知乎专栏:10分钟带你深入理解Transformer原理及实现
关于encoder和decoder
Encoder总览
Encoder端由N(原论文中N=6)个相同的大模块堆叠而成,其中每个大模块又由两个子模块构成,这两个子模块分别为多头self-attention模块,以及一个前馈神经网络模块。
需要注意的是,Encoder端每个大模块接收的输入是不一样的,第一个大模块(最底下的那个)接收的输入是输入序列的embedding(embedding可以通过word2vec预训练得来),其余大模块接收的是其前一个大模块的输出,最后一个模块的输出作为整个Encoder端的输出。
Decoder总览
Decoder端同样由N(原论文中N=6)个相同的大模块堆叠而成,其中每个大模块则由三个子模块构成,这三个子模块分别为多头self-attention模块,多头Encoder-Decoder attention交互模块,以及一个前馈神经网络模块。
同样需要注意的是,Decoder端每个大模块接收的输入也是不一样的,其中第一个大模块(最底下的那个)训练时和测试时的接收的 输入是不一样的,并且每次训练时接收的输入也可能是不一样的(也就是模型总览图示中的shifted right),其余大模块接收的是同样是其前一个大模块的输出,最后一个模块的输出作为整个Decoder端的输出。
对于第一个大模块,简而言之,其训练及测试时接收的输入为:
训练的时候每次的输入为上次的输入加上输入序列向后移一位的ground truth(例如每向后移一位就是一个新的单词,那么则加上其对应的embedding),特别地,当decoder的time step为1时(也就是第一次接收输入),其输入为一个特殊的token,可能是目标序列开始的token(如),也可能是源序列结尾的token(如),也可能是其它视任务而定的输入等等,不同源码中可能有微小的差异,其目标则是预测下一个位置的单词(token)是什么,对应到time step为1时,则是预测目标序列的第一个单词(token)是什么,以此类推;
这里需要注意的是,在实际实现中可能不会这样每次动态的输入,而是一次性把目标序列的embedding通通输入第一个大模块中,然后在多头attention模块对序列进行mask即可
在测试的时候,是先生成第一个位置的输出,然后有了这个之后,第二次预测时,再将其加入输入序列,以此类推直至预测结束。