(Ready) tiktoken 详解
动机、参考资料、涉及内容
- (Ready) tiktoken 的实现细节, 包括 python 实现与 rust 实现, 完全以学习的角度看: https://github.com/openai/tiktoken
- (Unknown) 为什么 tiktoken 快过 huggingface 的实现: 具体的测试脚本? huggingface 的 FastTokenizer 也是用 Rust 实现的, 为什么会慢?
tiktoken
使用
tiktoken 是 BPE 算法的实现, 最重要的特点是 encode 与 decode 是无损的
简单使用
from tiktoken import get_encoding
from typing import List
enc = get_encoding("cl100k_base")
text = "hello world"
tokens: List[int] = enc.encode("hello world")
decode_text: str = enc.decode(tokens)
assert text == decode_text
关于 special_token 的处理
enc.encode("<|endofprompt|>", allowed_special=set(), disallowed_special="all") # 默认值, 触发 Error
enc.encode("<|endofprompt|>", allowed_special="all", disallowed_special="all") # [100276]
# 源码中的处理逻辑如下:
if allowed_special == "all":
allowed_special = self.special_tokens_set
if disallowed_special == "all":
disallowed_special = self.special_tokens_set - allowed_special
# 只要 disallowed_special 不为空, 输入的 text 就不能包含 disallowed_special 中的字符
if disallowed_special:
if match := _special_token_regex(disallowed_special).search(text):
raise_disallowed_special_token(match.group())
# 这个 _core_bpe 的类型是 Rust 绑定到 Python 的
return self._core_bpe.encode(text, allowed_special)
前置说明
下面以 cl100k_base
为例来分析 enc = get_encoding("cl100k_base")
的具体流程 (gpt3.5, gpt4, text-embedding-ada-002 这几个模型都是使用这个 tokenizer)
# tiktoken.get_encoding("cl100k_base") 的实际调用流程
ENDOFTEXT = "<|endoftext|>"
FIM_PREFIX = "<|fim_prefix|>"
FIM_MIDDLE = "<|fim_middle|>"
FIM_SUFFIX = "<|fim_suffix|>"
ENDOFPROMPT = "<|endofprompt|>"
def cl100k_base():
# cl100k_base.tiktoken 本身是一个文本文件, 每一行是 base64 表示的 token 表示, 以及相应的 token 序号
mergeable_ranks: dict[bytes, int] = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7",
)
# load_tiktoken_bpe 的实际过程就是下载文件(或者从缓存中读取), 然后像这样解析
# lines = open("cl100k_base.tiktoken", "rb").read().splitlines()
# mergeable_ranks = {
# base64.b64decode(token): int(rank)
# for token, rank in (line.split() for line in lines if line)
# }
special_tokens = {
ENDOFTEXT: 100257,
FIM_PREFIX: 100258,
FIM_MIDDLE: 100259,
FIM_SUFFIX: 100260,
ENDOFPROMPT: 100276,
}
return {
"name": "cl100k_base",
"pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""",
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
tokenizer = Encoding(**cl100k_base())
值得注意点主要包括 mergeable_ranks
, pat_str
, Encoding
这几处
(1) 这里先探索一下 mergeable_ranks
- 前 256 个 token 完全覆盖了
\x00
~\xff
这2^8=256
个单字节字符import base64 lines = open("cl100k_base.tiktoken", "rb").read().splitlines() mergeable_ranks = { base64.b64decode(token): int(rank) for token, rank in (line.split() for line in lines if line) } a = set([ord(token) for token, num in list(mergeable_ranks.items())[:256]]) b = set([i for i in range(256)]) a == b # True
mergeable_ranks
长度为 100256, 因此特殊 token 的序号刚好错开all([i == num for i, (token, num) in enumerate(list(mergeable_ranks.items()))]) == True # True len(mergeable_ranks) # 100256
(2) 这里的 Encoding
主要是对 rust 实现的 tokenizer 的封装, 相应的代码如下
# tiktoken/core.py
class Encoding:
def __init__(self, ...):
# 其余代码从略 ...
self._special_tokens = special_tokens
# _tiktoken 是一个 rust 编译出的的Python扩展动态链接库, 在 pip install tiktoken 后在硬盘上大约位于
# site-packages/tiktoken/_tiktoken.cpython-39-x86_64-linux-gnu.so
self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str)
@functools.cached_property
def special_tokens_set(self) -> set[str]:
return set(self._special_tokens.keys())
def encode(
self,
text: str,
*,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]:
if allowed_special == "all":
allowed_special = self.special_tokens_set
if disallowed_special == "all":
disallowed_special = self.special_tokens_set - allowed_special
if disallowed_special: # 如果 disallowed_special = set("")
if not isinstance(disallowed_special, frozenset):
disallowed_special = frozenset(disallowed_special)
if match := _special_token_regex(disallowed_special).search(text):
raise_disallowed_special_token(match.group())
if isinstance(allowed_special, frozenset):
allowed_special = set(allowed_special)
try:
# 大多数情况是进入这个分支
return self._core_bpe.encode(text, allowed_special)
except UnicodeEncodeError: # 如果有兴趣, 这个分支可参考源码说明
text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
return self._core_bpe.encode(text, allowed_special)
def decode(self, tokens: list[int], errors: str = "replace") -> str:
return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors)
由此可以看出工作完全是交接给 Rust 实现的 CoreBPE, 这便是本文的目标
(3) 这个 pat_str
用在预切分, 这里以 gpt-2 举例, 这个正则的特征是能将原始文本完全切分, 也就是切分完后可以重新精确恢复原始文本.
import regex
pat = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
text = "a's 1,123 abc 中国人"
parts = regex.findall(pat, text)
# ['a', "'s", ' 1', ',', '123', ' ', ' abc', ' ', ' 中国人']
# 可以看出对英文比较友好, 而连续的汉字会切分不开
"".join(parts) == text # True, 这个性质是重点!!!!
总的来说就是一堆“或”关系的匹配, 具体可以拆解为:
'(?:[sdmt]|ll|ve|re)
: 实际上就是代表匹配's|'d|'m|'t|'ll|'ve|'re
, 基本上就是常见的缩写- ` ?\p{L}+`: 多个连续的 Unicode 文字
- ` ?\p{N}+`: 多个连续的 Unicode 数字
- ` ?[^\s\p{L}\p{N}]+`: 非空白, 非 Unicode 文字,
\s+(?!\S)
: 多个空白, 但不包含最后一个\s+
: 多个空白
备注: Unicode 字符集除了 p{L}
(文字) 和 p{N}
(数字) 外, 还有符号, 空白符, emoji 等等
备注: 最前面的 ?:
只是代表是非捕获组, 解释如下:
import re
text = "today is 2024-05-15, tomorrow is 2024-05-16"
pat1 = r"(\d{4})-(\d{2})-(\d{2})"
pat2 = r"(?:\d{4})-(\d{2})-(\d{2})"
print("捕获组")
for result in re.findall(pat1, text):
print(result)
print("非捕获组")
for result in re.findall(pat2, text):
print(result)
输出结果如下(其实只是结果里不包含非捕获组而已)
捕获组
('2024', '05', '15')
('2024', '05', '16')
非捕获组
('05', '15')
('05', '16')
cl100k_base 的这个 pat_str
如下
pat_str = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
值得注意的是, 这个 “正则” 不符合 python 的标注: 例如里面包含这种子序列 ?+
, 这对于 python 正则是不合法的, 而像 \p{L}
代表匹配任意的 Unicode 文本字符, 这也同样不能用 python 的 re 模块所解析 (可以借助第三方包 regex
). 后续会重新看看这个正则(从 Rust 的实现来看).
bpe 算法
完整实现请直接参考 tiktoken/_education.py
关于 bytes
的注解:
bytes(3) # 表示 3 个 \x00, 即: b'\x00\x00\x00'
bytes([3, 4]) # 表示 b'\x03\x04'
b''.join([bytes([3]), bytes([4])]) # [b'\x03', b'\x04'] -> b'\x03\x04'
train
1) 初始化:
import regex
pat_str = r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
data = "中国人, haven't"
words: list[list[bytes]] = [
[bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(pat_str, data)
]
# 具体过程为:
# regex.findall(pat_str, data) # ['中国人', ',', ' haven', "'t"]
# 然后再将每个部分分解为字节:
# words = [
# [b'\xe4', b'\xb8', b'\xad', b'\xe5', b'\x9b', b'\xbd', b'\xe4', b'\xba', b'\xba'],
# [b','],
# [b' ', b'h', b'a', b'v', b'e', b'n'],
# [b"'", b't']
# ]
# assert b''.join(words[0]).decode() == "中国人"
2) 迭代:
while len(ranks) < vocab_size: # vocab_size 是期望的词表大小
# step 1: 找到最常见的紧邻的字节
stats = collections.Counter()
for piece in words:
for pair in zip(piece[:-1], piece[1:]):
stats[pair] += 1
most_common_pair = max(stats, key=lambda x: stats[x])
token_bytes = most_common_pair[0] + most_common_pair[1]
token = len(ranks)
ranks[token_bytes] = token
# step 2: 用新添加的字节 pair 对 words 合并
new_words = []
for word in words:
new_word = []
i = 0
while i < len(word) - 1:
if (word[i], word[i + 1]) == most_common_pair:
# We found our pair! Merge it
new_word.append(token_bytes)
i += 2
else:
new_word.append(word[i])
i += 1
if i == len(word) - 1:
new_word.append(word[i])
new_words.append(new_word)
words = new_words
推理
def encode(self, text: str, visualise: Optional[str] = "colour") -> list[int]:
words = self._pat.findall(text)
tokens = []
for word in words:
word_bytes = word.encode("utf-8")
word_tokens = bpe_encode(self.mergeable_ranks, word_bytes)
tokens.extend(word_tokens)
return tokens
def bpe_encode(mergeable_ranks: dict[bytes, int], input: bytes) -> list[int]:
parts = [bytes([b]) for b in input]
while True:
# 每次进入都只会合并一次
min_idx = None
min_rank = None
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
rank = mergeable_ranks.get(pair[0] + pair[1])
# 这里 rank < min_rank 的作用是保证每次合并的是 rank 最小的 pair (rank 越小表示这是 train 过程越早发现的 pair)
if rank is not None and (min_rank is None or rank < min_rank):
min_idx = i
min_rank = rank
if min_rank is None:
break
assert min_idx is not None
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
tokens = [mergeable_ranks[part] for part in parts]
return tokens
关于 rank < min_rank
的补充解释
mergeable_ranks = {k.encode("utf-8"): v for k, v in {"a": 1, "b": 2, "c": 3, "bc": 89, "ab": 100}.items()}
bpe_encode(mergeable_ranks, 'abc'.encode('utf-8')) # [1, 89]
# 注意这里有两种合并规则: ['a', 'bc'] 和 ['ab', 'c'], 由于 {'bc': 89, 'ab': 100}, 所以 'bc' 更优先
tiktoken 的 Rust 实现
- https://github.com/youkaichao/fast_bpe_tokenizer: 这个仓库包含了一些对 tiktoken 的探索, 但注意无法保证作者的行文是否是统一的 (很可能像本文一样前后不一致)
- 上一节是
tiktoken._education
实现的精确描述, 从本节的叙述可以看出, merge 的规则 Rust 的实现与 Python 的实现是一致的 - tiktoken 中不包含 Rust 训练代码
encode 的核心代码如下, 笔者不熟悉 Rust, 只能尝试注解, 尽量解释成 Python 用户能理解
type Rank = u32;
// 此处 piece 即为使用 pat_str 切分后的一个子块, 而 ranks 即为训练过程中得到的可合并字节
// 整体的流程是先通过 pat_str 将原始文本切块, 然后再对每个块进行 BPE 解码 (特殊情况: 假设 piece 本身就在 rank 中, 则直接将对应的 token id 加入最终的结果里)
// 此函数 (byte_pair_encode) 即为对切块后的文本的 encode 操作
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
assert!(piece.len() > 1);
// _byte_pair_merge 返回的结果会是 [(0, x), (2, x), (3, x), ... (len(piece), x)]
// 代表应将 pieces[0:2], piece[2:3], ... 作为 token
_byte_pair_merge(&ranks, &piece)
.windows(2) // 滑动窗口: [((0, x), (2, x)), ((2, x), (3, x)), ...]
.map(|part| ranks[&piece[part[0].0..part[1].0]]) // 这里相当于 Python 里的一个 lambda 语法: lambda x: ranks[pieces[x[0][0]:x[1][0]]]
.collect()
}
fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
// 这里的 ranks 对应于 Python 中的 Dict[bytes, int], 表示前文中的 mergeable_ranks
let mut parts = Vec::with_capacity(piece.len() + 1);
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
for i in 0..piece.len() - 1 {
// piece[i..i+2] 就表示取相邻的两个 byte, 如果找不到, rank=Rank::MAX, 否则返回对应的 rank
let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
if rank < min_rank.0 {
min_rank = (rank, i);
}
parts.push((i, rank));
}
parts.push((piece.len() - 1, Rank::MAX));
parts.push((piece.len(), Rank::MAX));
// 至此, parts 是一个列表(长度为 piece.len() + 1): [(0, 500), (1, Rank::MAX), (2, 300), ..., (9, Rank::MAX), (10, Rank::MAX)]
// 代表 parts[0:1] 和 parts[2:3] 在 ranks 中出现, 且对应的 token id 为 500 及 300
// 而 min_rank = (300, 2) 代表应该最优先合并的相邻字节
// 这里假设 pieces 的长度为 10, 即 10 个字节
// parts 中每一项 parts[i] 的具体含义是:
// 假设 parts[i] 与 part[i+1] 进行合并, 即 pieces[parts[i][0]:parts[i+2][0]] 做成一个 token,
// 那么它的 rank id 会是 parts[i][1]
// get_rank 是一个闭包(类似于 Python 中的局部函数), |parts: &Vec<(usize, Rank)>, i: usize| 是这个局部函数的输入
let get_rank = {
#[inline(always)]
|parts: &Vec<(usize, Rank)>, i: usize| {
if (i + 3) < parts.len() {
*ranks
.get(&piece[parts[i].0..parts[i + 3].0])
.unwrap_or(&Rank::MAX)
} else {
Rank::MAX
}
}
};
// main loop: 这里具体的解释见下面
while min_rank.0 != Rank::MAX {
let i = min_rank.1;
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1);
}
parts[i].1 = get_rank(&parts, i);
parts.remove(i + 1);
min_rank = (Rank::MAX, usize::MAX);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
}
}
}
parts // 也就是 Python 中的 return parts
}
我们这里再仔细看一下 main loop 的逻辑 (每次进入 while 时, 假设 i>0
):
(parts[i-1][0], parts[i-1][1]) # 表示如果将 pieces[parts[i-1][0]:parts[i+1][0]] 合在一起, 那么 token id 是 parts[i-1][1]
(parts[i][0], parts[i][1]) #
(parts[i+1][0], parts[i+1][1])
(parts[i+2][0], parts[i+2][1])
(parts[i+3][0], parts[i+3][1])
注意 i = min_rank.1
表示这一步要将 pieces[parts[i][0]:parts[i+2][0]]
合在一起了, 这样一来 pieces[parts[i+1][0]:parts[i+3][0]]
是不可能做合并的, 因此 parts[i+1]
最终需要被移除. 而再次之前, 我们还要更新 parts[i-1]
和 parts[i]
:
parts[i-1]
如果要再进行合并, 那么就只能合并成pieces[parts[i-1][0]:parts[i+2][0]]
, 这也就是get_rank(&parts, i-1)
做的事情parts[i]
如果要再进行合并, 那么就只能合并成pieces[parts[i][0]:parts[i+2][0]]
, 这也就是get_rank(&parts, i)
做的事情
i=0
的情况不难同理推敲, 此处从略.
至此, 我们给出等价的 python 实现 (不难理解, 这个实现本质上与上一节 _education.py
里的实现是等价的, 只是换了种实现方式):
from typing import Dict
MAX_INT = int(1e8)
def _byte_pair_merge(ranks: Dict[bytes, int], piece: bytes):
parts = []
min_rank = (MAX_INT, MAX_INT) # (rank, i)
n = len(piece)
for i in range(n-1):
rank = ranks.get(piece[i:i+2], MAX_INT)
if rank < min_rank[0]:
min_rank = (rank, i)
parts.append([i, rank])
parts.append([n-1, MAX_INT])
parts.append([n, MAX_INT])
def get_rank(parts, i):
if (i+3) < len(parts):
return ranks.get(piece[parts[i][0]:parts[i+3][0]], MAX_INT)
return MAX_INT
while min_rank[0] != MAX_INT:
i = min_rank[1]
if i > 0:
parts[i-1][1] = get_rank(parts, i-1)
parts[i][1] = get_rank(parts, i)
parts.pop(i+1)
min_rank = (MAX_INT, MAX_INT)
for (i, (_, rank)) in enumerate(parts[:len(parts)-1]):
if rank < min_rank[0]:
min_rank = (rank, i)
return parts
def byte_pair_encode(ranks: Dict[bytes, int], piece: bytes):
parts = _byte_pair_merge(ranks, piece)
n = len(parts)
token_ids = []
for i in range(n-1):
token = piece[parts[i][0]:parts[i+1][0]]
token_ids.append(ranks[token])
return token_ids
# 测试用例
ranks = {k.encode("utf-8"):v for k, v in {"a": 1, "b": 2, "c": 3, "ab": 450, "bc": 650}.items()}
piece = "abc".encode()
print(byte_pair_encode(ranks, piece)) # [450, 3]
更多用法
# 这里的 mergeable_ranks 和 special_tokens 是 Encoding.__init__ 的入参
Encoding.n_vocab == (len(mergeable_ranks) + len(special_tokens))