动机、参考资料、涉及内容

使用

注意, 只分析 v0.5.0 的结构, github 上 main 分支是计划对 v0.6.0 做目录结构的大改, 这个改动基本上是:

peft/tuners/lora.py 
->
peft/tuners/lora
  - __init__.py
  - config.py # LoraConfig
  - layer.py  # lora 的一些 Layer 定义
  - model.py  # LoraModel
  - bnb.py    # bitandbytes 的一些 Layer 定义
  - gptq.py   # GPTQ 的一些 Layer 定义

没有实质的代码变动, 且对外接口实际上通过 __init__.py 之后实际上也是无感知的, 出于本文

Config 类的继承关系:

PushToHubMixin: transformers.utils.PushToHubMixin, 相当于只有 `push_to_hub` 方法
  - PeftConfigMixin: peft.utils.config.PeftConfigMixin, 相当于只增加了 `save_pretrained` 和 `from_pretrained` 方法
    - PeftConfig: peft.utils.config.PeftConfig, 只是增加了一些属性, 比较关键的是: `peft_type`, `base_model_name_or_path`, `task_type`, `inference_mode`
      - LoraConfig: peft.tuners.lora.LoraConfig, 设定 `peft_type=PeftType.LORA`, 再增加了一些其他属性
      - PromptLearningConfig: peft.utils.config.PromptLearningConfig, 只增加了一些属性: prefix_tuning, prompt_tuning, prompt_encoder 三个算法共用的一些属性
        - PrefixTuningConfig: peft.tuners.prefix_tuning.PrefixTuningConfig, 设定 `peft_type=PeftType.PREFIX_TUNING`, 再增加了一些其他属性

总之, peft/utils/{peft_type}.py 文件中定义了相应的 Config 类, 而这个类本质上就是附带了 push_to_hub, save_pretrained, from_pretrained 的“字典”

PeftConfig 的重要属性有这些:

  • task_type: 有时候会跟 "model_type" 这个字眼混用(model_type 某些情况下是指 transformers 的模型名称, 例如: bert), 此参数在 prefix_tuning, p_tuning, prompt_tuning必须要指定
  • peft_type:
  • is_prompt_learning: 只有 prefix_tuning, p_tuning, prompt_tuning 这几类为 True
  • is_adaption_prompt: 只有 adaption_promptTrue

总结如下【待完成】

类型 PeftConfig类名 peft_type task_type is_prompt_learning is_adaption_prompt
lora lora

PeftModel 是对外的核心类, 其内部包含 LoraModel, PrefixEncoder 这种特殊的层, 而 PeftModelForSequenceClassificationPeftModel 的子类.

PeftModel 做了可以切换多个 adapter 的设计:

lora_model = PeftModel(model, peft_config: PeftConfig, adapter_name="default")
lora_model.active_adapter  # default
lora_model.peft_config  # {"default": LoraConfig(...)}
lora_model.active_peft_config == lora_model.peft_config[lora_model.active_adapter]

lora_model.get_base_model()  # 必定是 transformers 里的 PretrainedModel
lora_model.base_model        # 取决于具体的算法, 如果是 lora, 则是 LoraModel, 如果是 prefix_tuning, 则为 PretrainedModel

下面主要解释这里发生了什么? 我们按重要顺序关注如下几点:

  • PeftModel.__init__
  • PeftModel.forward
  • 保存模型相关
  • 插拔 lora 相关内容
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model

model_checkpoint = "google/vit-base-patch16-224-in21k"

model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    # label2id=label2id,
    # id2label=id2label,
    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)
print_trainable_parameters(model)

config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["classifier"],
)
lora_model = get_peft_model(model, config)  # 在这种情况下, 基本等价于下面这一行
# lora_model = PeftModel(model, config, adapter_name="defalut")
print_trainable_parameters(lora_model)

在上面的例子里, PeftModel(model, config) 中最为关键的步骤是执行 self.base_model = LoraModel(model, config)

peft/tuners/lora.py 实质上就是这几个类的定义

LoraConfig(PeftConfig)
BaseTunerLayer(ABC)  # merge/unmerge 抽象方法, peft.tuners.tuners_utils.BaseTunerLayer
  - LoraLayer(BaseTunerLayer)
    - Linear(nn.Linear, LoraLayer)
    - Embedding(nn.Embedding, LoraLayer)
    - Conv2D(nn.Conv2D, LoraLayer)
    - Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer)  # 与 bnb 相关
    - Linear4bit(bnb.nn.Linear4bit, LoraLayer)      # 与 bnb 相关
    - QuantLinear(torch.nn.Module, LoraLayer)       # 似乎与 GPTQ 相关

BaseTuner(ABC, nn.Module)  # peft.tuners.tuners_utils.BaseTuner
LoraModel(BaseTuner)

首先回顾 LoRA 的细节:

  • Linear:
  • Conv2d: 假设原本的层为 layer=Conv2d(in_feat, out_feat, conv_size, stride, ...), 使用 LoRA 的原理是可以将卷积核拆开计算:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    in_features = 3
    out_features = 4
    r = 2
    kernel_size=(3, 3)
    stride = 1
    padding = 1
    H, W = 10, 10
    bias = True
    
    layer_a = nn.Conv2d(in_features, r, kernel_size, stride, padding, bias=False)
    layer_b = nn.Conv2d(r, out_features, (1, 1), (1, 1), bias=False)
    layer = nn.Conv2d(in_features, out_features, kernel_size, stride, padding, bias=bias)
    
    def forward():
      result = F.conv2d(
          x,
          layer.weight,
          bias=layer.bias,
          stride=layer.stride,
          padding=layer.padding,
          dilation=layer.dilation,
      )
      result = result + layer_b(layer_a(x))
      return result
    
    def merged_forward():
        merged_w = F.conv2d(
            layer_a.weight.permute(1, 0, 2, 3),
            layer_b.weight,  # (out,r,1,1)
        ).permute(1, 0, 2, 3) + layer.weight
        result = F.conv2d(
            x,
            merged_w,
            bias=layer.bias,
            stride=layer.stride,
            padding=layer.padding,
            dilation=layer.dilation,
        )
        return result
    torch.allclose(forward(), merged_forward())  # True
    

我们接下来仔细看一下这条继承链 BaseTunerLayer, LoraLayer, Linear 的源码

# BaseTunerLayer 只定义了一个类属性 active_adapter, 以及两个抽象方法 merge 与 unmerge
class LoraLayer(BaseTunerLayer):
    def __init__(self, in_features: int, out_features: int, **kwargs):
        self.r = {}
        self.lora_alpha = {}
        self.scaling = {}
        self.lora_dropout = nn.ModuleDict({})
        self.lora_A = nn.ModuleDict({})
        self.lora_B = nn.ModuleDict({})
        # For Embedding layer
        self.lora_embedding_A = nn.ParameterDict({})
        self.lora_embedding_B = nn.ParameterDict({})
        # Mark the weight as unmerged
        self.merged = False
        self.disable_adapters = False
        self.in_features = in_features
        self.out_features = out_features
        self.kwargs = kwargs

    def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
        self.r[adapter_name] = r
        self.lora_alpha[adapter_name] = lora_alpha
        if lora_dropout > 0.0:
            lora_dropout_layer = nn.Dropout(p=lora_dropout)
        else:
            lora_dropout_layer = nn.Identity()

        self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
        # Actual trainable parameters
        if r > 0:
            self.lora_A.update(nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)}))
            self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)}))
            self.scaling[adapter_name] = lora_alpha / r   # 注意这个放缩系数在 forward 中被使用
        if init_lora_weights:
            self.reset_lora_parameters(adapter_name)
        self.to(self.weight.device)
    
    def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): ...
    def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): ...

    def reset_lora_parameters(self, adapter_name):
        if adapter_name in self.lora_A.keys():
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B[adapter_name].weight)
        if adapter_name in self.lora_embedding_A.keys():
            # initialize a the same way as the default for nn.linear and b to zero
            nn.init.zeros_(self.lora_embedding_A[adapter_name])
            nn.init.normal_(self.lora_embedding_B[adapter_name])

class Linear(nn.Linear, LoraLayer):
    # Lora implemented in a dense layer
    def __init__(
        self,
        adapter_name: str,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,  # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        is_target_conv_1d_layer: bool = False,
        **kwargs,
    ):
        init_lora_weights = kwargs.pop("init_lora_weights", True)

        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
        # Freezing the pre-trained weight matrix
        self.weight.requires_grad = False

        self.fan_in_fan_out = fan_in_fan_out
        if fan_in_fan_out:
            self.weight.data = self.weight.data.T

        nn.Linear.reset_parameters(self)
        self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
        self.active_adapter = adapter_name
        self.is_target_conv_1d_layer = is_target_conv_1d_layer

    def merge(self):
        if self.active_adapter not in self.lora_A.keys():
            return
        if self.merged:
            warnings.warn("Already merged. Nothing to do.")
            return
        if self.r[self.active_adapter] > 0:
            self.weight.data += self.get_delta_weight(self.active_adapter)
            self.merged = True

    def unmerge(self):
        if self.active_adapter not in self.lora_A.keys():
            return
        if not self.merged:
            warnings.warn("Already unmerged. Nothing to do.")
            return
        if self.r[self.active_adapter] > 0:
            self.weight.data -= self.get_delta_weight(self.active_adapter)
            self.merged = False

    def get_delta_weight(self, adapter):
        return (
            transpose(
                self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
                self.fan_in_fan_out,
            )
            * self.scaling[adapter]
        )

    def forward(self, x: torch.Tensor):
        previous_dtype = x.dtype
        if self.active_adapter not in self.lora_A.keys():
            return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        if self.disable_adapters:
            if self.r[self.active_adapter] > 0 and self.merged:
                self.unmerge()
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        elif self.r[self.active_adapter] > 0 and not self.merged:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

            x = x.to(self.lora_A[self.active_adapter].weight.dtype)

            result += (
                self.lora_B[self.active_adapter](
                    self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
                )
                * self.scaling[self.active_adapter]
            )
        else:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

        result = result.to(previous_dtype)

        return result

我们只需要重点以下用法即可:

from peft.tuners.lora import Linear
import torch

def show_fn(layer, x):
    print("main weight: ": layer.weight.data[:2, :2])
    print("adapters:", layer.lora_A.keys())
    print("active_adapter:", layer.active_adapter)
    print("active_adapter_delta_weight:", layer.get_delta_weight(layer.active_adapter))
    print("forward:", layer(x)[0])

# init_lora_weights 默认是 True, 表示增加的 adapter 的 lora_B 会是 0
layer = Linear(adapter_name="adapter_a", in_features=6, out_features=4, r=3, lora_alpha=1, lora_dropout=0.0, init_lora_weights=False)
layer.training = False
x = torch.rand(2, 6)
show_fn(layer, x)

layer.merge()                # 将 "adapter_a" 的额外权重合并至 layer.weight 上
show_fn(layer, x)

layer.update_layer(adapter_name="adapter_b", r=2, lora_alpha=0.6, lora_dropout=0.5, init_lora_weights=False)      # 增加一些参数, 这总是安全的, 不会切换 active_adapter, 也不会参数合并
show_fn(layer, x)

#  为确保万一, 任何情况下切换 active_adapter 时都要先 unmerge:
layer.unmerge()              # 必须先 unmerge 掉 "adapter_a"
layer.active_adapter = "adapter_b"
layer.merge()                # merge "adapter_b"
show_fn(layer, x)


layer.weight, layer.bias  # 即 nn.Linear 的 weight 和 bias 参数
layer.lora_A   # {adapter_name: nn.Linear} 字典

我们接下来仔细看一下这条继承链或包含关系 BaseTuner, LoraModel, PeftModel 的源码

class BaseTuner:
    # 以子类 LoraModel 为例, 最主要的目的是替换层, 例如将 nn.Linear 替换为 peft.tuners.lora.Linear
    def inject_adapter(self, model, adapter_name): ...
class LoraModel:
    # 打开/阻止 adapter 的计算, 实现方式是设置 peft.tuners.lora.Linear 的 disable_adapters 开关
    def enable_adapter_layers(self): ...
    def disable_adapter_layers(self): ...
    # 设定 adapter
    def set_adapter(self, adapter_name): ...
    # 删除 adapter
    def delete_adapter(self, adapter_name): ...
    
    @staticmethod
    def _replace_module(parent, child_name, new_module, child):
        setattr(parent, child_name, new_module)
        new_module.weight = child.weight  # 这一行很关键
        ...
    # 主要代码如下: 简单来说就是例如将 peft.tuners.lora.Linear 重新替换回 nn.Linear
    def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False):
        key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
        for key in tqdm(key_list, disable=not progressbar,):
            parent, target, target_name = _get_submodules(self.model, key)
            if isinstance(target, peft.tuners.lora.Linear):
                new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
                if merge:
                    target.merge()
                self._replace_module(parent, target_name, new_module, target)
            if isinstance(target, ModulesToSaveWrapper):
                setattr(parent, target_name, target.modules_to_save[target.active_adapter])
    def merge_and_unload(self, progressbar=False):
        self._unload_and_optionally_merge(merge=True, progressbar=progressbar)
    def unload(self)
        self._unload_and_optionally_merge(merge=False, progressbar=False)
    
    # 如下方法可暂时忽略: 用于合并多个 adapter
    def add_weighted_adapter(...): ...
    def _svd_weighted_adapter(...): ...

class PeftModel:
    def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default"):
        super().__init__()
        self.base_model = model
        self.config = getattr(self.base_model, "config", {"model_type": "custom"})
        self.modules_to_save = None
        self.peft_config = {}
        self.active_adapter = adapter_name
        self.peft_type = peft_config.peft_type
        if not peft_config.is_prompt_learning:  # lora 会进入这里: PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type] 是 LoraModel
            self.peft_config[adapter_name] = peft_config
            self.base_model = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type](
                self.base_model, self.peft_config, adapter_name
            )
            self.set_additional_trainable_modules(peft_config, adapter_name)
        else:
            self.add_adapter(adapter_name, peft_config)
        # 以下是 gradient_checkpointing 与张量并行相关, 可忽略
        if getattr(model, "is_gradient_checkpointing", True):
            model = self._prepare_model_for_gradient_checkpointing(model)
        if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
            self.base_model.config.pretraining_tp = 1
    
    def add_adapter(self, adapter_name: str, peft_config: PeftConfig):
        self.peft_config[adapter_name] = peft_config
        try:
            if peft_config.is_prompt_learning:  # 目前只有 prefix_tuning, p_tuning, prompt_tuning
                if hasattr(self.config, "to_dict"):
                    dict_config = self.config.to_dict()
                else:
                    dict_config = self.config

                peft_config = _prepare_prompt_learning_config(peft_config, dict_config)
                self._setup_prompt_encoder(adapter_name)
            elif peft_config.is_adaption_prompt:  # 目前只有 adaption_prompt
                self.base_model.add_adapter(adapter_name, peft_config)
            else:   # Lora, ...
                self.base_model.inject_adapter(self, adapter_name)
        except Exception:  # somthing went wrong, roll back
            del self.peft_config[adapter_name]
            raise

        self.set_additional_trainable_modules(peft_config, adapter_name)

    def set_additional_trainable_modules(self, peft_config, adapter_name):
        if getattr(peft_config, "modules_to_save", None) is not None:
            if self.modules_to_save is None:
                self.modules_to_save = set(peft_config.modules_to_save)
            else:
                self.modules_to_save.update(peft_config.modules_to_save)
            _set_trainable(self, adapter_name)  # 细节从略

    def forward(self, *args: Any, **kwargs: Any):
        return self.get_base_model()(*args, **kwargs)

    # 注意有可能返回 PeftModel, 也可能返回子类 PeftModelForSequenceClassification
    @classmethod
    def from_pretrained(...): ...
  • LoraModel 实例化之后怎么增加一个 adapter_name 呢?
    # 方法1: 不确定, 参考了 BaseTuner 的 __init__ 方法
    lora_model = LoraModel(...)
    new_adapter_name, new_peft_config = "new_adapter", LoraConfig(...)
    lora_model.peft_config.update({new_adapter_name: new_peft_config})
    lora_model.inject_adapter(lora.model, new_adapter_name)  # 这里有个副作用是会把现在的layer替换掉
    lora_model.model.peft_config = lora_model.peft_config
      
    # 方法2: 不确定, 参考了 LoraModel 的 _set_adapter_layers 方法
    new_adapter_name, new_peft_config = "new_adapter", LoraConfig(...)
    lora_model.peft_config.update({new_adapter_name: new_peft_config})
    for module in self.model.modules():
        if isinstance(module, LoraLayer):
            module.update_layer(...)  # 或者 update_layer_conv2d
    lora_model.model.peft_config = lora_model.peft_config
    
  • PeftModel 的子类例如 PeftModelForSequenceClassification 主要只是重载了 forward 方法

回到最开始, 一共有几种方式初始化一个 “peft model”

# 方式一: PeftModel 的 __init__ 方法

# 方式二: 使用 get_peft_model: 本质上只是分发到 PeftModel, PeftModelForSeq2SeqLM 的 __init__ 方法
def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel:
    """
    Returns a Peft model object from a model and a config.

    Args:
        model ([`transformers.PreTrainedModel`]): Model to be wrapped.
        peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
    """
    model_config = getattr(model, "config", {"model_type": "custom"})
    if hasattr(model_config, "to_dict"):
        model_config = model_config.to_dict()

    peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)

    if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
        return PeftModel(model, peft_config, adapter_name=adapter_name)
    if peft_config.is_prompt_learning:
        peft_config = _prepare_prompt_learning_config(peft_config, model_config)
    return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)

# 方式三: PeftModel.from_pretrained
from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig

peft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)

# 方式四: AutoPeftModel.from_pretrained, 本质上是 AutoModelForXXX.from_pretrained + AutoPeftModelForYYY.from_pretrained
from peft import AutoPeftModelForCausalLM, AutoPeftModel
peft_model = AutoPeftModelForCausalLM.from_pretrained("ybelkada/opt-350m-lora")
peft_model = AutoPeftModel.from_pretrained("smangrul/openai-whisper-large-v2-LORA-colab")
  • PeftModel.__init__, PeftModelForSeq2SeqLM.__init__: 后面几类方式都以此为基础
  • get_peft_model 方法: 分发至 PeftModel, PeftModelForSeq2SeqLM 的 __init__
  • PeftModel.from_pretrained: __init__load_state_dict
  • AutoPeftModel.from_pretrained: AutoModelForXXX.from_pretrained + AutoPeftModelForYYY.from_pretrained
# lora
PeftModel: 包含一个 LoraModel, LoraModel包含一个PretrainedModel, PretrainedModel中的nn.Linear被替换为了peft.tuners.lora.Linear
# p_tuning: 必须是 PeftModelForXXX
PeftModelForCausalLM: 包含一个PretrainedModel, PeftModelForCausalLM  forward  PretrainedModel 基础上做了修改

各种 peft 总结:

针对 Encoder/Decoder 结构 (PeftModelForCausalLM):

  • prompt_tuning/p_tuning: 直接将输入的 input_embedding 拼接上 prefix, 后面的计算流程不变
  • prefix_tuning: 输入的 input_embedding 不变, 每一层都追加一份 keyvalue

针对 Encoder-Decoder 结构【待研究】:

prompt_tuning

# config:
# {"prompt_tuning_init_text": "hello", "prompt_tuning_init": "TEXT"}
# {"prompt_tuning_init": "RANDOM"}
class PromptEmbedding(torch.nn.Module):
    def __init__(self, config, word_embeddings: nn.Embeding):
        # word_embeddings 是 base_model 的 embedding 层
        super().__init__()
        total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
        self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
        if config.prompt_tuning_init == PromptTuningInit.TEXT:  # TEXT or RANDOM
            # 首先使用 word_embededings 将 config.prompt_tuning_init_text 转化为向量, 然后匹配上 total_virtual_tokens 的长度:
            # 如果长度不足 total_virtual_tokens, 则重复: [11, 2, 3], 5 -> [11, 2, 3, 11, 2]
            # 如果长度超出 total_virtual_tokens, 则截断: [11, 2, 3, 4, 5], 3 -> [11, 2, 3]
            # 最后使用 embedding 结果对 self.embedding.weight 进行初始化
            ...
        
    def forward(self, indices):
        prompt_embeddings = self.embedding(indices)
        return prompt_embeddings

p_tuning

看上去只是 prompt_tuning 的重参数化

class PromptEncoder(torch.nn.Module):
    def __init__(self, config):
        ...
        total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
        self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim)
        if not config.inference_mode:
            if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
                self.lstm_head = torch.nn.LSTM(...)
                self.mlp_head = torch.nn.Sequential(...)  # Linear, Relu, Linear
            elif self.encoder_type == PromptEncoderReparameterizationType.MLP:  # 默认值
                self.mlp_head = torch.nn.Sequential(...)  # Linear, Relu, Linear, Relu, Linear
            else:
                raise ValueError("...")
    def forward(self, indices):  # inference_mode 为 True 时怎么办, 似乎没有处理 inference_mode 为 True 的情况???
        input_embeds = self.embedding(indices)
        if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
            output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0])
        elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
            output_embeds = self.mlp_head(input_embeds)
        else:
            raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")
        return output_embeds

prefix_tuning

class PrefixEncoder(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        token_dim = config.token_dim
        num_layers = config.num_layers
        encoder_hidden_size = config.encoder_hidden_size
        num_virtual_tokens = config.num_virtual_tokens
        if self.prefix_projection and not config.inference_mode:  # inference_mode 仅有 embedding
            self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
            self.transform = torch.nn.Sequential(
                torch.nn.Linear(token_dim, encoder_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
            )
        else:
            self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)
    
    def forward(self, prefix: torch.Tensor):  # 每一层都有一个 key 和 value, 也就是说增加的这些 virtual token 不走正常的 attention 流程
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.transform(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values