码迷,mamicode.com
首页 > 其他好文 > 详细

02-pytorch

时间:2019-07-07 12:36:16      阅读:91      评论:0      收藏:0      [点我收藏+]

标签:均值   pytho   require   python   tensor   out   平均值   ogr   back   

import torch
from pprint import pprint
from torch.autograd import Variable

变成图纸中的一个节点

tensor = torch.FloatTensor([[1,2],[3,4]])
variable = Variable(tensor,requires_grad=True)
pprint(tensor)
pprint(variable)
tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]], requires_grad=True)

反向传播误差

t_out = torch.mean(tensor*tensor)  # 求x^2的平均值
v_out = torch.mean(variable*variable)
# v_out = 1/4 *sum(var*var) 
v_out.backward()  # 反向传播误差
# d(v_out)/d(var)    1/4*2*variable = variable/2

print(variable.grad)
tensor([[0.5000, 1.0000],
        [1.5000, 2.0000]])
# variable.data 才是 tensor 的形式
print(variable.data)
tensor([[1., 2.],
        [3., 4.]])

02-pytorch

标签:均值   pytho   require   python   tensor   out   平均值   ogr   back   

原文地址:https://www.cnblogs.com/liu247/p/11145511.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!