动机、参考资料、涉及内容

待定

总览

硬件相关

  • Habana: HPU 的支持
  • Intel: Intel Neural Compressor/OpenVINO 的支持
  • AWS Trainium/Inferentia: AWS 支持
  • Furiosa: NPU 的支持

模型格式转换

  • Onnx Runtime
  • Exporters: onnx/tflite

其他:

  • Torch FX: torch.fx, 没找到应用
  • BetterTransformer: 推理优化(目前都还是基于pytorch原生的优化)
    • fastpath: nn.TransformerEncoderLayer
    • Flash Attention and Memory-Efficient Attention: torch.nn.functional.scaled_dot_product_attention
  • LLM Quantization: 大模型量化的支持
    • GPTQ

备注: 似乎不包含 pytorch 原生的 int8 量化的支持

BetterTransformer

flash-attention 作为 bettertransformer 的一部分被实现在 pytorch 中: torch.nn.functional.scaled_dot_product_attention

bettertransformer:

🤗 optimum 中用这样的做法将模型进行转换

from transformers import AutoModelForCausalLM
from optimum.bettertransformer import BetterTransformer
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda:0")
bt_model = BetterTransformer.transform(model, keep_original_model=True)  # 与正常的推理过程唯一增加的一行, 其余地方都不动
# input_ids, attention_mask= ...
model.generate(input_ids, attention_mask=masks, max_new_tokens=20, pad_token_id=model.config.eos_token_id)

实现原理如下

# optimum.bettertransformer.models.decoder_models.GPT2AttentionLayerBetterTransformer
class BetterTransformer:
    def transform(model, ...):
        # 递归函数, 最终执行的是
        # bettertransformer_module = BetterTransformerManager.MODEL_MAPPING[config.model_type][target_class](module, config)
        # model._modules[name] = bettertransformer_module

        # 而 BetterTransformerManager.MODEL_MAPPING 包含:
        # "gpt2": {"GPT2Attention": GPT2AttentionLayerBetterTransformer}
    def reverse(...):
        ...

from transformer.models.gpt2 import GPT2Attention

# optimum.bettertransformer.models.base.BetterTransformerBaseLayer
class BetterTransformerBaseLayer:
    ...  # 没有太多涉及到forward的东西

# optimum.bettertransformer.models.decoder_models.GPT2AttentionLayerBetterTransformer
class GPT2AttentionLayerBetterTransformer(BetterTransformerBaseLayer, GPT2Attention):
    _attn = gpt2_wrapped_scaled_dot_product  # 此处为关键, 将 GPT2Attention 中的 _attn 函数替换
    def forward(self, *args, **kwargs):
        super().forward_checker()
        return super().forward(*args, **kwargs)

def gpt2_wrapped_scaled_dot_product(...):  # 这个函数用来替换 GPT2Attention._attn 函数
    # 最终使用到了 pytorch 的 torch.nn.functional.scaled_dot_product_attention 函数