头图

背景:Transformer诞生于NLP,自大模型热潮以来,Transformer也被大量研究者应用于视觉领域。笔者主要从事视觉领域的算法研发,为了解Transformer的结构、优势等,阅读和调研了一些以往鲜少了解的RNN、时序处理、语音处理等领域的资料,结合自己的思考谈谈对Transformer结构的理解,如有错误,欢迎留言/私信指正。

Transformer原文:https://arxiv.org/abs/1706.03762#

关于Transformer的结构,各博客文章、视频教程等已有详细的介绍,本文不再赘述。

笔者也是看了许多资料,发现大部分资料对于网络结构中的运算已经说明得非常详尽,但是很少有资料会说明为什么Transformer要设计这样的运算,以及这样的运算到底带来了什么好处

现在笔者通过对比RNN、CNN,来谈一下笔者本人的见解。

RNN中的迭代运算

先来看一下RNN的结构:

image.png

这个图中,X是输入,O是输出,S是中间节点的运算结果,U、V、W都是权值参数。这个图是说上一个运算出来的S会作为入参参与到当前的运算中,写成公式:

$$ O_t = g(V \cdot S_t) $$

$$ S_t = f(U \cdot X_t + W \cdot S_{t-1}) $$

可以看出,RNN之所以可以用来处理时序问题,又可以支持不定长的输入,就在于它的迭代运算,即用上一次的运算结果作为当前的入参。可以从下面这个动图中得到更直观的理解:

575e2-2019-07-02-input-5.gif
(图片来源:https://easyai.tech/ai-definition/rnn/

可以看出,这种运算需要将输入一个一个传给网络,很显然这就带来了运算效率上的问题

(PS:此外,RNN在训练时还存在梯度消失的问题,不能支持过长的输入,虽然LSTM通过添加“遗忘门”结构尝试解决历史信息的“遗忘”,但仍没有从根本上解决该问题。)

CNN中的卷积运算

上文提到了RNN运算存在效率的问题。而卷积运算因为一次可以运算“一批”数据,在效率上会更高一些。下图为一个二维数据做卷积运算的示意图:

Convolution_schematic.gif
(图片来源:http://ufldl.stanford.edu/tutorial/supervised/FeatureExtracti...

由于卷积运算的高效,卷积结构便被大量应用于视觉领域,在Transformer之前,几乎所有视觉任务的神经网络都是CNN结构。

而利用卷积运算处理时序任务,WaveNet应当是一个里程碑式的工作。

WaveNet原文:https://arxiv.org/abs/1609.03499

WaveNet的贡献在于提出膨胀因果卷积(dilated causal convolution) 的运算方法,如下图:

v2-c30c38e855aa24273739940a3ff411a1_b.gif
(图片来源:https://zhuanlan.zhihu.com/p/161667958

所谓因果卷积,是指输出数据只由在它序号之前的输入数据计算得到;而所谓膨胀卷积,是指跳过一些中间序号的输入数据计算输出数据。因为有了因果膨胀卷积,这使得网络可以支持很长的输入。

可以看到,WaveNet虽将卷积运算引入到时序分析中,一次处理“一批”数据,但效率仍然很低,因为它仍然需要将上一次处理的结果作为当前运算的输入,这本质上也是在做迭代,也是顺次将数据放进网络。

Transformer中的矩阵乘法运算

Transformer中最重要的运算当属所谓“自注意力机制”了:

image.png

现在来详细解读一下这个运算过程。

上文公式中的\(Q\)、\(K\)、\(V\)是通过将词嵌入矩阵\(X\)分别与三个权值矩阵\(W_Q\)、\(W_K\)、\(W_V\)相乘得到的,如下图展示了\(X\)与\(W_Q\)相乘的运算过程:

image.png

笔者认为,这个运算是整个网络结构中最需要重点讲述的部分。

原论文中,词嵌入是用512维的向量表示的,所以输入的\(X\)矩阵是512列,这里输入的词汇个数可以是不定长的,我们假设输入词汇数为N,所以\(X\)矩阵是N行。

原论文中仅说明了权值矩阵一共8列,但由于权值矩阵要与输入矩阵做矩阵乘法,所以权值矩阵一定是512行的,正如上面的示意图,权值矩阵\(W_Q\)是一个固定维度的512行8列的矩阵,\(W_K\)、\(W_V\)也是相同的维度。

将矩阵\(X\)和矩阵\(W_Q\)相乘,根据矩阵乘法的运算法则,那么得到的\(Q\)矩阵的行数一定是\(X\)矩阵的行数,列数一定是\(W_Q\)矩阵的列数,所以\(Q\)矩阵是N行8列的,很显然,这个N就是输入的词汇数。

\(K\)与\(V\)的运算与\(Q\)相同,所以\(Q\)、\(K\)、\(V\)三个矩阵都是N行8列的。

接下来就是自注意力机制的运算(也就是上文公式中的运算),将\(Q\)乘以\(K\)的转置,除以一个系数(为了防止结果过大)取softmax再乘以\(V\)。

我们可以看到,因为引入了矩阵乘法,这使得我们的网络输入可以是任意长度(输入句子的词汇可以是任意个数),这是CNN做不到的,而同时又是因为矩阵乘法,我们又可以对一个任意长度的输入做并行运算(同时处理输入句子中的所有词汇,而不是一个词一个词处理),而这是RNN做不到的。无论输入长度是怎样的,经过上述矩阵乘法再经过自注意力机制的运算,一定会得到一个维数固定的矩阵,从而进行后面的运算。

总结一下:Transformer用矩阵乘法,替换RNN中的迭代运算或CNN中的卷积运算,从而使得网络可以对不定长的输入(CNN做不到或者效果不好)做并行运算(RNN做不到)。 这便是Transformer结构的优势。

(PS:可以看到,上述的运算词汇的位置信息并没有被包含进去,为了将词汇位置信息也加入训练,原文对位置信息做了编码,原文中,一个词汇编码成了512维,作者便也将位置编码成512维,然后把这个位置向量与词汇向量直接相加。)

Transformer应用于视觉

Transformer在时序处理的领域(语音、NLP)取得了巨大的成功,同时视觉领域的研究者也在探索如何将其应用在图像、视频这种类型的数据里。

这里介绍一个里程碑工作:Vision Transformer(ViT),是将Transformer结构应用于图像的比较成功的工作。

原文链接:https://arxiv.org/abs/2010.11929

这里,是把一张图切成9个patch,每个patch用线性投影层转成向量,再在每个向量前拼接上1、2...9这些数字,用来表示位置信息,数字0+分类词汇向量作为标签,把这个向量作为transformer的输入,从而替代了卷积运算,如下图:

image.png

可以看出(这里是仅为笔者的观点),这样的应用似乎不太符合直觉,或者,有些“生硬”,这样做丢弃了图像中各像素间的位置关系这样的信息,像是为了用Transformer而用Transformer。

笔者认为,卷积运算因为其参数共享和平移不变的特点,实际上很天然地适合处理非时序的二维数据,对于处理图像任务,也许不应当把卷积运算完全丢弃了,也许将Transformer和CNN两者优势结合起来是更好的结构。


ZincO
1 声望1 粉丝