码迷,mamicode.com
首页 > Web开发 > 详细

Transformer之encoder原理

时间:2019-12-15 10:57:12      阅读:306      评论:0      收藏:0      [点我收藏+]

标签:通过   res   学习   with   完成   参数   文章   tps   sqrt   

前言

前几天写了一篇关于BERT的博文,里面用到了Transformer的编码器,但是没有具体讲它的原理,所以在这篇文章里做一个补充。本文只阐述编码器encoder的部分,只做一个重点部分流程的概括,具体的最好还是看看原论文,然后关于解码器的部分之后有机会再讲。

encoder原理

我们主要根据下面一张图的流程来讲解

技术图片

1.首先假设我们有一组input:\(I = (I_1, I_2, ... I_n)\) 经过一个简单的embedding,其实就是做一个线性变换 \(\alpha = WI\)

2.然后加入每个token的position信息,其实就是直接把两个向量加起来 \(X = P + \alpha\)

3.下面就要进入一个循环体了(灰色框框内),也就是多层encode过程,我们提到过在BERT中有12层的base,也有24层的large。

3.1 第一步要经过一个multi-head attention,我们先从简单入手,介绍一个self-attention的单个头的情形。在这种最简单的情况下,把X经过三种不同的线性变换\(W^q, W^k, W^v\),得到三个结果分别代表query,key和value,表示为\(Q = W^qX, K = W^kX, V = W^vX\)。然后我们将每个Q元素与每个K元素相乘,相当于是一种双向的传播,表示为 \(Q^TK\),经过一层softmax后再与对应的\(V\)相乘得到最后的结果 \(O = V^Tsoftmax(Q^TK)\)。下面这个手绘的图展示了微观上的操作,可以参考一下:

技术图片

了解了单头的情况,多头也就容易了。multi-head只不过是把原来的输入X切成n段(嗯,真的就是直接切开就好),然后分别去做线性变换,变换完了以后再把他们拼接回来,又变成了一个同维度的整体。那么你可能会问,这样分开来做和一起做有什么区别吗?根据原论文的阐述,”Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.” 多头注意允许模型在不同的位置共同关注来自不同表示子空间的信息。只有一个注意力集中的头脑,平均化可以抑制这种情况。

3.2 接下来可以先做一个dropout的操作来防止过拟合,再把得到的输出O和之前的输入X做一个残差和ADD,再做一个layer norm。如果你不清楚什么是layer norm,大概是这个形式 \(\beta_1 * \frac{x - mean(x)}{\sqrt{var(x) + \epsilon}} + \beta_2\)

3.3 然后,我们把norm完的结果输入到全连接层中feed forward。这里的具体过程也很简单,我大致画了一下

技术图片

4.这样一次循环就结束了,一共需要经过N次循环,就完成了encoder编码任务。

补充点

  1. 这个encoder过程中有一些参数是人为规定好的,例如position embedding,有一些是需要通过学习得到的比如一些\(\beta\)

  2. 为了使用矩阵形式进行运算,对原始的句子是会做一些裁切和zero-padding的工作来确保长度一致。

  3. 关于attention中矩阵的shape,具体实践中不是简单的二维哦,里面还要加入batch维度,这个是由batch_size来决定的。假如batch size = N,token size = T,embedding size = E,那么输入的shape就是(N,T,E)。

Reference

1.Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. In Advances in neural information processing systems (pp. 5998-6008).

Transformer之encoder原理

标签:通过   res   学习   with   完成   参数   文章   tps   sqrt   

原文地址:https://www.cnblogs.com/mrdoghead/p/12041764.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!