动机、参考资料、涉及内容

动机

  • 学习 RWKV 的运作方式
  • 比较 RNN 与 Transformer
  • 分析 Transformer 各个组件的理论计算复杂度与实际运行时间

RNN 与 Transformer 推理/训练性能理论分析【Transformer实测分析待补充,参考李沐视频讲解】

本节主要讨论 RNN 和 Transformer (本节的特指使用标准 scaled dot-product attention 的原版 Transformer) 在训练时与推理时的性能

RNN

我们只考虑单向的 RNN, 相比于 Transformer 的三种模型结构 (encoder-only, decoder-only, encoder-decoder), RNN 对应的结构大概是这样【配图】:

  • encoder-only: 一个 RNN 作为 encoder
  • decoder-only: 一个 RNN 作为 decoder
  • encoder-decoder: 一个 RNN 作为 encoder, 另一个 RNN 作为 decoder

从性能分析的角度看, RNN 的 decoder-only 与 encoder-decoder 没有本质区别.

我们先简要回顾一下原始 RNN、LSTM、GRU 的计算逻辑, 【待补】

由此可见, RNN 在训练和推理时对序列长度均不能做并行处理, 但是它在推理与训练时的计算复杂度与长度 $L$ 是线性关系.

Transformer

标准的 scaled dot-product attention, 对于一个头, 在训练时, 计算公式如下: $Q\in\mathbb{R}^{L\times d}, K\in\mathbb{R}^{L\times d}, V\in\mathbb{R}^{L\times d}$,

\[Y = \text{softmax}(\frac{QK^T}{\sqrt{d} })V\]

其中

\[[\text{softmax}(A)]_{i, j}=\frac{\exp(A_{i, j})}{\sum_{j=1}^{L}\exp(A_{i, j})}\]

为了使分析更为仔细, 我们需要区分三种模型结构 (encoder-only, decoder-only, encoder-decoder) 在训练和推理时的计算复杂度.

首先回顾一下三种模型结构:

  • encoder-only (Bert): 这种结构涉及到的 attention 层只有 encoder-self-attention. 在训练和推理时都可以对长度并行处理
  • decoder-only (GPT2): 这种结构涉及到的 attention 层只有 decoder-self-attention. 在训练时可以利用 mask 对长度并行处理, 而推理时只能采用自回归的方式处理
  • encoder-decoder (T5): 这种结构涉及到的 attention 层有 encoder-self-attention, decoder-self-attention, decoder-cross-attention. 其中 encoder-self-attention 在训练和推理都可以对长度并行处理, decoder-self-attention 在训练时可以通过 mask 对长度进行并行处理, 推理时只能采用自回归的方式处理, decoder-cross-attention 在训练时可以对长度并行处理, 而推理也只能采用自回归的方式进行处理.

为了记号简单, 我们后面有时将 encoder-only, decoder-only, encoder-decoder 分别记作: EnO, DeO, EnDe, 将 encoder-self-attention, decoder-self-attention, decoder-cross-attention 分别记作: EnSA, DeSA, DeCA. 注意不同模型架构相同名称的 attention 层, 处理的数据量是不同的, 例如: EnO-EnSA 与 EnDe-EnSA, 前者既要处理输入又要处理输出, 后者只需要处理输出.

其次我们对所谓的“训练”与“推理”的含义做一些限定:

  • encoder-only (Bert)
    • 训练: 特指使用 encoder-only 做文本分类问题的训练过程
    • 推理: 特指使用 encoder-only 做文本分类问题的推理过程
  • decoder-only (GPT2)
    • 训练: 特指使用 decoder-only 做文本续写任务的训练过程
    • 推理: 特指使用 decoder-only 做文本续写任务的推理过程
  • encoder-decoder (T5)
    • 训练: 特指使用 encoder-decoder 做文本续写任务的训练过程
    • 推理: 特指使用 encoder-decoder 做文本续写任务的推理过程

我们分析一下浮点操作次数, $d$ 代表每个头的特征维数, $h$ 代表头的个数, $D=d\times h$ 代表输入/输出隐层维数, $L$ 代表输入长度, $M$ 代表输出长度 ($M$ 的概念仅适用于 decoder-only 与 encoder-decoder 模型结构中), 注意在下面的分析中, 我们将指数, 加法, 除法操作的复杂度均视为相同的常数. 提示: 形状为 $(m, k)$ 与形状为 $(k, n)$ 的矩阵乘法的浮点操作次数为 $m\times k\times n$.

  • qkv 的计算
    • encoder-self-attention (encoder-only/encoder-decoder): 训练/推理: $O(6L\times D^2)$.
      • 训练: 可并行处理
      • 推理: 可并行处理
    • decoder-self-attention (decoder-only): 训练/推理: $O(6(L+M)\times D^2)$.
      • 训练: 对 $L$ 和 $M$ 均可并行处理
      • 推理: 对 $L$ 可并行处理, 对 $M$ 不可并行处理
    • decoder-self-attention (encoder-decoder): 训练/推理: $O(6M\times D^2)$.
      • 训练: 可并行处理
      • 推理: 不可并行处理
    • decoder-cross-attention: 训练/推理: $O(6M\times D^2)$.
      • 训练: 可并行处理
      • 推理: 不可并行处理
  • o 的计算: 与 qkv 类似, 常数系数由 6 改为 2.
  • encoder-self-attention (encoder-only/encoder-decoder):
    • 训练: $O(h\times(2\times L^2d+L\times(3\times L)+2\times L^2d))=O(4L^2D+3hL^2)$, 其中第 2 项为 softmax 的浮点数计算次数, 可并行处理. 我们可以看到, 复杂度与 $L$ 呈平方关系.
    • 推理: 与训练时一样, 也可以并行处理
  • decoder-self-attention (decoder-only):
    • 训练: $\sum_{l=1}^{L+M}{h\times(2ld+3l+2ld)}=O(2(L+M)^2D+\frac{3}{2}h(L+M)^2)$, 可完全并行处理
    • 推理: 计算复杂度与训练时一致, 对求和式的前面 $L$ 可并行处理, 对求和式后 $M$ 个不能并行, 因此输出序列越长, 并行度越低
  • decoder-self-attention (encoder-decoder)
    • 训练: $\sum_{l=1}^{M}{h\times(2ld+3l+2ld)}=O(2M^2D+\frac{3}{2}hM^2)$, 可并行处理
    • 推理: 计算复杂度与训练时一致, 不可并行处理
  • decoder-cross-attention (encoder-decoder):
    • 训练: $O(M\times h\times(2Ld+3L+2Ld))=O(4LMD+3LMh)$, 可并行处理
    • 推理: 计算复杂度与训练时一致, 不可并行处理

总结如下:

encoder-only

算子类型 模型结构 运行时 计算复杂度(含常数项且含低阶项) 并行性质
encoder-self-attention encoder-only 训练/推理 $O(4L^2D+3hL^2)$ 对 $L$ 可并行处理
qkv计算 encoder-only 训练/推理 $O(6L\times D^2)$ 对 $L$ 可并行处理
o计算 encoder-only 训练/推理 $O(2L\times D^2)$ 对 $L$ 可并行处理
合计(encoder-self-attention) encoder-only 训练/推理 $O(8LD^2+4L^2D+3hL^2)$ 对 $L$ 可并行处理

decoder-only

算子类型 模型结构 运行时 计算复杂度(含常数项且含低阶项) 并行性质
decoder-self-attention decoder-only 训练 $O(2(L+M)^2D+\frac{3}{2}h(L+M)^2)$ 对 $L$ 和 $M$ 均可并行处理
qkv计算 decoder-only 训练 $O(6(L+M)\times D^2)$ 对 $L$ 和 $M$ 均可并行处理
o计算 decoder-only 训练 $O(2(L+M)\times D^2)$ 对 $L$ 和 $M$ 均可并行处理
合计(decoder-self-attention) decoder-only 训练 $O(8(L+M)D^2+2(L+M)^2D+\frac{3}{2}h(L+M)^2)$ 对 $L$ 和 $M$ 均可并行处理
decoder-self-attention decoder-only 推理 $O(2(L+M)^2D+\frac{3}{2}h(L+M)^2)$ $O(2L^2D+\frac{3}{2}hL^2)$ 的计算量可并行处理
qkv计算 decoder-only 推理 $O(6(L+M)\times D^2)$ $O(6LD^2)$ 的计算量可并行处理
o计算 decoder-only 推理 $O(2(L+M)\times D^2)$ $O(2LD^2)$ 的计算量可并行处理
合计(decoder-self-attention) decoder-only 推理 $O(8(L+M)D^2+2(L+M)^2D+\frac{3}{2}h(L+M)^2)$ $O(8LD^2+2L^2D+\frac{3}{2}hL^2)$ 的计算量可并行处理

encoder-decoder

EnDe-EnSA 模块

算子类型 模型结构 运行时 计算复杂度(含常数项且含低阶项) 并行性质
encoder-self-attention encoder-decoder 训练/推理 $O(4L^2D+3hL^2)$ 对 $L$ 可并行处理
qkv计算(EnSA) encoder-decoder 训练/推理 $O(6L\times D^2)$ 对 $L$ 可并行处理
o计算(EnSA) encoder-decoder 训练/推理 $O(2L\times D^2)$ 对 $L$ 可并行处理
合计(encoder-self-attention) encoder-decoder 训练/推理 $O(8LD^2+4L^2D+3hL^2)$ 对 $L$ 可并行处理

EnDe-DeSA 模块

算子类型 模型结构 运行时 计算复杂度(含常数项且含低阶项) 并行性质
decoder-self-attention encoder-decoder 训练 $O(2M^2D+\frac{3}{2}hM^2)$ 对 $M$ 可并行处理
qkv计算(DeSA) encoder-decoder 训练 $O(6MD^2)$ 对 $M$ 可并行处理
o计算(DeSA) encoder-decoder 训练 $O(2MD^2)$ 对 $M$ 可并行处理
合计(decoder-self-attention) encoder-decoder 训练 $O(8MD^2+2M^2D+\frac{3}{2}hM^2)$ 对 $M$ 可并行处理
decoder-self-attention encoder-decoder 推理 $O(2M^2D+\frac{3}{2}hM^2)$ 对 $M$ 不可并行处理
qkv计算(DeSA) encoder-decoder 推理 $O(6MD^2)$ 对 $M$ 不可并行处理
o计算(DeSA) encoder-decoder 推理 $O(2MD^2)$ 对 $M$ 不可并行处理
合计(decoder-self-attention) encoder-decoder 推理 $O(8MD^2+2M^2D+\frac{3}{2}hM^2)$ 对 $M$ 不可并行处理

EnDe-DeCA 模块

算子类型 模型结构 运行时 计算复杂度(含常数项且含低阶项) 并行性质
decoder-cross-attention encoder-decoder 训练 $O(4LMD+3LMh)$ 可并行处理
qkv计算(DeCA) encoder-decoder 训练 $O(6MD^2)$ 对 $M$ 可并行处理
o计算(DeCA) encoder-decoder 训练 $O(2MD^2)$ 对 $M$ 可并行处理
合计(decoder-cross-attention) encoder-decoder 训练 $O(8MD^2+4LMD+3LMh)$ 可并行处理
decoder-cross-attention encoder-decoder 推理 $O(4LMD+3LMh)$ 对 $M$ 不可并行处理
qkv计算(DeCA) encoder-decoder 推理 $O(6MD^2)$ 对 $M$ 不可并行处理
o计算(DeCA) encoder-decoder 推理 $O(2MD^2)$ 对 $M$ 不可并行处理
合计(decoder-cross-attention) encoder-decoder 推理 $O(8MD^2+4LMD+3LMh)$ 对 $M$ 不可并行处理

AFT

参考资料

RWKV 参考了 2021.05 发表的一篇 paper An Attention Free Transformer, 简称 AFT, 这篇论文主要是对常见的 Multihead scaled dot-product Attention 进行改进 (解决计算复杂度是序列长度的平方).

一般意义上的 attention 可以用这种方式进行定义: $V\in\mathbb{R}^{L\times d}$ 代表需要被注意的值, $\mathbf{y}\in\mathbb{R}^{d}$

\[\mathbf{y}=\sum_{l=1}^{L}{a_l V_{l, :}}\]

写成分量形式:

\[y_i = \langle \mathbf{a}, V_{:, i}\rangle\]

AFT 的计算公式如下: $Q\in\mathbb{R}^{L\times D}, K\in\mathbb{R}^{L\times D}, V\in\mathbb{R}^{L\times D}, W\in\mathbb{R}^{L\times L}$. 其中 $Q, K, V$ 由输入 $X\in\mathbb{R}^{L\times D}$ 经过可学习的全连接层计算得到, 而 $W$ 为可学习参数.

\[Y_{l,:}=\text{sigmoid}(Q_{l, :})\odot \frac{\sum_{l'=1}^{L}{[\exp{(K_{l',:} + W_{l, l'})}\odot V_{l',:}] }}{\sum_{l'=1}^{L}{\exp{(K_{l',:} + W_{l, l'})} } }\]

其中 $Y\in\mathbb{R}^{L\times D}$ 为 AFT 的计算结果, 注意上式中的指数项内部实际上是一个 $D$ 维向量, 除法也是逐元素的除法. 分析可知, 采用这种定义, 计算得出 $Y$ 的计算复杂度为: $O(6L^2D+2LD)$. 这里的复杂度可以用如下代码解释:

# Q, K, V: (L, D), W: (L, L)

Y = torch.zeros((L, D))
for l in range(L):
    Q_l_gate = torch.sigmoid(Q[l, :])  # D
    score = K + W[l].view(-1, 1)       # LD
    weight = torch.exp(score)          # LD
    sum = weight.sum(dim=-1)            # LD, sum: (D,)
    Y_l = torch.zeros(D)
    for i in range(L):
        Y_l += weight[i, :] / sum[i] * V[i, :]  # 3D
    Y_l = Y_l * Q_1_gate  # D
    Y[l] = Y_l