(P1) Huggingface PEFT
动机、参考资料、涉及内容
使用
注意, 只分析 v0.5.0 的结构, github 上 main 分支是计划对 v0.6.0 做目录结构的大改, 这个改动基本上是:
peft/tuners/lora.py
->
peft/tuners/lora
- __init__.py
- config.py # LoraConfig
- layer.py # lora 的一些 Layer 定义
- model.py # LoraModel
- bnb.py # bitandbytes 的一些 Layer 定义
- gptq.py # GPTQ 的一些 Layer 定义
没有实质的代码变动, 且对外接口实际上通过 __init__.py
之后实际上也是无感知的, 出于本文
Config 类的继承关系:
PushToHubMixin: transformers.utils.PushToHubMixin, 相当于只有 `push_to_hub` 方法
- PeftConfigMixin: peft.utils.config.PeftConfigMixin, 相当于只增加了 `save_pretrained` 和 `from_pretrained` 方法
- PeftConfig: peft.utils.config.PeftConfig, 只是增加了一些属性, 比较关键的是: `peft_type`, `base_model_name_or_path`, `task_type`, `inference_mode`
- LoraConfig: peft.tuners.lora.LoraConfig, 设定 `peft_type=PeftType.LORA`, 再增加了一些其他属性
- PromptLearningConfig: peft.utils.config.PromptLearningConfig, 只增加了一些属性: prefix_tuning, prompt_tuning, prompt_encoder 三个算法共用的一些属性
- PrefixTuningConfig: peft.tuners.prefix_tuning.PrefixTuningConfig, 设定 `peft_type=PeftType.PREFIX_TUNING`, 再增加了一些其他属性
总之, peft/utils/{peft_type}.py
文件中定义了相应的 Config 类, 而这个类本质上就是附带了 push_to_hub
, save_pretrained
, from_pretrained
的“字典”
PeftConfig
的重要属性有这些:
task_type
: 有时候会跟"model_type"
这个字眼混用(model_type
某些情况下是指 transformers 的模型名称, 例如: bert), 此参数在prefix_tuning
,p_tuning
,prompt_tuning
下必须要指定peft_type
:is_prompt_learning
: 只有prefix_tuning
,p_tuning
,prompt_tuning
这几类为True
is_adaption_prompt
: 只有adaption_prompt
为True
总结如下【待完成】
类型 | PeftConfig类名 | peft_type | task_type | is_prompt_learning | is_adaption_prompt |
lora | lora |
PeftModel
是对外的核心类, 其内部包含 LoraModel
, PrefixEncoder
这种特殊的层, 而 PeftModelForSequenceClassification
是 PeftModel
的子类.
PeftModel
做了可以切换多个 adapter 的设计:
lora_model = PeftModel(model, peft_config: PeftConfig, adapter_name="default")
lora_model.active_adapter # default
lora_model.peft_config # {"default": LoraConfig(...)}
lora_model.active_peft_config == lora_model.peft_config[lora_model.active_adapter]
lora_model.get_base_model() # 必定是 transformers 里的 PretrainedModel
lora_model.base_model # 取决于具体的算法, 如果是 lora, 则是 LoraModel, 如果是 prefix_tuning, 则为 PretrainedModel
下面主要解释这里发生了什么? 我们按重要顺序关注如下几点:
PeftModel.__init__
PeftModel.forward
- 保存模型相关
- 插拔 lora 相关内容
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
model_checkpoint = "google/vit-base-patch16-224-in21k"
model = AutoModelForImageClassification.from_pretrained(
model_checkpoint,
# label2id=label2id,
# id2label=id2label,
ignore_mismatched_sizes=True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)
print_trainable_parameters(model)
config = LoraConfig(
r=16,
lora_alpha=16,
target_modules=["query", "value"],
lora_dropout=0.1,
bias="none",
modules_to_save=["classifier"],
)
lora_model = get_peft_model(model, config) # 在这种情况下, 基本等价于下面这一行
# lora_model = PeftModel(model, config, adapter_name="defalut")
print_trainable_parameters(lora_model)
在上面的例子里, PeftModel(model, config)
中最为关键的步骤是执行 self.base_model = LoraModel(model, config)
peft/tuners/lora.py
实质上就是这几个类的定义
LoraConfig(PeftConfig)
BaseTunerLayer(ABC) # merge/unmerge 抽象方法, peft.tuners.tuners_utils.BaseTunerLayer
- LoraLayer(BaseTunerLayer)
- Linear(nn.Linear, LoraLayer)
- Embedding(nn.Embedding, LoraLayer)
- Conv2D(nn.Conv2D, LoraLayer)
- Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer) # 与 bnb 相关
- Linear4bit(bnb.nn.Linear4bit, LoraLayer) # 与 bnb 相关
- QuantLinear(torch.nn.Module, LoraLayer) # 似乎与 GPTQ 相关
BaseTuner(ABC, nn.Module) # peft.tuners.tuners_utils.BaseTuner
LoraModel(BaseTuner)
首先回顾 LoRA 的细节:
- Linear:
- Conv2d: 假设原本的层为
layer=Conv2d(in_feat, out_feat, conv_size, stride, ...)
, 使用 LoRA 的原理是可以将卷积核拆开计算:import torch import torch.nn as nn import torch.nn.functional as F in_features = 3 out_features = 4 r = 2 kernel_size=(3, 3) stride = 1 padding = 1 H, W = 10, 10 bias = True layer_a = nn.Conv2d(in_features, r, kernel_size, stride, padding, bias=False) layer_b = nn.Conv2d(r, out_features, (1, 1), (1, 1), bias=False) layer = nn.Conv2d(in_features, out_features, kernel_size, stride, padding, bias=bias) def forward(): result = F.conv2d( x, layer.weight, bias=layer.bias, stride=layer.stride, padding=layer.padding, dilation=layer.dilation, ) result = result + layer_b(layer_a(x)) return result def merged_forward(): merged_w = F.conv2d( layer_a.weight.permute(1, 0, 2, 3), layer_b.weight, # (out,r,1,1) ).permute(1, 0, 2, 3) + layer.weight result = F.conv2d( x, merged_w, bias=layer.bias, stride=layer.stride, padding=layer.padding, dilation=layer.dilation, ) return result torch.allclose(forward(), merged_forward()) # True
我们接下来仔细看一下这条继承链 BaseTunerLayer
, LoraLayer
, Linear
的源码
# BaseTunerLayer 只定义了一个类属性 active_adapter, 以及两个抽象方法 merge 与 unmerge
class LoraLayer(BaseTunerLayer):
def __init__(self, in_features: int, out_features: int, **kwargs):
self.r = {}
self.lora_alpha = {}
self.scaling = {}
self.lora_dropout = nn.ModuleDict({})
self.lora_A = nn.ModuleDict({})
self.lora_B = nn.ModuleDict({})
# For Embedding layer
self.lora_embedding_A = nn.ParameterDict({})
self.lora_embedding_B = nn.ParameterDict({})
# Mark the weight as unmerged
self.merged = False
self.disable_adapters = False
self.in_features = in_features
self.out_features = out_features
self.kwargs = kwargs
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
if r > 0:
self.lora_A.update(nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)}))
self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)}))
self.scaling[adapter_name] = lora_alpha / r # 注意这个放缩系数在 forward 中被使用
if init_lora_weights:
self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)
def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): ...
def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): ...
def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B[adapter_name].weight)
if adapter_name in self.lora_embedding_A.keys():
# initialize a the same way as the default for nn.linear and b to zero
nn.init.zeros_(self.lora_embedding_A[adapter_name])
nn.init.normal_(self.lora_embedding_B[adapter_name])
class Linear(nn.Linear, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
adapter_name: str,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
is_target_conv_1d_layer: bool = False,
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out:
self.weight.data = self.weight.data.T
nn.Linear.reset_parameters(self)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
self.is_target_conv_1d_layer = is_target_conv_1d_layer
def merge(self):
if self.active_adapter not in self.lora_A.keys():
return
if self.merged:
warnings.warn("Already merged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
self.weight.data += self.get_delta_weight(self.active_adapter)
self.merged = True
def unmerge(self):
if self.active_adapter not in self.lora_A.keys():
return
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
self.weight.data -= self.get_delta_weight(self.active_adapter)
self.merged = False
def get_delta_weight(self, adapter):
return (
transpose(
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
self.fan_in_fan_out,
)
* self.scaling[adapter]
)
def forward(self, x: torch.Tensor):
previous_dtype = x.dtype
if self.active_adapter not in self.lora_A.keys():
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.disable_adapters:
if self.r[self.active_adapter] > 0 and self.merged:
self.unmerge()
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.r[self.active_adapter] > 0 and not self.merged:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
result += (
self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
)
* self.scaling[self.active_adapter]
)
else:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
result = result.to(previous_dtype)
return result
我们只需要重点以下用法即可:
from peft.tuners.lora import Linear
import torch
def show_fn(layer, x):
print("main weight: ": layer.weight.data[:2, :2])
print("adapters:", layer.lora_A.keys())
print("active_adapter:", layer.active_adapter)
print("active_adapter_delta_weight:", layer.get_delta_weight(layer.active_adapter))
print("forward:", layer(x)[0])
# init_lora_weights 默认是 True, 表示增加的 adapter 的 lora_B 会是 0
layer = Linear(adapter_name="adapter_a", in_features=6, out_features=4, r=3, lora_alpha=1, lora_dropout=0.0, init_lora_weights=False)
layer.training = False
x = torch.rand(2, 6)
show_fn(layer, x)
layer.merge() # 将 "adapter_a" 的额外权重合并至 layer.weight 上
show_fn(layer, x)
layer.update_layer(adapter_name="adapter_b", r=2, lora_alpha=0.6, lora_dropout=0.5, init_lora_weights=False) # 增加一些参数, 这总是安全的, 不会切换 active_adapter, 也不会参数合并
show_fn(layer, x)
# 为确保万一, 任何情况下切换 active_adapter 时都要先 unmerge:
layer.unmerge() # 必须先 unmerge 掉 "adapter_a"
layer.active_adapter = "adapter_b"
layer.merge() # merge "adapter_b"
show_fn(layer, x)
layer.weight, layer.bias # 即 nn.Linear 的 weight 和 bias 参数
layer.lora_A # {adapter_name: nn.Linear} 字典
我们接下来仔细看一下这条继承链或包含关系 BaseTuner
, LoraModel
, PeftModel
的源码
class BaseTuner:
# 以子类 LoraModel 为例, 最主要的目的是替换层, 例如将 nn.Linear 替换为 peft.tuners.lora.Linear
def inject_adapter(self, model, adapter_name): ...
class LoraModel:
# 打开/阻止 adapter 的计算, 实现方式是设置 peft.tuners.lora.Linear 的 disable_adapters 开关
def enable_adapter_layers(self): ...
def disable_adapter_layers(self): ...
# 设定 adapter
def set_adapter(self, adapter_name): ...
# 删除 adapter
def delete_adapter(self, adapter_name): ...
@staticmethod
def _replace_module(parent, child_name, new_module, child):
setattr(parent, child_name, new_module)
new_module.weight = child.weight # 这一行很关键
...
# 主要代码如下: 简单来说就是例如将 peft.tuners.lora.Linear 重新替换回 nn.Linear
def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False):
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
for key in tqdm(key_list, disable=not progressbar,):
parent, target, target_name = _get_submodules(self.model, key)
if isinstance(target, peft.tuners.lora.Linear):
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
if merge:
target.merge()
self._replace_module(parent, target_name, new_module, target)
if isinstance(target, ModulesToSaveWrapper):
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
def merge_and_unload(self, progressbar=False):
self._unload_and_optionally_merge(merge=True, progressbar=progressbar)
def unload(self)
self._unload_and_optionally_merge(merge=False, progressbar=False)
# 如下方法可暂时忽略: 用于合并多个 adapter
def add_weighted_adapter(...): ...
def _svd_weighted_adapter(...): ...
class PeftModel:
def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default"):
super().__init__()
self.base_model = model
self.config = getattr(self.base_model, "config", {"model_type": "custom"})
self.modules_to_save = None
self.peft_config = {}
self.active_adapter = adapter_name
self.peft_type = peft_config.peft_type
if not peft_config.is_prompt_learning: # lora 会进入这里: PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type] 是 LoraModel
self.peft_config[adapter_name] = peft_config
self.base_model = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type](
self.base_model, self.peft_config, adapter_name
)
self.set_additional_trainable_modules(peft_config, adapter_name)
else:
self.add_adapter(adapter_name, peft_config)
# 以下是 gradient_checkpointing 与张量并行相关, 可忽略
if getattr(model, "is_gradient_checkpointing", True):
model = self._prepare_model_for_gradient_checkpointing(model)
if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
self.base_model.config.pretraining_tp = 1
def add_adapter(self, adapter_name: str, peft_config: PeftConfig):
self.peft_config[adapter_name] = peft_config
try:
if peft_config.is_prompt_learning: # 目前只有 prefix_tuning, p_tuning, prompt_tuning
if hasattr(self.config, "to_dict"):
dict_config = self.config.to_dict()
else:
dict_config = self.config
peft_config = _prepare_prompt_learning_config(peft_config, dict_config)
self._setup_prompt_encoder(adapter_name)
elif peft_config.is_adaption_prompt: # 目前只有 adaption_prompt
self.base_model.add_adapter(adapter_name, peft_config)
else: # Lora, ...
self.base_model.inject_adapter(self, adapter_name)
except Exception: # somthing went wrong, roll back
del self.peft_config[adapter_name]
raise
self.set_additional_trainable_modules(peft_config, adapter_name)
def set_additional_trainable_modules(self, peft_config, adapter_name):
if getattr(peft_config, "modules_to_save", None) is not None:
if self.modules_to_save is None:
self.modules_to_save = set(peft_config.modules_to_save)
else:
self.modules_to_save.update(peft_config.modules_to_save)
_set_trainable(self, adapter_name) # 细节从略
def forward(self, *args: Any, **kwargs: Any):
return self.get_base_model()(*args, **kwargs)
# 注意有可能返回 PeftModel, 也可能返回子类 PeftModelForSequenceClassification
@classmethod
def from_pretrained(...): ...
LoraModel
实例化之后怎么增加一个adapter_name
呢?# 方法1: 不确定, 参考了 BaseTuner 的 __init__ 方法 lora_model = LoraModel(...) new_adapter_name, new_peft_config = "new_adapter", LoraConfig(...) lora_model.peft_config.update({new_adapter_name: new_peft_config}) lora_model.inject_adapter(lora.model, new_adapter_name) # 这里有个副作用是会把现在的layer替换掉 lora_model.model.peft_config = lora_model.peft_config # 方法2: 不确定, 参考了 LoraModel 的 _set_adapter_layers 方法 new_adapter_name, new_peft_config = "new_adapter", LoraConfig(...) lora_model.peft_config.update({new_adapter_name: new_peft_config}) for module in self.model.modules(): if isinstance(module, LoraLayer): module.update_layer(...) # 或者 update_layer_conv2d lora_model.model.peft_config = lora_model.peft_config
PeftModel
的子类例如PeftModelForSequenceClassification
主要只是重载了forward
方法
回到最开始, 一共有几种方式初始化一个 “peft model”
# 方式一: PeftModel 的 __init__ 方法
# 方式二: 使用 get_peft_model: 本质上只是分发到 PeftModel, PeftModelForSeq2SeqLM 的 __init__ 方法
def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel:
"""
Returns a Peft model object from a model and a config.
Args:
model ([`transformers.PreTrainedModel`]): Model to be wrapped.
peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
"""
model_config = getattr(model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
return PeftModel(model, peft_config, adapter_name=adapter_name)
if peft_config.is_prompt_learning:
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)
# 方式三: PeftModel.from_pretrained
from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig
peft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)
# 方式四: AutoPeftModel.from_pretrained, 本质上是 AutoModelForXXX.from_pretrained + AutoPeftModelForYYY.from_pretrained
from peft import AutoPeftModelForCausalLM, AutoPeftModel
peft_model = AutoPeftModelForCausalLM.from_pretrained("ybelkada/opt-350m-lora")
peft_model = AutoPeftModel.from_pretrained("smangrul/openai-whisper-large-v2-LORA-colab")
PeftModel.__init__
,PeftModelForSeq2SeqLM.__init__
: 后面几类方式都以此为基础get_peft_model
方法: 分发至 PeftModel, PeftModelForSeq2SeqLM 的__init__
PeftModel.from_pretrained
:__init__
加load_state_dict
AutoPeftModel.from_pretrained
:AutoModelForXXX.from_pretrained
+AutoPeftModelForYYY.from_pretrained
# lora
PeftModel: 包含一个 LoraModel, LoraModel包含一个PretrainedModel, PretrainedModel中的nn.Linear被替换为了peft.tuners.lora.Linear
# p_tuning: 必须是 PeftModelForXXX
PeftModelForCausalLM: 包含一个PretrainedModel, PeftModelForCausalLM 的 forward 在 PretrainedModel 基础上做了修改
各种 peft 总结:
针对 Encoder/Decoder 结构 (PeftModelForCausalLM
):
prompt_tuning
/p_tuning
: 直接将输入的input_embedding
拼接上 prefix, 后面的计算流程不变prefix_tuning
: 输入的input_embedding
不变, 每一层都追加一份key
和value
针对 Encoder-Decoder 结构【待研究】:
prompt_tuning
# config:
# {"prompt_tuning_init_text": "hello", "prompt_tuning_init": "TEXT"}
# {"prompt_tuning_init": "RANDOM"}
class PromptEmbedding(torch.nn.Module):
def __init__(self, config, word_embeddings: nn.Embeding):
# word_embeddings 是 base_model 的 embedding 层
super().__init__()
total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
if config.prompt_tuning_init == PromptTuningInit.TEXT: # TEXT or RANDOM
# 首先使用 word_embededings 将 config.prompt_tuning_init_text 转化为向量, 然后匹配上 total_virtual_tokens 的长度:
# 如果长度不足 total_virtual_tokens, 则重复: [11, 2, 3], 5 -> [11, 2, 3, 11, 2]
# 如果长度超出 total_virtual_tokens, 则截断: [11, 2, 3, 4, 5], 3 -> [11, 2, 3]
# 最后使用 embedding 结果对 self.embedding.weight 进行初始化
...
def forward(self, indices):
prompt_embeddings = self.embedding(indices)
return prompt_embeddings
p_tuning
看上去只是 prompt_tuning
的重参数化
class PromptEncoder(torch.nn.Module):
def __init__(self, config):
...
total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim)
if not config.inference_mode:
if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
self.lstm_head = torch.nn.LSTM(...)
self.mlp_head = torch.nn.Sequential(...) # Linear, Relu, Linear
elif self.encoder_type == PromptEncoderReparameterizationType.MLP: # 默认值
self.mlp_head = torch.nn.Sequential(...) # Linear, Relu, Linear, Relu, Linear
else:
raise ValueError("...")
def forward(self, indices): # inference_mode 为 True 时怎么办, 似乎没有处理 inference_mode 为 True 的情况???
input_embeds = self.embedding(indices)
if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0])
elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
output_embeds = self.mlp_head(input_embeds)
else:
raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")
return output_embeds
prefix_tuning
class PrefixEncoder(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.prefix_projection = config.prefix_projection
token_dim = config.token_dim
num_layers = config.num_layers
encoder_hidden_size = config.encoder_hidden_size
num_virtual_tokens = config.num_virtual_tokens
if self.prefix_projection and not config.inference_mode: # inference_mode 仅有 embedding
self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
self.transform = torch.nn.Sequential(
torch.nn.Linear(token_dim, encoder_hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
)
else:
self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)
def forward(self, prefix: torch.Tensor): # 每一层都有一个 key 和 value, 也就是说增加的这些 virtual token 不走正常的 attention 流程
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.transform(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values