博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
论文笔记:Attention Is All You Need
阅读量:5940 次
发布时间:2019-06-19

本文共 3747 字,大约阅读时间需要 12 分钟。

Attention Is All You Need

2018-04-17 10:35:25 

 

1. Introduction: 

  现有的做 domain translation 的方法大部分都是基于 encoder-decoder framework,取得顶尖性能的框架也都是 RNN + Attention Mechanism 的思路。而本文别出心裁,仅仅依赖于 attention 机制,就可以做到很好的性能,并且,这种方法并适用于并行(parallelization)。

  

2. Model Architecture:  

  大部分神经序列转换模型(neural sequence transduction models)都有 encoder-decoder structure。此处,encoder 将输入的序列(x1, x2, ... , xn)转换为连续的表示 z = (z1, z2, ..., zn)。给定 z,decoder 然后每一输出一个元素,构成了序列 (y1, y2, ... , ym)。在每一个时间步骤,该模型是 auto-regressive,当产生下一个输出时,会使用上一个时刻产生的符号作为额外的输入(consuming the previously generated symbols as additional input when generating the next)。

 

2.1 Encoder and Decoder Stacks 

  

  Encoder:encoder 是由 6 个相同的 layer 堆叠起来的。每一个 layer 包括 两个 sub-layers:

  第一个是:mutli-head self-attention mechanism ; 

  第二个是:position-wise fully connected feed-forward network. 

  每一个这样的 two sub-layers 附近都会用上 residual connection,然后加上 layer normalization。

  

  Decoder:decoder 也有 6 层,不同的是:decoder layer 中包含 3个 sub-layers, which performs multi-head attention over the output of the encoder stack. 类似于 the encoder,我们采用 residual connections,followed by layer normalization. 我们也修改了 the self-attention sub-layer in the decoder。这个 masking,结合了这么一个事实:the output embeddings are offset by one position, 确保对于位置 i 的预测可以仅仅依赖于 the known outputs at positions less than i (less than i 是什么意思???). This maksing, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position i can depend only on the known outputs at positions less than i 

 

2.2 Attention

  An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. 输出可以看做是 the values 的加权组合,给每一个 value 的加权可以计算为:a compatibility function of the query with the corresponding key. 

  

  2.2.1 Scaled Dot-Product Attention 

  我们称我们特定的 attention 为:“Scaled Dot-Product Attention”。输入包括:queries and keys of dimension $d_k$, and values of dimension $d_v$。我们计算 the dot products of the query with all keys, divide each by  and apply a softmax function to obtain the weights on the values. 

  实际上,我们同时在一个 queries 的集合上计算 attention function,将其打包为 a matrix Q。The keys and values 也被打包为:K and V. 我们计算输出的矩阵为:

  

  两个最常用的 attention functions 是:additive attention,and dot-product (multiplicative) attention.  

  

  2.2.2 Multi-Head Attention 

  用 $d_{model}-dimensional$ keys, values and queries,我们发现:it is beneficial to linearly project the queries, keys and values h times with different, learned linear projections to dk, dk and dv dimensions, respectively. 在每一个这些投影的版本,我们然后并行的执行 attention function,产生 dv-dimensional 的输出 values。这些东西组合起来,然后再次投影,得到最终的 values,如图2所示。

  Multi-head attention allows the model to jointly attend to information from different reprentation subspaces at different positions. With a single attention head, averaging inhibits this. 

  

  本文才用 h=8 并行的 attention layers,or heads. 

 

  2.2.3 Applications of Attention in our Model 

  

  2.3 Position-wise Feed-Forward Networks 

  除了 attention sub-layers, 我们 encoder and decoder 的每一层都包含一个全连接的 feed-forward network, 单独且平等的适用于每一个位置。这包含:two linear transformations with a ReLU activation in between: 

  

  2.4 Embeddings and Softmax 

  和其他序列转换模型一样,我们利用学习到的 embeddings 来转换输入的符号,然后输出符号为维度是  $d_{model}$ 的向量。

  

  2.5 Positional Encoding  

  由于我们的模型没有任何 recurrence 和 convolution,为了使得模型充分利用 sequence 的序列信息,我们必须注入相对或者绝对位置的信息。因此,我们将 “Positional encoding” to the input embeddings at the bottoms of the encoder and decoder stacks. 位置编码和 embeddings 有相同的维度,所以这两个东西可以相加起来。

  本文我们采用 sine and cosine functions of different frequencies: 

  

  其中,pos 是 位置,i 是维度。

 

  


 

 

  

=============================================================================

Reference: 

Paper:

Code(PyTorch Version):

另一个不错的关于这个文章的 Blog: 

 

 

 

 

转载地址:http://wrmtx.baihongyu.com/

你可能感兴趣的文章
安装win7提示安装程序无法创建新的系统分区和定位现有系统分区
查看>>
那些年,我跳过的坑(一)
查看>>
快递查询接口的调用与解析案例
查看>>
我的友情链接
查看>>
服务器性能优化配置建议
查看>>
GetWindowRect
查看>>
oracle sql语句实现累加、累减、累乘、累除
查看>>
SCNetworkReachabilityRef监测网络状态
查看>>
3D地图的定时高亮和点击事件(基于echarts)
查看>>
接口由40秒到200ms优化记录
查看>>
java 视频播放 多人及时弹幕技术 代码生成器 websocket springmvc mybatis SSM
查看>>
Activiti6.0,spring5,SSM,工作流引擎,OA
查看>>
第十三章:SpringCloud Config Client的配置
查看>>
使用 GPUImage 实现一个简单相机
查看>>
CoinWhiteBook:区块链在慈善事业中的应用
查看>>
【二】express
查看>>
Mac上基于Github搭建Hexo博客
查看>>
What does corn harvester involve?
查看>>
阿里云服务器ECS开放8080端口
查看>>
前端常用排序详解
查看>>