码迷,mamicode.com
首页 > 其他好文 > 详细

hinge loss/支持向量损失的理解

时间:2019-12-20 01:06:55      阅读:88      评论:0      收藏:0      [点我收藏+]

标签:nal   tps   tensor   需要   sele   ret   for   pytho   div   

https://blog.csdn.net/AI_focus/article/details/78339234

https://www.cnblogs.com/massquantity/p/8964029.html

pytprch HingeLoss 的实现:

import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as TF
import torchvision.utils as vutils
import torch.nn.functional as F

class HingeLoss(nn.Module):
    """
    铰链损失
    SVM hinge loss
    L1 loss = sum(max(0,pred-true+1)) / batch_size
    注意: 此处不包含正则化项, 需要另外计算出来 add
    https://blog.csdn.net/AI_focus/article/details/78339234
    """

    def __init__(self, n_classes, margin=1.):
        super(HingeLoss, self).__init__()
        self.margin = margin
        self.n_classes = n_classes

    def forward(self, y_pred, y_truth):
        # y_pred: [b,n_classes]    = W^t.X
        # y_truth:[b,]
        batch_size = y_truth.size(0)
        mask = torch.eye(self.n_classes, self.n_classes, dtype=torch.bool)[y_truth].cuda()
        y_pred_true = torch.masked_select(y_pred, mask).unsqueeze(dim=-1).cuda()
        loss = torch.max(torch.zeros_like(y_pred).cuda(), y_pred - y_pred_true + self.margin)
        loss = loss.masked_fill(mask, 0)
        return torch.sum(loss) / batch_size


if __name__ == ‘__main__‘:
    LossFun = HingeLoss(5)
    y_truth = torch.tensor([0, 1, 2])
    y_pred = torch.randn([3, 5])
    loss = LossFun(y_pred, y_truth)
    print(loss)

  

hinge loss/支持向量损失的理解

标签:nal   tps   tensor   需要   sele   ret   for   pytho   div   

原文地址:https://www.cnblogs.com/dxscode/p/12019794.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!