码迷,mamicode.com
首页 > 其他好文 > 详细

二元线性回归

时间:2019-04-05 12:34:41      阅读:151      评论:0      收藏:0      [点我收藏+]

标签:tool   tmp   for   The   limit   实现   print   surf   pyplot   

import numpy as np
import  matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
data=np.genfromtxt("Delivery.csv",delimiter=,)
x_data = data[:,[1,2]]
y_data = data[:,[-1]]

lr=0.000001
theta0=0
theta1=0
theta2=0
epochs=50
print (x_data[0])
def gredient_desent_runner(x_data,y_data,lr,theta0,theta1,theta2,epochs):
    m = len(x_data)
    for i in range(epochs):
        tmp_theta0=0
        tmp_theta1=0
        tmp_theta2=0
        for j in range(m):
            # print ()
            tmp_theta0+=theta0+theta1*x_data[j][0]+theta2*x_data[j][1]-y_data[j][0]
            tmp_theta1+=(theta0+theta1*x_data[j][0]+theta2*x_data[j][1]-y_data[j][0])*x_data[j][0]
            tmp_theta2+=(theta0+theta1*x_data[j][0]+theta2*x_data[j][1]-y_data[j][0])*x_data[j][1]
        theta0-=lr*tmp_theta0/m
        theta1-=lr*tmp_theta1/m
        theta2-=lr*tmp_theta2/m
    return theta0,theta1,theta2
theta0,theta1,theta2=gredient_desent_runner(x_data,y_data,lr,theta0,theta1,theta2,epochs)
ax=plt.figure().add_subplot(111,projection = 3d)
x0=x_data[:,0]
x1=x_data[:,1]
x0,x1 = np.meshgrid(x0,x1)
z=theta0+theta1*x0+theta2*x1
ax.plot_surface(x0,x1,z)
plt.show()

梯度下降法实现线性回归

二元线性回归

标签:tool   tmp   for   The   limit   实现   print   surf   pyplot   

原文地址:https://www.cnblogs.com/students/p/10658041.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!