码迷,mamicode.com
首页 > 其他好文 > 详细

损失函数

时间:2020-02-24 16:42:33      阅读:89      评论:0      收藏:0      [点我收藏+]

标签:结果   有一个   rop   size   函数   ros   一个   注意   batch   

1、torch.nn.CrossEntropyLoss()

loss_func=torch.nn.CrossEntropyLoss()

loss=loss_func(input_data,input_target)

其中input_data的shape一般是(batch_size,output_features),而input_target的shape是(batch_size)

返回的loss是一个张量,但是只有一个数,代表的是计算结果的交叉商损失值

交叉商的计算方法是:

将输入的数据在最后一个维度上做softmax运算

对softmax后的数据取log,注意softmax后所有的数值介于0和1之间,所以log后所有的数值全都是负数

softmax_loged_data=torch.log(torch.nn.Softmax(dim=-1)(input_data))

根据标签对应的数值去softmax_loged_data中索引出相应的数值并且去掉符号,

将这batch_size个数值相加取平均后就是input_data与input_target的交叉商损失值

损失函数

标签:结果   有一个   rop   size   函数   ros   一个   注意   batch   

原文地址:https://www.cnblogs.com/liujianing/p/12357425.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!