(P1) MoE: (Mixtral 8x7B)
动机、参考资料、涉及内容
以 Mixtral 8x7B 为例, 记录其推理及训练流程
推理: MixtralSparseMoeBlock
Mixtral 8x7B 与普通的 transformer 的“唯一”区别是将 Feed-Forward 做了改动, huggingface 源代码 对此的实现使用了太多的 torch tensor 的下标操作, 显得有些费解, 这里做些简化
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
import torch
import torch.nn.functional as F
class MyMixtralSparseMoeBlock(torch.nn.Module):
def __init__(self, hf_module: MixtralSparseMoeBlock):
super().__init__()
self.hf_module = hf_module
def forward(self, hidden_states):
# self.hf_module.gate = torch.nn.Linear(C, num_expert)
# self.hf_module.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(num_experts)]) # feed_forward
batch_size, sequence_length, hidden_dim = hidden_states.shape
router_logits = self.hf_module.gate(hidden_states) # (B, L, num_expert)
routing_weights = F.softmax(router_logits, dim=-1)
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
# routing_weights, selected_experts: (B, L, top_k)
# 例如:
# routing_weights[0][4] = [0.3, 0.2], selected_expert[0][4] = [0, 7]
# 表示第 0 个序列(在这个例子中总共 B=2 个序列), 第 4 个 token 选择第 0, 7 号专家
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# routing_weights[0][4] = [0.6, 0.4]
final_hidden_states = torch.zeros(B, L, C)
for b in range(B):
for l in range(L):
for k in range(top_k):
h = self.hf_module.experts[int(selected_experts[b, l, k])](hidden_states[b, l].view(1, -1))
final_hidden_states[b, l] += h.view(-1) * routing_weights[b, l, k]
return final_hidden_states, router_logits.view(batch_size*sequence_length, self.hf_module.num_experts)
B, L, C = 2, 64, 128
# 总共 8 个专家, 每次只激活其中的两个: 注意对于一个序列来说, 第一个token可能激活的是 [0, 7], 第二个token可能激活的是 [2, 4]
num_expert, top_k = 8, 2
config = MixtralConfig(hidden_size=C, num_local_experts=num_expert, num_experts_per_tok=top_k)
hf_module = MixtralSparseMoeBlock(config)
my_module = MyMixtralSparseMoeBlock(hf_module)
hidden_states = torch.rand(B, L, C)
a, b = hf_module(hidden_states) # (B, L, C), (B*L, num_expert)
c, d = my_module(hidden_states) # (B, L, C), (B*L, num_expert)
print(torch.allclose(a.detach(), c.detach(), atol=1e-6))
print(torch.allclose(b.detach(), d.detach(), atol=1e-6))