标签:bat forward pytorch pac 产生 color 继承 als output
首先是pytorch的机制:
Torch 自称为神经网络界的 Numpy, 因为他能将 torch 产生的 tensor 放在 GPU 中加速运算
最常用的数据形式是tensor , torch中的tensor类似tensorflow中的tensor;
在编写代码的时候,注意把自己的数据转为tensor,然后如果需要梯度计算类似的操作再转为variable;若需要cuda运算,再加上cuda
其次是pytorch用于存储不断变化的量的variable:
variable计算的时候,后台搭建了一个动态计算图(tensorflow大部分时候是静态的),是将所有的计算步骤 (节点) 都连接起来, 最后进行误差反向传递的时候, 一次性将所有 variable 里面的修改幅度 (梯度) 都计算出来, 而 tensor 就没有这个能力。
不能直接获取variable的值,需要通过var.data来转换为tensor形式;
激励函数使得输出结果 y 也有了非线性的特征;cnn中推荐使用relu以及变化版;
class Net(torch.nn.Module): # 继承 torch 的 Module
网络的搭建集成nn.module,一般都会重写forward函数,最后一般通过output = Net(args)
直接获取结果
torch训练过程:
for t in range(100): prediction = net(x) # 喂给 net 训练数据 x, 输出预测值 loss = loss_func(prediction, y) # 计算两者的误差 optimizer.zero_grad() # 清空上一步的残余更新参数值 loss.backward() # 误差反向传播, 计算参数更新值 optimizer.step() # 将参数更新值施加到 net 的 parameters 上
快速搭建网络的sequential序列网络:
net2 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) )
一般保存pytorch网络都采用
torch.save(net1.state_dict(), ‘net_params.pkl‘) # 只保存网络中的参数 (速度快, 占内存少)
提取网络一般采用:
new_net.load_state_dict(torch.load(‘net_params.pkl‘))
举一个pytorch加载数据集的例子:
import torch import torch.utils.data as Data torch.manual_seed(1) # reproducible BATCH_SIZE = 5 # 批训练的数据个数 x = torch.linspace(1, 10, 10) # x data (torch tensor) y = torch.linspace(10, 1, 10) # y data (torch tensor) # 先转换成 torch 能识别的 Dataset torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y) # 把 dataset 放入 DataLoader loader = Data.DataLoader( dataset=torch_dataset, # torch TensorDataset format batch_size=BATCH_SIZE, # mini batch size shuffle=True, # 要不要打乱数据 (打乱比较好) num_workers=2, # 多线程来读数据 ) for epoch in range(3): # 训练所有!整套!数据 3 次 for step, (batch_x, batch_y) in enumerate(loader): # 每一步 loader 释放一小批数据用来学习 # 假设这里就是你训练的地方...
若放在gpu加速的话需要修改的地方有:
数据.cuda()
net.cuda()
在 train 的时候, 将每次的training data 变成 GPU 形式
总之就是参与需要一直变化的运算的一切东西,都在后面加上.cuda()
参考自:https://morvanzhou.github.io/tutorials/machine-learning/torch/3-05-train-on-batch/
莫烦python
标签:bat forward pytorch pac 产生 color 继承 als output
原文地址:https://www.cnblogs.com/ywheunji/p/11027130.html