码迷,mamicode.com
首页 > 编程语言 > 详细

最小二乘法 python实现

时间:2015-07-31 23:26:16      阅读:678      评论:0      收藏:0      [点我收藏+]

标签:

 1 import numpy as np
 2 
 3 def SumSquareError(dataset,A):
 4     # 输入目标数据集与假设曲线函数,计算误差平方和
 5     # 数据形式 dataset[i] = [x,y],y = hypfunc(x)
 6     # A: 多项式系数[a0,a1,...,an-1] 
 7     hypresult = [hypfunc(dataset[i,0],A) for i in range(dataset.shape[0])]
 8     sse = np.sum((hypresult - dataset[:,1])**2)
 9     return sse 
10 
11 def hypfunc(x,A):
12     # 输入:x 横坐标数值, A 多项式系数 [a0,a1,...,an-1]
13     # 返回 y = hypfunc(x)
14     return np.sum(A[i]*(x**i) for i in range(len(A)))
15 
16 """
17 最小二乘思路
18 设 假设 yh = a0x^0 + a1x^1 + a2x^2 +...+ akx^k
19 则误差 R2 = sum(y(xi)-yh(xi)) i = 1...n
20       R2 = sum [(yi-(a0x^0 + a1x^1 + a2x^2 +...+ akx^k))]2 ~ 0
21     R2对ai求偏导:并令(共k+1个方程)
22     div(R2,ai) = -2 * sum(yi-(a0x^0 + a1x^1 + a2x^2 +...+ akx^k)) * x^i = 0
23     有如下矩阵 用方程求解
24      [[1 x1 ... x1^k],...,[1 xn ... xn^k]] * [a0,...,ak] = [y1,...,yn]
25 """
26 
27 import random
28 import matplotlib.pyplot as plt
29 
30 if __name__=="__main__":
31     pass
32     # 生成曲线上各个点
33     x = np.arange(-1,1,0.02)
34     y = [((a*a-1)*(a*a-1)*(a*a-1)+0.5)*np.sin(a*2) for a in x]
35     xa = []
36     ya = []
37     # 对曲线上每个点进行随机偏移
38     for i in range(len(x)):
39         d = np.float(random.randint(60,140))/100
40         ya.append(y[i]*d)
41         xa.append(x[i]*d)
42     n = len(xa)     # 数据个数
43     
44     order = 9   # 设定k阶多项式 0 ~ k
45     # 根据数据点构造X,Y的 范德蒙德矩阵
46     matX = np.array([[np.sum([xa[i]**(k2+k1) for i in range(n)]) 
47              for k2 in range(order+1)] for k1 in range(order+1)])
48             
49     matY = np.array([np.sum([(xa[i]**k)*ya[i] for i in range(n)])
50             for k in range(order+1)])
51     print matX.shape,matY.shape
52     
53     A = np.linalg.solve(matX, matY)
54     print A
55 
56     # 画出数据点与拟合曲线
57     plt.figure()
58     # 输出数据点
59     plt.plot(xa,ya,linestyle=‘‘,marker=.) 
60     
61     # 画出拟合后曲线
62     yhyp = [hypfunc(x[i],A) for i in range(n)]
63     plt.plot(x,yhyp,linestyle=-,marker=‘‘) 
64     
65     plt.show()

 

最小二乘法 python实现

标签:

原文地址:http://www.cnblogs.com/hanahimi/p/4693282.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!