码迷,mamicode.com
首页 > 其他好文 > 详细

【CVPR2020】Gated Channel Transformation for Visual Recognition

时间:2020-07-29 00:39:40      阅读:110      评论:0      收藏:0      [点我收藏+]

标签:问题:   合作   工作   apt   pytorch   ali   inf   lob   img   

代码:https://github.com/z-x-yang/GCT

这是一个百度和悉尼科技大学合作的工作,作者指出,SENet modulate feature maps on the channel-wise level. 但是,SE 模块使用了个全连接层(FC层)处理 channel-wise embeddings,这样会产生两个问题:

  • 在CNN中应用SE模块时,模块数量有限制。SE主要在模块上应用(Res-Block或者Inception-Block),FC层无法在网络的所有层上应用。
  • 由于FC层参数复杂,难以分析网络不同层间通道的关联性。

为了解决上述两个问题,作者提出了 Gated Channel Transformation (GCT)模块,主要有如下设计:

  • 使用一个 归一化模块 替换FC,对通道间的特征关系建模
  • 设计了一系列参数,应用 gating mechanism(门控机制) 对通道间的特征关系建模

通过归一化模块和门控机制,GCT模块可以捕获通道特征间的 “竞争” 和 “合作”。如下图所示,GCT模块可以促进 shallow layer 特征间的合作,同时,促进 deep layer 特征间的竞争。这样,浅层特征可以更好的获取通用的属性,深层特征可以更好的获取与任务相关的 discriminative 特征。

技术图片

GCT模块包括三个部分:global context embedding, channel normalization, gating adaptation,如下图所示:

技术图片

第一部分:Global context embedding

与SE模块不同,GCT模块没有采用全局池化 (GAP) 的方式,因为在某些情况下GAP会失效。比如在某些应用中会使用 instance normalization,会固定各个通道的均值,这样得到的结果向量就会变成常量。因此,作者使用L2 norm 进行global context embeding:

技术图片

这里引入了\(\alpha\),是一组可训练的参数。

第二部分:Channel Normalization

在这部分仍然使用L2 norm,如下所示:

技术图片

\(\sqrt{C}\)用来避免当\(C\)比较大时,\(\hat{s}_c\)的值过小。与 SE 的 FC层相比,该通道归一化方法计算量更小。

第三部分:Gating adaptation

作者提出了门控机制来 adapte the original feature。使用门控机制,GCT可以在训练过程中促进特征间的竞争与合作:

技术图片

作者指出,这里设计了权重\(\gamma\) 和 偏置 \(\beta\) 来控制通道特征是否激活。当一个通道的特征权重 \(\gamma_c\)被正激活,GCT将促进这个通道的特征和其它通道的特征“竞争”。当一个通道的特征 \(\gamma_c\) 被负激活,GCT将促进这个通道的特征和其它通道的特征“合作”。

作者在论文中还提供了GCT模块的 pytorch 实现代码:

def forward(self, x, epsilon=1e-5):
  # x: input features with shape [N,C,H,W] 
  # alpha, gamma, beta: embedding weight, gating weight,
  # gating bias with shape [1,C,1,1]
  embedding = (x.pow(2).sum((2,3), keepdim=True)
    + epsilon).pow(0.5) * self.alpha
  norm = self.gamma / (embedding.pow(2).mean(dim=1,        	
    keepdim=True) + epsilon).pow(0.5)
  gate = 1. + torch.tanh(embedding * norm + self.beta)
  return x * gate

论文有趣的地方是作者做了一个实验,来分析在RestNet50中,\(\gamma\)的变化。可以看出,作者得出一个结论,在网络的浅层,\(\gamma\)的值比较小,普遍在0以下,说明特征间中合作关系;在网络的深层,\(\gamma\)的值就在增大,增长到0以上,说明特征间是竞争关系,有助于分类。

技术图片

其它部分的内容不过多介绍了,具体可以参考作者论文。

【CVPR2020】Gated Channel Transformation for Visual Recognition

标签:问题:   合作   工作   apt   pytorch   ali   inf   lob   img   

原文地址:https://www.cnblogs.com/gaopursuit/p/13394627.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!