动机、参考资料、涉及内容

openai-python, langchain, llama_index, fastapi 等项目大量用到了 pydantic, 并且 pydantic 本身也有 V1 版本与 V2 版本的区别. 对 V1 版本的使用方式是 pydantic.v1

Questions:

  • BaseModel 里的 Config 是什么
  • 装饰器 root_validator
  • langchain: AIMessage.update_forward_refs()
  • json-schema, 有可能另起一篇
  • 较新版的 fastapi 应该使用 v2?, 但怎么做到客户端可以多传一些参数, 接口调用正常, 并且服务端怎么获取在 BaseModel 校验前的原始参数

V1

例子

本节展示一个例子, 基本上能覆盖大多数使用

from pydantic.v1 import BaseModel, Field, Extra, validator, root_validator
from typing import Annotated, Optional, List
class Request(BaseModel):
    # query, temperature, other_notes 展示了几种 type hint 的写法
    query: str
    temperature: float = Field(description="the temperature", ge=0.0, lt=2.0)  # pydantic 会检查 Field 定义的约束 
    other_notes: Annotated[str, Field(description="tools", examples=["calculator", "python"])]
    stop_words: Optional[List[str]] = None

    # pydantic 的一些内置检查选项
    class Config:
        max_anystr_length = 10  # 任何字符串形式的字段长度不超过 10
        extra = Extra.forbid  # 禁止传入多余字段

    # 通过指定 pre=True 先于后面的 validate_stop_word_length 检查
    @validator("stop_words", pre=True)
    def split_stop_words(cls, v):
        if isinstance(v, str):
            return v.split("|")
        return v

    @validator("stop_words")
    def validate_stop_word_length(cls, v):
        # 至多只能设置 4 个 stop word
        if len(v) > 4:
            raise ValueError(f'stop words more than 4')
        return v  # 注意需要返回数据
    
    # 可以对多个字段采用相同的检查
    @validator("query", "other_notes")
    def validate_min_length(cls, v):
        if len(v) == 0:
            raise ValueError(f"empty string")
        return v
    
    # 对整个数据结构进行整体检查
    @root_validator
    def validate_context_length(cls, values):
        query = values.get("query")
        other_notes = values.get("other_notes")
        if len(query) + len(other_notes) > 15:
            raise ValueError("context length more than 15")
        return values


req = Request(temperature=1.0, other_notes="note note", query="2+3", stop_words=["2", "3", "4"])
req = Request(temperature=1.0, other_notes="calculate", query="1+1", stop_words="2|3|4")
# err = Request(temperature=1.0, other_notes="calculate", query="1+1", stop_words="2|3|4", xx = 2)  # Error!
print(req.dict())  # 转换为字典, v2 应该使用 model_dump
print(Request.schema())  # 输出 json schema, v2 应该使用 model_json_schema

输出:

# req.dict()
{'query': '1+1',
 'temperature': 1.0,
 'other_notes': 'calculate',
 'stop_words': ['2', '3', '4']}

# Request.schema()
{'title': 'Request',
 'type': 'object',
 'properties': {'query': {'title': 'Query', 'type': 'string'},
  'temperature': {'title': 'Temperature',
   'description': 'the temperature',
   'exclusiveMaximum': 2.0,
   'minimum': 0.0,
   'type': 'number'},
  'other_notes': {'title': 'Other Notes',
   'description': 'tools',
   'examples': ['calculator', 'python'],
   'type': 'string'},
  'stop_words': {'title': 'Stop Words',
   'type': 'array',
   'items': {'type': 'string'}}},
 'required': ['query', 'temperature', 'other_notes'],
 'additionalProperties': False}

这个例子用 pydantic V2 写如下: 总的来说差异还是比较多的, 主要是各种方法名, 字段名的修改

from pydantic import BaseModel, Field, model_validator, field_validator, ConfigDict
from typing import Annotated, Optional, List


class Request(BaseModel):
    query: str
    temperature: float = Field(description="the temperature", ge=0.0, lt=2.0)  # pydantic 会检查 Field 定义的约束 
    other_notes: Annotated[str, Field(description="tools", examples=["calculator", "python"])]
    stop_words: Optional[List[str]] = None
    # Config 类变成了一个字段: model_config
    # Extra.forbit 变成了字符串 "forbid"
    model_config = ConfigDict(str_max_length=10, extra="forbid")

    # 注意 pre=True/False 改为了 mode="after"/"before"
    # validator (V1) -> field_validator (V2)
    @field_validator("stop_words", mode="before")
    @classmethod  # 注意需要增加 classmethod 装饰器, 且需要位于 field_validator 之后
    def split_stop_words(cls, v):
        if isinstance(v, str):
            return v.split("|")
        return v

    @field_validator("stop_words")
    @classmethod
    def validate_stop_word_length(cls, v):
        if len(v) > 4:
            raise ValueError(f'stop words more than 4')
        return v
    
    @field_validator("query", "other_notes")
    @classmethod
    def validate_min_length(cls, v):
        if len(v) == 0:
            raise ValueError(f"empty string")
        return v
    
    # root_validator -> model_validator
    @model_validator(mode="after")
    @classmethod  # 注意需要增加 classmethod 装饰器, 且需要位于 model_validator 之后
    def validate_context_length(cls, values):
        query = values.query  # 注意 V2 用的点运算符, values: Request
        other_notes = values.other_notes
        if len(query) + len(other_notes) > 15:
            raise ValueError("context length more than 15")
        return values


req = Request(temperature=1.0, other_notes="note note", query="2+3", stop_words=["2", "3", "4"])
req = Request(temperature=1.0, other_notes="calculate", query="1+1", stop_words="2|3|4")
# err = Request(temperature=1.0, other_notes="calculate", query="1+1", stop_words="2|3|4", xx = 2)  # Error!
print(req.model_dump())  # dict -> model_dump
print(Request.model_json_schema())   # schema -> model_json_schema

Type Hint & Field & Annotated

继承自 BaseModel 的类的属性必须有 type hint, 有以下三种方式:

  • 只使用普通的 type hint: 这种情况下, pydantic 会去校验数据项是否满足类型约束
  • 使用普通的 type hint, 再补充一个 Field: 这种情况下, pydantic 会去校验数据项是否满足类型约束, 并且会检查 Field 中描述的约束
  • 使用 typing.Annotated, 本质上与第二种方法一样.

以下是一个例子

from typing import Annotated
from pydantic import BaseModel, Field
class MyModel(BaseModel):
    a: str
    b: str = Field(default="abc", title="bbb")     # 这种写法兼容性较高
    c: Annotated[str, Field(title="ccc")] = "abc"  # 注意默认值的优先级是先看等号后面的, 再看 Field 里面的 default 字段

备注:

data: typing.Annotated[T, x] 是对普通的 type hint 的增强, 其中 T 时类型名, x 是任意数据, 代表 metadata. 在 python 运行时, 无论是 type hint 以及 metadata, 都不会对 data 本身做校验. 但 pydantic 会利用这些信息进行数据校验.

x: Annotated[str, "desc"] = "123"
Annotated[str, "desc"].__metadata__  # ("desc",)

Field 实际上是一个函数, 其返回类型是 FieldInfo

Field(description="the temperature", ge=0.0, lt=2.0)
# FieldInfo(default=PydanticUndefined, description='the temperature', ge=0.0, lt=2.0, extra={})

Validator

字段校验次序参考 https://docs.pydantic.dev/1.10/usage/models/#field-ordering, 简单来说与字段定义的书写顺序相关, 也与 validator(pre=True) 里的 pre 参数相关.

Config

Config 是 pydantic 内置的一些校验方法, 而 Validator 是自定义的校验手段

V1 to V2

感觉 API 变化很大, 不理解为什么要从 V1 升到 V2 (TODO)

  • llama_index (v0.9.31, 发布时间 2024/1/16) 使用的是 V1
  • langchain (v0.1.0, 发布时间 2024/01/06): 似乎在试图兼容 V1 与 V2, 但是否实际都是使用 pydantic.v1?
  • openai-python (v1.2.3, 发布时间 2023/11/10): 似乎在试图兼容 V1 与 V2, 但是否实际都是使用 pydantic.v1?
  • fastapi: 不确定?

bump-pydantic

add default none 为例, 探究实现细节.

转换前

# repo_folder/my_package/a.py
from pydantic import BaseModel

class Foo(BaseModel):
    bar: Optional[str]
    baz: Union[str, None]
    qux: Any

转换方法

pip install bump-pydantic
cd repo_folder
bump-pydantic my_package

转换后

# repo_folder/my_package/a.py
from pydantic import BaseModel

class Foo(BaseModel):
    bar: Optional[str] = None
    baz: Union[str, None] = None
    qux: Any = None

大致原理是利用 libcst 构建 concrete syntax tree, 并且这个过程是无损的 (libcst 可以精确还原为原始代码, 而 python 自带的 ast 模块是有损的, 特别地, ast 不能精确还原空格和空行), 下面是一个简单的例子:

import libcst
module = libcst.parse_module("a =(( 1+2))")  # 一个树状的数据结构
code = module.code  # "a =(( 1+2))"

bump-pydantic (以 add default none 为例) 的实现原理就是修改 module, 然后还原为代码, 具体实现上则借助了 libcst 的一些内置工具, 大略如下.

# 参考自: bump_pydantic/codemods/add_default_none.py
import libcst as cst
import libcst.matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import FullyQualifiedNameProvider, QualifiedName
from libcst.metadata import FullRepoManager

class AddDefaultNoneCommand(VisitorBasedCodemodCommand):
    # 这里的 ClassDef 与 AnnAssign 是 libcst 中的节点类型, 在调用 visit 方法时, 会触发下面的这些方法:
    # 判断是否在 BaseModel 内
    def visit_ClassDef(self, node: cst.ClassDef) -> None: ...
    def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: ...
    # 判断是否需要添加 `= None`
    def visit_AnnAssign(self, node: cst.AnnAssign) -> None: ...
    # 实现修改 `bar: Optional[str]` 为 `bar: Optional[str] = None` 的逻辑
    def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign: ...

tmpdir, module = "./", "./package_name/a.py"
mrg = FullRepoManager(tmpdir, {module}, providers={FullyQualifiedNameProvider})
wrapper = mrg.get_metadata_wrapper_for_path(module)
context = CodemodContext(wrapper=wrapper)

command = AddDefaultNoneCommand(context=context)  # type: ignore[assignment]
mod = wrapper.visit(command)
print(mod.code)

pydantic 的更多使用例子

例子 1: A 属性的值由 B 属性确定

要点: 优先考虑 @property 装饰器 以及 @pydantic.root_validator 装饰器

方案 1: B 属性也定为了 field

方案 1.1: 使用 root_validator

参考: langchain_community/embeddings/baichuan.py: BaichuanTextEmbeddings:session

from pydantic.v1 import BaseModel, root_validator
from typing import List, Dict, Any

class A(BaseModel):
    embedding_functions: Dict[str, Any]
    vector_names: List[str]  #: :meta private:
    
    @root_validator(pre=True)
    def add_vector_names(cls, values):
        values["vector_names"] = list(values["embedding_functions"].keys())
        return values

    @classmethod
    def from_functions(cls, embedding_functions): 
        return cls(embedding_functions=embedding_functions)

embedding_functions = {"a": lambda x: [1]}
a = A.from_functions(embedding_functions)
a = A(embedding_functions=embedding_functions)

方案 1.2: 重载 __init__

参考: langchain_community/retrievers/arcee.py: _client

这种写法应该是方案 1.1 的备选方案, 不甚优雅. 注意如果将 self.vector_names=... 放在 super().__init__(...) 之后时, 必须将 vector_name 的 type hist 设置为 Optional.

from pydantic.v1 import BaseModel, root_validator
from typing import List, Dict, Any

class A(BaseModel):
    embedding_functions: Dict[str, Any]
    vector_names: Optional[List[str]] = None  #: :meta private:
    
    def __init__(self, **data):
        super().__init__(**data)
        self.vector_names = list(embedding_functions.values())

    @classmethod
    def from_functions(cls, embedding_functions): 
        return cls(embedding_functions=embedding_functions)

embedding_functions = {"a": lambda x: [1]}
a = A.from_functions(embedding_functions)
a = A(embedding_functions=embedding_functions)

方案 2: 使用 property 装饰器, 可以避免将 B 属性设置为 field

方案 2.1

此写法参考: langchain_community/retrievers/azure_ai_search.py: _headers

from pydantic.v1 import BaseModel, root_validator
from typing import List, Dict, Any

class A(BaseModel):
    embedding_functions: Dict[str, Any]
    
    @property
    def vector_names(self):
        return len(self.embedding_functions)

    @classmethod
    def from_functions(cls, embedding_functions): 
        return cls(embedding_functions=embedding_functions)

embedding_functions = {"a": lambda x: [1]}
a = A.from_functions(embedding_functions)
a = A(embedding_functions=embedding_functions)

方案 2.2

如果 B 属性的计算比较费时, 并且不会根据 A 属性产生变化, 可以参考这个写法, 但不够优雅. 写法来源于 GPT-4

from pydantic.v1 import BaseModel, root_validator
from typing import List, Dict, Any

class A(BaseModel):
    embedding_functions: Dict[str, Any]
    
    @property
    def vector_names(self):
        if "_vector_name" not in self.__dict__:
            self.__dict__["_vector_name"] = len(self.embedding_functions)
        return self.__dict__["_vector_name"]

    @classmethod
    def from_functions(cls, embedding_functions): 
        return cls(embedding_functions=embedding_functions)

方案 2.3

这个例子不太贴合需求: langchain_core/retrievers.py: BaseRetriever.__init_subclass__