码迷,mamicode.com
首页 > 其他好文 > 详细

损失函数Center Loss 代码解析

时间:2018-05-25 21:11:07      阅读:1976      评论:0      收藏:0      [点我收藏+]

标签:native   ref   blog   ima   als   dsa   sed   initial   sdn   

center loss来自ECCV2016的一篇论文:A Discriminative Feature Learning Approach for Deep Face Recognition。 
论文链接:http://ydwen.github.io/papers/WenECCV16.pdf 
代码链接:https://github.com/davidsandberg/facenet

理论解析请参看 https://blog.csdn.net/u014380165/article/details/76946339

下面给出centerloss的计算公式以及更新公式

技术分享图片

技术分享图片

 

技术分享图片

 

下面的代码是facenet作者利用tensorflow实现的centerloss代码

def center_loss(features, label, alfa, nrof_classes):
    """Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"
       (http://ydwen.github.io/papers/WenECCV16.pdf)
       https://blog.csdn.net/u014380165/article/details/76946339
    """
    nrof_features = features.get_shape()[1]
  #训练过程中,需要保存当前所有类中心的全连接预测特征centers, 每个batch的计算都要先读取已经保存的centers centers
= tf.get_variable(centers, [nrof_classes, nrof_features], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) label = tf.reshape(label, [-1]) centers_batch = tf.gather(centers, label)#获取当前batch对应的类中心特征 diff = (1 - alfa) * (centers_batch - features)#计算当前的类中心与特征的差异,用于Cj的的梯度更新 centers = tf.scatter_sub(centers, label, diff)#更新梯度Cj,对于上图中步骤6,tensorflow会将该变量centers保留下来,用于计算下一个batch的centerloss loss = tf.reduce_mean(tf.square(features - centers_batch))#计算当前的centerloss 对应于Lc return loss, centers

 

损失函数Center Loss 代码解析

标签:native   ref   blog   ima   als   dsa   sed   initial   sdn   

原文地址:https://www.cnblogs.com/adong7639/p/9090421.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!