码迷,mamicode.com
首页 > 其他好文 > 详细

一维线性回归

时间:2019-08-28 01:16:13      阅读:89      评论:0      收藏:0      [点我收藏+]

标签:back   target   targe   ict   mat   info   eval   cuda   com   

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim as optim


x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],
    [9.779],[6.182],[7.59],[2.167],[7.042],
    [10.791],[5.313],[7.997],[3.1]], dtype=np.float32)

y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],
    [3.366],[2.596],[2.53],[1.221],[2.827],
    [3.465],[1.65],[2.904],[1.3]],dtype=np.float32)


x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)


class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(1,1)#参数的shape是1*1的

    def forward(self, x):
        out = self.linear(x)
        return out

if torch.cuda.is_available():
    model = LinearRegression().cuda()
else:
    model = LinearRegression()


criterion = nn.MSELoss()#均方误差
optimizer = optim.SGD(model.parameters(), lr=1e-3)

num_epoches = 1000
for epoch in range(num_epoches):
    if torch.cuda.is_available():
        inputs = x_train.cuda()#输入是(10,1)的矩阵,nn.Linear(1,1),即参数矩阵的shape是(1,1),最后结果是(10,1)
        target = y_train.cuda()
    else:
        inputs = x_train
        target = y_train
    out = model(inputs)
    loss = criterion(out,target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch+1) % 20 ==0:
        print("Epoch[{}/{}], loss: {:.6f}".format((epoch+1), num_epoches, loss.item()))

model.eval()
predict = model(x_train.cuda())
predict = predict.data.cpu().numpy()
plt.plot(x_train.numpy(),y_train.numpy(), "ro", label="Original data")
plt.plot(x_train.numpy(), predict, label="Fitting Line")
plt.show()

技术图片

 

一维线性回归

标签:back   target   targe   ict   mat   info   eval   cuda   com   

原文地址:https://www.cnblogs.com/liualexsone/p/11421377.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!