(P1) RWKV 浅析
动机、参考资料、涉及内容
动机
- 学习 RWKV 的运作方式
- 比较 RNN 与 Transformer
- 分析 Transformer 各个组件的理论计算复杂度与实际运行时间
RNN 与 Transformer 推理/训练性能理论分析【Transformer实测分析待补充,参考李沐视频讲解】
本节主要讨论 RNN 和 Transformer (本节的特指使用标准 scaled dot-product attention 的原版 Transformer) 在训练时与推理时的性能
RNN
我们只考虑单向的 RNN, 相比于 Transformer 的三种模型结构 (encoder-only, decoder-only, encoder-decoder), RNN 对应的结构大概是这样【配图】:
- encoder-only: 一个 RNN 作为 encoder
- decoder-only: 一个 RNN 作为 decoder
- encoder-decoder: 一个 RNN 作为 encoder, 另一个 RNN 作为 decoder
从性能分析的角度看, RNN 的 decoder-only 与 encoder-decoder 没有本质区别.
我们先简要回顾一下原始 RNN、LSTM、GRU 的计算逻辑, 【待补】
由此可见, RNN 在训练和推理时对序列长度均不能做并行处理, 但是它在推理与训练时的计算复杂度与长度 $L$ 是线性关系.
Transformer
标准的 scaled dot-product attention, 对于一个头, 在训练时, 计算公式如下: $Q\in\mathbb{R}^{L\times d}, K\in\mathbb{R}^{L\times d}, V\in\mathbb{R}^{L\times d}$,
\[Y = \text{softmax}(\frac{QK^T}{\sqrt{d} })V\]其中
\[[\text{softmax}(A)]_{i, j}=\frac{\exp(A_{i, j})}{\sum_{j=1}^{L}\exp(A_{i, j})}\]为了使分析更为仔细, 我们需要区分三种模型结构 (encoder-only, decoder-only, encoder-decoder) 在训练和推理时的计算复杂度.
首先回顾一下三种模型结构:
- encoder-only (Bert): 这种结构涉及到的 attention 层只有 encoder-self-attention. 在训练和推理时都可以对长度并行处理
- decoder-only (GPT2): 这种结构涉及到的 attention 层只有 decoder-self-attention. 在训练时可以利用 mask 对长度并行处理, 而推理时只能采用自回归的方式处理
- encoder-decoder (T5): 这种结构涉及到的 attention 层有 encoder-self-attention, decoder-self-attention, decoder-cross-attention. 其中 encoder-self-attention 在训练和推理都可以对长度并行处理, decoder-self-attention 在训练时可以通过 mask 对长度进行并行处理, 推理时只能采用自回归的方式处理, decoder-cross-attention 在训练时可以对长度并行处理, 而推理也只能采用自回归的方式进行处理.
为了记号简单, 我们后面有时将 encoder-only, decoder-only, encoder-decoder 分别记作: EnO, DeO, EnDe, 将 encoder-self-attention, decoder-self-attention, decoder-cross-attention 分别记作: EnSA, DeSA, DeCA. 注意不同模型架构相同名称的 attention 层, 处理的数据量是不同的, 例如: EnO-EnSA 与 EnDe-EnSA, 前者既要处理输入又要处理输出, 后者只需要处理输出.
其次我们对所谓的“训练”与“推理”的含义做一些限定:
- encoder-only (Bert)
- 训练: 特指使用 encoder-only 做文本分类问题的训练过程
- 推理: 特指使用 encoder-only 做文本分类问题的推理过程
- decoder-only (GPT2)
- 训练: 特指使用 decoder-only 做文本续写任务的训练过程
- 推理: 特指使用 decoder-only 做文本续写任务的推理过程
- encoder-decoder (T5)
- 训练: 特指使用 encoder-decoder 做文本续写任务的训练过程
- 推理: 特指使用 encoder-decoder 做文本续写任务的推理过程
我们分析一下浮点操作次数, $d$ 代表每个头的特征维数, $h$ 代表头的个数, $D=d\times h$ 代表输入/输出隐层维数, $L$ 代表输入长度, $M$ 代表输出长度 ($M$ 的概念仅适用于 decoder-only 与 encoder-decoder 模型结构中), 注意在下面的分析中, 我们将指数, 加法, 除法操作的复杂度均视为相同的常数. 提示: 形状为 $(m, k)$ 与形状为 $(k, n)$ 的矩阵乘法的浮点操作次数为 $m\times k\times n$.
- qkv 的计算
- encoder-self-attention (encoder-only/encoder-decoder): 训练/推理: $O(6L\times D^2)$.
- 训练: 可并行处理
- 推理: 可并行处理
- decoder-self-attention (decoder-only): 训练/推理: $O(6(L+M)\times D^2)$.
- 训练: 对 $L$ 和 $M$ 均可并行处理
- 推理: 对 $L$ 可并行处理, 对 $M$ 不可并行处理
- decoder-self-attention (encoder-decoder): 训练/推理: $O(6M\times D^2)$.
- 训练: 可并行处理
- 推理: 不可并行处理
- decoder-cross-attention: 训练/推理: $O(6M\times D^2)$.
- 训练: 可并行处理
- 推理: 不可并行处理
- encoder-self-attention (encoder-only/encoder-decoder): 训练/推理: $O(6L\times D^2)$.
- o 的计算: 与 qkv 类似, 常数系数由 6 改为 2.
- encoder-self-attention (encoder-only/encoder-decoder):
- 训练: $O(h\times(2\times L^2d+L\times(3\times L)+2\times L^2d))=O(4L^2D+3hL^2)$, 其中第 2 项为 softmax 的浮点数计算次数, 可并行处理. 我们可以看到, 复杂度与 $L$ 呈平方关系.
- 推理: 与训练时一样, 也可以并行处理
- decoder-self-attention (decoder-only):
- 训练: $\sum_{l=1}^{L+M}{h\times(2ld+3l+2ld)}=O(2(L+M)^2D+\frac{3}{2}h(L+M)^2)$, 可完全并行处理
- 推理: 计算复杂度与训练时一致, 对求和式的前面 $L$ 可并行处理, 对求和式后 $M$ 个不能并行, 因此输出序列越长, 并行度越低
- decoder-self-attention (encoder-decoder)
- 训练: $\sum_{l=1}^{M}{h\times(2ld+3l+2ld)}=O(2M^2D+\frac{3}{2}hM^2)$, 可并行处理
- 推理: 计算复杂度与训练时一致, 不可并行处理
- decoder-cross-attention (encoder-decoder):
- 训练: $O(M\times h\times(2Ld+3L+2Ld))=O(4LMD+3LMh)$, 可并行处理
- 推理: 计算复杂度与训练时一致, 不可并行处理
总结如下:
encoder-only
算子类型 | 模型结构 | 运行时 | 计算复杂度(含常数项且含低阶项) | 并行性质 |
---|---|---|---|---|
encoder-self-attention | encoder-only | 训练/推理 | $O(4L^2D+3hL^2)$ | 对 $L$ 可并行处理 |
qkv计算 | encoder-only | 训练/推理 | $O(6L\times D^2)$ | 对 $L$ 可并行处理 |
o计算 | encoder-only | 训练/推理 | $O(2L\times D^2)$ | 对 $L$ 可并行处理 |
合计(encoder-self-attention) | encoder-only | 训练/推理 | $O(8LD^2+4L^2D+3hL^2)$ | 对 $L$ 可并行处理 |
decoder-only
算子类型 | 模型结构 | 运行时 | 计算复杂度(含常数项且含低阶项) | 并行性质 |
---|---|---|---|---|
decoder-self-attention | decoder-only | 训练 | $O(2(L+M)^2D+\frac{3}{2}h(L+M)^2)$ | 对 $L$ 和 $M$ 均可并行处理 |
qkv计算 | decoder-only | 训练 | $O(6(L+M)\times D^2)$ | 对 $L$ 和 $M$ 均可并行处理 |
o计算 | decoder-only | 训练 | $O(2(L+M)\times D^2)$ | 对 $L$ 和 $M$ 均可并行处理 |
合计(decoder-self-attention) | decoder-only | 训练 | $O(8(L+M)D^2+2(L+M)^2D+\frac{3}{2}h(L+M)^2)$ | 对 $L$ 和 $M$ 均可并行处理 |
decoder-self-attention | decoder-only | 推理 | $O(2(L+M)^2D+\frac{3}{2}h(L+M)^2)$ | $O(2L^2D+\frac{3}{2}hL^2)$ 的计算量可并行处理 |
qkv计算 | decoder-only | 推理 | $O(6(L+M)\times D^2)$ | $O(6LD^2)$ 的计算量可并行处理 |
o计算 | decoder-only | 推理 | $O(2(L+M)\times D^2)$ | $O(2LD^2)$ 的计算量可并行处理 |
合计(decoder-self-attention) | decoder-only | 推理 | $O(8(L+M)D^2+2(L+M)^2D+\frac{3}{2}h(L+M)^2)$ | $O(8LD^2+2L^2D+\frac{3}{2}hL^2)$ 的计算量可并行处理 |
encoder-decoder
EnDe-EnSA 模块
算子类型 | 模型结构 | 运行时 | 计算复杂度(含常数项且含低阶项) | 并行性质 |
---|---|---|---|---|
encoder-self-attention | encoder-decoder | 训练/推理 | $O(4L^2D+3hL^2)$ | 对 $L$ 可并行处理 |
qkv计算(EnSA) | encoder-decoder | 训练/推理 | $O(6L\times D^2)$ | 对 $L$ 可并行处理 |
o计算(EnSA) | encoder-decoder | 训练/推理 | $O(2L\times D^2)$ | 对 $L$ 可并行处理 |
合计(encoder-self-attention) | encoder-decoder | 训练/推理 | $O(8LD^2+4L^2D+3hL^2)$ | 对 $L$ 可并行处理 |
EnDe-DeSA 模块
算子类型 | 模型结构 | 运行时 | 计算复杂度(含常数项且含低阶项) | 并行性质 |
---|---|---|---|---|
decoder-self-attention | encoder-decoder | 训练 | $O(2M^2D+\frac{3}{2}hM^2)$ | 对 $M$ 可并行处理 |
qkv计算(DeSA) | encoder-decoder | 训练 | $O(6MD^2)$ | 对 $M$ 可并行处理 |
o计算(DeSA) | encoder-decoder | 训练 | $O(2MD^2)$ | 对 $M$ 可并行处理 |
合计(decoder-self-attention) | encoder-decoder | 训练 | $O(8MD^2+2M^2D+\frac{3}{2}hM^2)$ | 对 $M$ 可并行处理 |
decoder-self-attention | encoder-decoder | 推理 | $O(2M^2D+\frac{3}{2}hM^2)$ | 对 $M$ 不可并行处理 |
qkv计算(DeSA) | encoder-decoder | 推理 | $O(6MD^2)$ | 对 $M$ 不可并行处理 |
o计算(DeSA) | encoder-decoder | 推理 | $O(2MD^2)$ | 对 $M$ 不可并行处理 |
合计(decoder-self-attention) | encoder-decoder | 推理 | $O(8MD^2+2M^2D+\frac{3}{2}hM^2)$ | 对 $M$ 不可并行处理 |
EnDe-DeCA 模块
算子类型 | 模型结构 | 运行时 | 计算复杂度(含常数项且含低阶项) | 并行性质 |
---|---|---|---|---|
decoder-cross-attention | encoder-decoder | 训练 | $O(4LMD+3LMh)$ | 可并行处理 |
qkv计算(DeCA) | encoder-decoder | 训练 | $O(6MD^2)$ | 对 $M$ 可并行处理 |
o计算(DeCA) | encoder-decoder | 训练 | $O(2MD^2)$ | 对 $M$ 可并行处理 |
合计(decoder-cross-attention) | encoder-decoder | 训练 | $O(8MD^2+4LMD+3LMh)$ | 可并行处理 |
decoder-cross-attention | encoder-decoder | 推理 | $O(4LMD+3LMh)$ | 对 $M$ 不可并行处理 |
qkv计算(DeCA) | encoder-decoder | 推理 | $O(6MD^2)$ | 对 $M$ 不可并行处理 |
o计算(DeCA) | encoder-decoder | 推理 | $O(2MD^2)$ | 对 $M$ 不可并行处理 |
合计(decoder-cross-attention) | encoder-decoder | 推理 | $O(8MD^2+4LMD+3LMh)$ | 对 $M$ 不可并行处理 |
AFT
参考资料
- 原始论文: An Attention Free Transformer
- 实现, 作者们似乎没有将代码开源, 这里找到几个看上去靠谱的开源代码
- labml: https://nn.labml.ai/transformers/aft/index.html, 也可以直接看 代码
- 一个 star 数较多的第三方实现: https://github.com/rish-16/aft-pytorch
RWKV 参考了 2021.05 发表的一篇 paper An Attention Free Transformer, 简称 AFT, 这篇论文主要是对常见的 Multihead scaled dot-product Attention 进行改进 (解决计算复杂度是序列长度的平方).
一般意义上的 attention 可以用这种方式进行定义: $V\in\mathbb{R}^{L\times d}$ 代表需要被注意的值, $\mathbf{y}\in\mathbb{R}^{d}$
\[\mathbf{y}=\sum_{l=1}^{L}{a_l V_{l, :}}\]写成分量形式:
\[y_i = \langle \mathbf{a}, V_{:, i}\rangle\]AFT 的计算公式如下: $Q\in\mathbb{R}^{L\times D}, K\in\mathbb{R}^{L\times D}, V\in\mathbb{R}^{L\times D}, W\in\mathbb{R}^{L\times L}$. 其中 $Q, K, V$ 由输入 $X\in\mathbb{R}^{L\times D}$ 经过可学习的全连接层计算得到, 而 $W$ 为可学习参数.
\[Y_{l,:}=\text{sigmoid}(Q_{l, :})\odot \frac{\sum_{l'=1}^{L}{[\exp{(K_{l',:} + W_{l, l'})}\odot V_{l',:}] }}{\sum_{l'=1}^{L}{\exp{(K_{l',:} + W_{l, l'})} } }\]其中 $Y\in\mathbb{R}^{L\times D}$ 为 AFT 的计算结果, 注意上式中的指数项内部实际上是一个 $D$ 维向量, 除法也是逐元素的除法. 分析可知, 采用这种定义, 计算得出 $Y$ 的计算复杂度为: $O(6L^2D+2LD)$. 这里的复杂度可以用如下代码解释:
# Q, K, V: (L, D), W: (L, L)
Y = torch.zeros((L, D))
for l in range(L):
Q_l_gate = torch.sigmoid(Q[l, :]) # D
score = K + W[l].view(-1, 1) # LD
weight = torch.exp(score) # LD
sum = weight.sum(dim=-1) # LD, sum: (D,)
Y_l = torch.zeros(D)
for i in range(L):
Y_l += weight[i, :] / sum[i] * V[i, :] # 3D
Y_l = Y_l * Q_1_gate # D
Y[l] = Y_l