码迷,mamicode.com
首页 > 其他好文 > 详细

第一章基本操作-线性拟合(GPU版本)

时间:2020-04-07 00:30:56      阅读:88      评论:0      收藏:0      [点我收藏+]

标签:linear   初始化   array   操作   sel   init   实例化   optimize   super   

第一步:构造数据

import numpy as np
import os

x_values = [i for i in range(11)]
x_train = np.array(x_values, dtype=np.float32).reshape(-1, 1)

y_values = [i * 2 + 1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32).reshape(-1, 1)

第二步: 使用class LinearRegressionModel 

class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        out = self.linear(x)
        return out

第三步: 实例化模型,初始化epochs, 学习率,定义SGD优化函数,以及定义mse优化损失函数,使用model.to(device) 将模型的参数更新放在GPU上 

input_dim = 1
output_dim = 1

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = LinearRegressionModel(input_dim, output_dim)
model.to(device) epochs = 1000 learning_rate = 0.01 optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate) criterion = nn.MSELoss()

第四步: 如果模型存在就使用model.load_state_dict(torch.load("model.pkl")) 加载模型 参数,进行模型的参数优化,每50次,使用torch.save(model.state_dict)保存模型 ,使用to(device) 将训练样本和测试样本放在GPU上 

if os.path.exists("model.pkl"):
    model.load_state_dict(torch.load("model.pkl"))

for epoch in range(epochs):

    inputs = torch.from_numpy(x_train).to(device)
    labels = torch.from_numpy(y_train).to(device)

    # 梯度每次清零
    optimizer.zero_grad()

    # 前向传播
    outputs = model(inputs)

    # 计算损失值
    loss = criterion(outputs, labels)

    #反向传播
    loss.backward()

    #更新权重参数
    optimizer.step()

    if epoch % 50 == 0:
        print("epoch:{},loss:{}".format(epoch, loss.item()))
        torch.save(model.state_dict(), "model.pkl")

 

第一章基本操作-线性拟合(GPU版本)

标签:linear   初始化   array   操作   sel   init   实例化   optimize   super   

原文地址:https://www.cnblogs.com/my-love-is-python/p/12650342.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!