标签:near 维度 backward size print 网络 线性回归 oss --
pytorch写神经网络
(1)准备数据集
(2)涉及模型(yheight)
(3)构造损失函数和优化器
(4)训练周期(前馈、反馈、更新)
1 import torch 2 3 #1.准备数据 4 x_data=torch.tensor([[1.0],[2.0],[3.0]]) 5 y_data=torch.tensor([[2.0],[4.0],[6.0]]) 6 #2.构造模型 7 class LinearModel(torch.nn.Module): 8 def __init__(self):#构造函数 9 super(LinearModel,self).__init__() 10 self.linear=torch.nn.Linear(1,1)#构造对象;(1,1)指输入x和输出y的特征维度;第三个参数自动为true,是否要有b 11 12 def forward(self,x): 13 y_pred=self.linear(x)#计算y=ax+b 14 return y_pred 15 16 model=LinearModel()#实例化 17 18 criterion=torch.nn.MSELoss(size_average=False)#计算loss 19 optimizer=torch.optim.SGD(model.parameters(),lr=0.01)#lr:学习率 20 21 for epoch in range(100): 22 y_pred=model(x_data)#前向传播 23 loss=criterion(y_pred,y_data)#计算loss 24 print(epoch,loss.item()) 25 26 optimizer.zero_grad() 27 loss.backward()#反向传播,计算梯度 28 optimizer.step()#更新w,b的值 29 30 print(‘w=‘,model.linear.weight.item()) 31 print("b=",model.linear.bias.item()) 32 33 x_test=torch.tensor([4.0]) 34 y_test=model(x_test) 35 print("y_pred=",y_test)
PyTorch深度学习实践(五)---pytorch实现线性回归
标签:near 维度 backward size print 网络 线性回归 oss --
原文地址:https://www.cnblogs.com/miosk/p/14663144.html