(Alpha) pydantic tutorial
动机、参考资料、涉及内容
openai-python
, langchain
, llama_index
, fastapi
等项目大量用到了 pydantic
, 并且 pydantic
本身也有 V1 版本与 V2 版本的区别. 对 V1 版本的使用方式是 pydantic.v1
Questions:
BaseModel
里的Config
是什么- 装饰器
root_validator
- langchain:
AIMessage.update_forward_refs()
- json-schema, 有可能另起一篇
- 较新版的 fastapi 应该使用 v2?, 但怎么做到客户端可以多传一些参数, 接口调用正常, 并且服务端怎么获取在 BaseModel 校验前的原始参数
V1
例子
本节展示一个例子, 基本上能覆盖大多数使用
from pydantic.v1 import BaseModel, Field, Extra, validator, root_validator
from typing import Annotated, Optional, List
class Request(BaseModel):
# query, temperature, other_notes 展示了几种 type hint 的写法
query: str
temperature: float = Field(description="the temperature", ge=0.0, lt=2.0) # pydantic 会检查 Field 定义的约束
other_notes: Annotated[str, Field(description="tools", examples=["calculator", "python"])]
stop_words: Optional[List[str]] = None
# pydantic 的一些内置检查选项
class Config:
max_anystr_length = 10 # 任何字符串形式的字段长度不超过 10
extra = Extra.forbid # 禁止传入多余字段
# 通过指定 pre=True 先于后面的 validate_stop_word_length 检查
@validator("stop_words", pre=True)
def split_stop_words(cls, v):
if isinstance(v, str):
return v.split("|")
return v
@validator("stop_words")
def validate_stop_word_length(cls, v):
# 至多只能设置 4 个 stop word
if len(v) > 4:
raise ValueError(f'stop words more than 4')
return v # 注意需要返回数据
# 可以对多个字段采用相同的检查
@validator("query", "other_notes")
def validate_min_length(cls, v):
if len(v) == 0:
raise ValueError(f"empty string")
return v
# 对整个数据结构进行整体检查
@root_validator
def validate_context_length(cls, values):
query = values.get("query")
other_notes = values.get("other_notes")
if len(query) + len(other_notes) > 15:
raise ValueError("context length more than 15")
return values
req = Request(temperature=1.0, other_notes="note note", query="2+3", stop_words=["2", "3", "4"])
req = Request(temperature=1.0, other_notes="calculate", query="1+1", stop_words="2|3|4")
# err = Request(temperature=1.0, other_notes="calculate", query="1+1", stop_words="2|3|4", xx = 2) # Error!
print(req.dict()) # 转换为字典, v2 应该使用 model_dump
print(Request.schema()) # 输出 json schema, v2 应该使用 model_json_schema
输出:
# req.dict()
{'query': '1+1',
'temperature': 1.0,
'other_notes': 'calculate',
'stop_words': ['2', '3', '4']}
# Request.schema()
{'title': 'Request',
'type': 'object',
'properties': {'query': {'title': 'Query', 'type': 'string'},
'temperature': {'title': 'Temperature',
'description': 'the temperature',
'exclusiveMaximum': 2.0,
'minimum': 0.0,
'type': 'number'},
'other_notes': {'title': 'Other Notes',
'description': 'tools',
'examples': ['calculator', 'python'],
'type': 'string'},
'stop_words': {'title': 'Stop Words',
'type': 'array',
'items': {'type': 'string'}}},
'required': ['query', 'temperature', 'other_notes'],
'additionalProperties': False}
这个例子用 pydantic V2 写如下: 总的来说差异还是比较多的, 主要是各种方法名, 字段名的修改
from pydantic import BaseModel, Field, model_validator, field_validator, ConfigDict
from typing import Annotated, Optional, List
class Request(BaseModel):
query: str
temperature: float = Field(description="the temperature", ge=0.0, lt=2.0) # pydantic 会检查 Field 定义的约束
other_notes: Annotated[str, Field(description="tools", examples=["calculator", "python"])]
stop_words: Optional[List[str]] = None
# Config 类变成了一个字段: model_config
# Extra.forbit 变成了字符串 "forbid"
model_config = ConfigDict(str_max_length=10, extra="forbid")
# 注意 pre=True/False 改为了 mode="after"/"before"
# validator (V1) -> field_validator (V2)
@field_validator("stop_words", mode="before")
@classmethod # 注意需要增加 classmethod 装饰器, 且需要位于 field_validator 之后
def split_stop_words(cls, v):
if isinstance(v, str):
return v.split("|")
return v
@field_validator("stop_words")
@classmethod
def validate_stop_word_length(cls, v):
if len(v) > 4:
raise ValueError(f'stop words more than 4')
return v
@field_validator("query", "other_notes")
@classmethod
def validate_min_length(cls, v):
if len(v) == 0:
raise ValueError(f"empty string")
return v
# root_validator -> model_validator
@model_validator(mode="after")
@classmethod # 注意需要增加 classmethod 装饰器, 且需要位于 model_validator 之后
def validate_context_length(cls, values):
query = values.query # 注意 V2 用的点运算符, values: Request
other_notes = values.other_notes
if len(query) + len(other_notes) > 15:
raise ValueError("context length more than 15")
return values
req = Request(temperature=1.0, other_notes="note note", query="2+3", stop_words=["2", "3", "4"])
req = Request(temperature=1.0, other_notes="calculate", query="1+1", stop_words="2|3|4")
# err = Request(temperature=1.0, other_notes="calculate", query="1+1", stop_words="2|3|4", xx = 2) # Error!
print(req.model_dump()) # dict -> model_dump
print(Request.model_json_schema()) # schema -> model_json_schema
Type Hint & Field & Annotated
继承自 BaseModel
的类的属性必须有 type hint, 有以下三种方式:
- 只使用普通的 type hint: 这种情况下, pydantic 会去校验数据项是否满足类型约束
- 使用普通的 type hint, 再补充一个
Field
: 这种情况下, pydantic 会去校验数据项是否满足类型约束, 并且会检查Field
中描述的约束 - 使用
typing.Annotated
, 本质上与第二种方法一样.
以下是一个例子
from typing import Annotated
from pydantic import BaseModel, Field
class MyModel(BaseModel):
a: str
b: str = Field(default="abc", title="bbb") # 这种写法兼容性较高
c: Annotated[str, Field(title="ccc")] = "abc" # 注意默认值的优先级是先看等号后面的, 再看 Field 里面的 default 字段
备注:
data: typing.Annotated[T, x]
是对普通的 type hint 的增强, 其中 T
时类型名, x
是任意数据, 代表 metadata. 在 python 运行时, 无论是 type hint 以及 metadata, 都不会对 data
本身做校验. 但 pydantic 会利用这些信息进行数据校验.
x: Annotated[str, "desc"] = "123"
Annotated[str, "desc"].__metadata__ # ("desc",)
Field
实际上是一个函数, 其返回类型是 FieldInfo
Field(description="the temperature", ge=0.0, lt=2.0)
# FieldInfo(default=PydanticUndefined, description='the temperature', ge=0.0, lt=2.0, extra={})
Validator
字段校验次序参考 https://docs.pydantic.dev/1.10/usage/models/#field-ordering, 简单来说与字段定义的书写顺序相关, 也与 validator(pre=True)
里的 pre
参数相关.
Config
Config 是 pydantic 内置的一些校验方法, 而 Validator 是自定义的校验手段
V1 to V2
感觉 API 变化很大, 不理解为什么要从 V1 升到 V2 (TODO)
llama_index
(v0.9.31, 发布时间 2024/1/16) 使用的是 V1langchain
(v0.1.0, 发布时间 2024/01/06): 似乎在试图兼容 V1 与 V2, 但是否实际都是使用 pydantic.v1?openai-python
(v1.2.3, 发布时间 2023/11/10): 似乎在试图兼容 V1 与 V2, 但是否实际都是使用 pydantic.v1?fastapi
: 不确定?
bump-pydantic
以 add default none 为例, 探究实现细节.
转换前
# repo_folder/my_package/a.py
from pydantic import BaseModel
class Foo(BaseModel):
bar: Optional[str]
baz: Union[str, None]
qux: Any
转换方法
pip install bump-pydantic
cd repo_folder
bump-pydantic my_package
转换后
# repo_folder/my_package/a.py
from pydantic import BaseModel
class Foo(BaseModel):
bar: Optional[str] = None
baz: Union[str, None] = None
qux: Any = None
大致原理是利用 libcst 构建 concrete syntax tree, 并且这个过程是无损的 (libcst 可以精确还原为原始代码, 而 python 自带的 ast 模块是有损的, 特别地, ast 不能精确还原空格和空行), 下面是一个简单的例子:
import libcst
module = libcst.parse_module("a =(( 1+2))") # 一个树状的数据结构
code = module.code # "a =(( 1+2))"
bump-pydantic
(以 add default none 为例) 的实现原理就是修改 module
, 然后还原为代码, 具体实现上则借助了 libcst 的一些内置工具, 大略如下.
# 参考自: bump_pydantic/codemods/add_default_none.py
import libcst as cst
import libcst.matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.metadata import FullyQualifiedNameProvider, QualifiedName
from libcst.metadata import FullRepoManager
class AddDefaultNoneCommand(VisitorBasedCodemodCommand):
# 这里的 ClassDef 与 AnnAssign 是 libcst 中的节点类型, 在调用 visit 方法时, 会触发下面的这些方法:
# 判断是否在 BaseModel 内
def visit_ClassDef(self, node: cst.ClassDef) -> None: ...
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: ...
# 判断是否需要添加 `= None`
def visit_AnnAssign(self, node: cst.AnnAssign) -> None: ...
# 实现修改 `bar: Optional[str]` 为 `bar: Optional[str] = None` 的逻辑
def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.AnnAssign: ...
tmpdir, module = "./", "./package_name/a.py"
mrg = FullRepoManager(tmpdir, {module}, providers={FullyQualifiedNameProvider})
wrapper = mrg.get_metadata_wrapper_for_path(module)
context = CodemodContext(wrapper=wrapper)
command = AddDefaultNoneCommand(context=context) # type: ignore[assignment]
mod = wrapper.visit(command)
print(mod.code)
pydantic 的更多使用例子
例子 1: A 属性的值由 B 属性确定
要点: 优先考虑 @property
装饰器 以及 @pydantic.root_validator
装饰器
方案 1: B 属性也定为了 field
方案 1.1: 使用 root_validator
参考: langchain_community/embeddings/baichuan.py
: BaichuanTextEmbeddings:session
from pydantic.v1 import BaseModel, root_validator
from typing import List, Dict, Any
class A(BaseModel):
embedding_functions: Dict[str, Any]
vector_names: List[str] #: :meta private:
@root_validator(pre=True)
def add_vector_names(cls, values):
values["vector_names"] = list(values["embedding_functions"].keys())
return values
@classmethod
def from_functions(cls, embedding_functions):
return cls(embedding_functions=embedding_functions)
embedding_functions = {"a": lambda x: [1]}
a = A.from_functions(embedding_functions)
a = A(embedding_functions=embedding_functions)
方案 1.2: 重载 __init__
参考: langchain_community/retrievers/arcee.py
: _client
这种写法应该是方案 1.1 的备选方案, 不甚优雅. 注意如果将 self.vector_names=...
放在 super().__init__(...)
之后时, 必须将 vector_name
的 type hist 设置为 Optional
.
from pydantic.v1 import BaseModel, root_validator
from typing import List, Dict, Any
class A(BaseModel):
embedding_functions: Dict[str, Any]
vector_names: Optional[List[str]] = None #: :meta private:
def __init__(self, **data):
super().__init__(**data)
self.vector_names = list(embedding_functions.values())
@classmethod
def from_functions(cls, embedding_functions):
return cls(embedding_functions=embedding_functions)
embedding_functions = {"a": lambda x: [1]}
a = A.from_functions(embedding_functions)
a = A(embedding_functions=embedding_functions)
方案 2: 使用 property 装饰器, 可以避免将 B 属性设置为 field
方案 2.1
此写法参考: langchain_community/retrievers/azure_ai_search.py
: _headers
from pydantic.v1 import BaseModel, root_validator
from typing import List, Dict, Any
class A(BaseModel):
embedding_functions: Dict[str, Any]
@property
def vector_names(self):
return len(self.embedding_functions)
@classmethod
def from_functions(cls, embedding_functions):
return cls(embedding_functions=embedding_functions)
embedding_functions = {"a": lambda x: [1]}
a = A.from_functions(embedding_functions)
a = A(embedding_functions=embedding_functions)
方案 2.2
如果 B 属性的计算比较费时, 并且不会根据 A 属性产生变化, 可以参考这个写法, 但不够优雅. 写法来源于 GPT-4
from pydantic.v1 import BaseModel, root_validator
from typing import List, Dict, Any
class A(BaseModel):
embedding_functions: Dict[str, Any]
@property
def vector_names(self):
if "_vector_name" not in self.__dict__:
self.__dict__["_vector_name"] = len(self.embedding_functions)
return self.__dict__["_vector_name"]
@classmethod
def from_functions(cls, embedding_functions):
return cls(embedding_functions=embedding_functions)
方案 2.3
这个例子不太贴合需求: langchain_core/retrievers.py
: BaseRetriever.__init_subclass__