标签:cat numpy res expand oss item data ima back
import torch from torch.autograd import Variable %matplotlib inline from matplotlib import pyplot as plt from IPython import display
torch.manual_seed(2019) def get_fake_data(batch_size=16): x = torch.randn(batch_size,1) * 20 y = x * 2 + (1 + torch.rand(batch_size,1)) * 3 return x,y #返回的是二维数组
x_train,y_train = get_fake_data() plt.scatter(x_train.squeeze().numpy(),y_train.squeeze().numpy()) #x.squeeze()将二维变为一维
w = Variable(torch.rand(1,1),requires_grad=True) b = Variable(torch.zeros(1,1),requires_grad=True) lr = 1e-6 #lr不能设置太大,否则会梯度爆炸
for i in range(100000): x_train,y_train = get_fake_data() x_train,y_train = Variable(x_train),Variable(y_train) y_pred = x_train.mm(w) + b.expand_as(y_train) loss = 0.5 * (y_pred - y_train) ** 2 loss = loss.sum() loss.backward() w.data.sub_(lr * w.grad.data) b.data.sub_(lr * b.grad.data) w.grad.data.zero_() b.grad.data.zero_() if i % 1000 == 0: display.clear_output(wait=True) x_test = torch.arange(0,20).view(-1,1).float() y_test = x_test.mm(w.data) + b.data.expand_as(x_test) plt.plot(x_test.numpy(),y_test.numpy()) x_train,y_train = get_fake_data(batch_size=20) plt.scatter(x_train.numpy(),y_train.numpy()) plt.xlim(0,20) plt.ylim(0,41) plt.show() plt.pause(0.5) print(w.data.squeeze().item(),b.data.squeeze().item())
最后结果:
代码来自于《深度学习框架PyTorch:入门与实践》,环境为PyTorch1.0 + Jupyter
标签:cat numpy res expand oss item data ima back
原文地址:https://www.cnblogs.com/liualex1109/p/11838002.html