码迷,mamicode.com
首页 > Web开发 > 详细

『MXNet』第八弹_物体检测之SSD

时间:2018-05-30 01:23:22      阅读:930      评论:0      收藏:0      [点我收藏+]

标签:坐标   clu   net   war   loss   div   one   multi   class   

预、API介绍

mxnet.metric

from mxnet import metric

cls_metric = metric.Accuracy()
box_metric = metric.MAE() 

cls_metric.update([cls_target], [class_preds.transpose((0,2,1))])
box_metric.update([box_target], [box_preds * box_mask])
cls_metric.get()
box_metric.get()

gluon.loss.Loss

class FocalLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, alpha=0.25, gamma=2, batch_axis=0, **kwargs):
        super(FocalLoss, self).__init__(None, batch_axis, **kwargs)
        self._axis = axis
        self._alpha = alpha
        self._gamma = gamma

    def hybrid_forward(self, F, output, label):
        # Here `F` can be either mx.nd or mx.sym
        # 这里使用F取代在forward中显式的指定两者,方便使用
        # 所以非hybrid无此参数
        output = F.softmax(output)
        pj = output.pick(label, axis=self._axis, keepdims=True)
        loss = - self._alpha * ((1 - pj) ** self._gamma) * pj.log()
        return loss.mean(axis=self._batch_axis, exclude=True)

mxnet.contrib.ndarray.MultiBoxTarget

def training_targets(anchors, class_preds, labels):
    """
    得到的全部边框坐标
    得到的全部边框各个类别得分
    真实类别及对应边框坐标
    """
    class_preds = class_preds.transpose(axes=(0,2,1))
    return MultiBoxTarget(anchors, labels, class_preds)

# Output achors: (1, 5444, 4)
# Output class predictions: (1, 5444, 3)
# batch.label: (1, 1, 5)
out = training_targets(anchors, class_preds, batch.label[0][0:1]) 

mxnet.contrib.ndarray.MultiBoxDetection

 

『MXNet』第八弹_物体检测之SSD

标签:坐标   clu   net   war   loss   div   one   multi   class   

原文地址:https://www.cnblogs.com/hellcat/p/9108647.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!