码迷,mamicode.com
首页 > 其他好文 > 详细

Pytorch实现简单的线性回归

时间:2019-11-11 21:40:56      阅读:121      评论:0      收藏:0      [点我收藏+]

标签:cat   numpy   res   expand   oss   item   data   ima   back   

import torch 
from torch.autograd import Variable
%matplotlib inline
from matplotlib import pyplot as plt
from IPython import display
torch.manual_seed(2019)

def get_fake_data(batch_size=16):
    x = torch.randn(batch_size,1) * 20
    y = x * 2 + (1 + torch.rand(batch_size,1)) * 3
    return x,y #返回的是二维数组
x_train,y_train = get_fake_data()
plt.scatter(x_train.squeeze().numpy(),y_train.squeeze().numpy()) #x.squeeze()将二维变为一维
w = Variable(torch.rand(1,1),requires_grad=True)
b = Variable(torch.zeros(1,1),requires_grad=True)
lr = 1e-6 #lr不能设置太大,否则会梯度爆炸
for i in range(100000):
    x_train,y_train = get_fake_data()
    x_train,y_train = Variable(x_train),Variable(y_train)
    
    y_pred = x_train.mm(w) + b.expand_as(y_train)
    loss = 0.5 * (y_pred - y_train) ** 2
    loss = loss.sum()
    
    loss.backward()
    
    w.data.sub_(lr * w.grad.data)
    b.data.sub_(lr * b.grad.data)
    
    w.grad.data.zero_()
    b.grad.data.zero_()

    if i % 1000 == 0:
        display.clear_output(wait=True)
        x_test = torch.arange(0,20).view(-1,1).float()
        y_test = x_test.mm(w.data) + b.data.expand_as(x_test)
        plt.plot(x_test.numpy(),y_test.numpy())
        
        x_train,y_train = get_fake_data(batch_size=20)
        plt.scatter(x_train.numpy(),y_train.numpy())
        
        plt.xlim(0,20)
        plt.ylim(0,41)
        plt.show()
        plt.pause(0.5)
        
        
print(w.data.squeeze().item(),b.data.squeeze().item())

最后结果:

技术图片

 

 代码来自于《深度学习框架PyTorch:入门与实践》,环境为PyTorch1.0 + Jupyter

 

 

 

 

 

 

Pytorch实现简单的线性回归

标签:cat   numpy   res   expand   oss   item   data   ima   back   

原文地址:https://www.cnblogs.com/liualex1109/p/11838002.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!