动机、参考资料、涉及内容

以 Mixtral 8x7B 为例, 记录其推理及训练流程

推理: MixtralSparseMoeBlock

Mixtral 8x7B 与普通的 transformer 的“唯一”区别是将 Feed-Forward 做了改动, huggingface 源代码 对此的实现使用了太多的 torch tensor 的下标操作, 显得有些费解, 这里做些简化

from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
import torch
import torch.nn.functional as F

class MyMixtralSparseMoeBlock(torch.nn.Module):
    def __init__(self, hf_module: MixtralSparseMoeBlock):
        super().__init__()
        self.hf_module = hf_module
    def forward(self, hidden_states):
        # self.hf_module.gate = torch.nn.Linear(C, num_expert)
        # self.hf_module.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(num_experts)])  # feed_forward
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        router_logits = self.hf_module.gate(hidden_states)  # (B, L, num_expert)
        routing_weights = F.softmax(router_logits, dim=-1)
        routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
        # routing_weights, selected_experts: (B, L, top_k)
        # 例如:
        # routing_weights[0][4] = [0.3, 0.2], selected_expert[0][4] = [0, 7]
        # 表示第 0 个序列(在这个例子中总共 B=2 个序列), 第 4 个 token 选择第 0, 7 号专家

        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # routing_weights[0][4] = [0.6, 0.4]

        final_hidden_states = torch.zeros(B, L, C)
        for b in range(B):
            for l in range(L):
                for k in range(top_k):
                    h = self.hf_module.experts[int(selected_experts[b, l, k])](hidden_states[b, l].view(1, -1))
                    final_hidden_states[b, l] += h.view(-1) * routing_weights[b, l, k]
        return final_hidden_states, router_logits.view(batch_size*sequence_length, self.hf_module.num_experts)

B, L, C = 2, 64, 128
# 总共 8 个专家, 每次只激活其中的两个: 注意对于一个序列来说, 第一个token可能激活的是 [0, 7], 第二个token可能激活的是 [2, 4]
num_expert, top_k = 8, 2
config = MixtralConfig(hidden_size=C, num_local_experts=num_expert, num_experts_per_tok=top_k)

hf_module = MixtralSparseMoeBlock(config)
my_module = MyMixtralSparseMoeBlock(hf_module)
hidden_states = torch.rand(B, L, C)

a, b = hf_module(hidden_states)  # (B, L, C), (B*L, num_expert)
c, d = my_module(hidden_states)  # (B, L, C), (B*L, num_expert)
print(torch.allclose(a.detach(), c.detach(), atol=1e-6))
print(torch.allclose(b.detach(), d.detach(), atol=1e-6))