动机、参考资料、涉及内容

  • (Ready) tiktoken 的实现细节, 包括 python 实现与 rust 实现, 完全以学习的角度看: https://github.com/openai/tiktoken
  • (Unknown) 为什么 tiktoken 快过 huggingface 的实现: 具体的测试脚本? huggingface 的 FastTokenizer 也是用 Rust 实现的, 为什么会慢?

tiktoken

使用

tiktoken 是 BPE 算法的实现, 最重要的特点是 encode 与 decode 是无损的

简单使用

from tiktoken import get_encoding
from typing import List
enc = get_encoding("cl100k_base")
text = "hello world"
tokens: List[int] = enc.encode("hello world")
decode_text: str = enc.decode(tokens)
assert text == decode_text

关于 special_token 的处理

enc.encode("<|endofprompt|>", allowed_special=set(), disallowed_special="all")  # 默认值, 触发 Error
enc.encode("<|endofprompt|>", allowed_special="all", disallowed_special="all")  # [100276]

# 源码中的处理逻辑如下:
if allowed_special == "all":
    allowed_special = self.special_tokens_set
if disallowed_special == "all":
    disallowed_special = self.special_tokens_set - allowed_special
# 只要 disallowed_special 不为空, 输入的 text 就不能包含 disallowed_special 中的字符
if disallowed_special:
    if match := _special_token_regex(disallowed_special).search(text):
        raise_disallowed_special_token(match.group())
# 这个 _core_bpe 的类型是 Rust 绑定到 Python 的
return self._core_bpe.encode(text, allowed_special)

前置说明

下面以 cl100k_base 为例来分析 enc = get_encoding("cl100k_base") 的具体流程 (gpt3.5, gpt4, text-embedding-ada-002 这几个模型都是使用这个 tokenizer)

# tiktoken.get_encoding("cl100k_base") 的实际调用流程

ENDOFTEXT = "<|endoftext|>"
FIM_PREFIX = "<|fim_prefix|>"
FIM_MIDDLE = "<|fim_middle|>"
FIM_SUFFIX = "<|fim_suffix|>"
ENDOFPROMPT = "<|endofprompt|>"

def cl100k_base():
    # cl100k_base.tiktoken 本身是一个文本文件, 每一行是 base64 表示的 token 表示, 以及相应的 token 序号
    mergeable_ranks: dict[bytes, int] = load_tiktoken_bpe(
        "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
        expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7",
    )
    # load_tiktoken_bpe 的实际过程就是下载文件(或者从缓存中读取), 然后像这样解析
    # lines = open("cl100k_base.tiktoken", "rb").read().splitlines()
    # mergeable_ranks = {
    #     base64.b64decode(token): int(rank)
    #     for token, rank in (line.split() for line in lines if line)
    # }
    special_tokens = {
        ENDOFTEXT: 100257,
        FIM_PREFIX: 100258,
        FIM_MIDDLE: 100259,
        FIM_SUFFIX: 100260,
        ENDOFPROMPT: 100276,
    }
    return {
        "name": "cl100k_base",
        "pat_str": r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""",
        "mergeable_ranks": mergeable_ranks,
        "special_tokens": special_tokens,
    }

tokenizer = Encoding(**cl100k_base())

值得注意点主要包括 mergeable_ranks, pat_str, Encoding 这几处

(1) 这里先探索一下 mergeable_ranks

  • 前 256 个 token 完全覆盖了 \x00 ~ \xff2^8=256 个单字节字符
    import base64
      lines = open("cl100k_base.tiktoken", "rb").read().splitlines()
      mergeable_ranks = {
          base64.b64decode(token): int(rank) for token, rank in (line.split() for line in lines if line)
      }
      a = set([ord(token) for token, num in list(mergeable_ranks.items())[:256]])
      b = set([i for i in range(256)])
      a == b  # True
    
  • mergeable_ranks 长度为 100256, 因此特殊 token 的序号刚好错开
    all([i == num for i, (token, num) in enumerate(list(mergeable_ranks.items()))]) == True # True
    len(mergeable_ranks)  # 100256
    

(2) 这里的 Encoding 主要是对 rust 实现的 tokenizer 的封装, 相应的代码如下

# tiktoken/core.py
class Encoding:
    def __init__(self, ...):
        # 其余代码从略 ...
        self._special_tokens = special_tokens
        # _tiktoken 是一个 rust 编译出的的Python扩展动态链接库, 在 pip install tiktoken 后在硬盘上大约位于
        # site-packages/tiktoken/_tiktoken.cpython-39-x86_64-linux-gnu.so
        self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str)

    @functools.cached_property
    def special_tokens_set(self) -> set[str]:
        return set(self._special_tokens.keys())

    def encode(
        self,
        text: str,
        *,
        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),  # noqa: B006
        disallowed_special: Union[Literal["all"], Collection[str]] = "all",
    ) -> list[int]:
        if allowed_special == "all":
            allowed_special = self.special_tokens_set
        if disallowed_special == "all":
            disallowed_special = self.special_tokens_set - allowed_special
        if disallowed_special:  # 如果 disallowed_special = set("")
            if not isinstance(disallowed_special, frozenset):
                disallowed_special = frozenset(disallowed_special)
            if match := _special_token_regex(disallowed_special).search(text):
                raise_disallowed_special_token(match.group())

        if isinstance(allowed_special, frozenset):
            allowed_special = set(allowed_special)

        try:
            # 大多数情况是进入这个分支
            return self._core_bpe.encode(text, allowed_special)
        except UnicodeEncodeError:  # 如果有兴趣, 这个分支可参考源码说明
            text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
            return self._core_bpe.encode(text, allowed_special)
    def decode(self, tokens: list[int], errors: str = "replace") -> str:
        return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors)

由此可以看出工作完全是交接给 Rust 实现的 CoreBPE, 这便是本文的目标

(3) 这个 pat_str 用在预切分, 这里以 gpt-2 举例, 这个正则的特征是能将原始文本完全切分, 也就是切分完后可以重新精确恢复原始文本.

import regex
pat = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
text = "a's 1,123  abc  中国人"
parts = regex.findall(pat, text)
# ['a', "'s", ' 1', ',', '123', ' ', ' abc', ' ', ' 中国人']
# 可以看出对英文比较友好, 而连续的汉字会切分不开
"".join(parts) == text  # True, 这个性质是重点!!!!

总的来说就是一堆“或”关系的匹配, 具体可以拆解为:

  • '(?:[sdmt]|ll|ve|re): 实际上就是代表匹配 's|'d|'m|'t|'ll|'ve|'re, 基本上就是常见的缩写
  • ` ?\p{L}+`: 多个连续的 Unicode 文字
  • ` ?\p{N}+`: 多个连续的 Unicode 数字
  • ` ?[^\s\p{L}\p{N}]+`: 非空白, 非 Unicode 文字,
  • \s+(?!\S): 多个空白, 但不包含最后一个
  • \s+: 多个空白

备注: Unicode 字符集除了 p{L}(文字) 和 p{N}(数字) 外, 还有符号, 空白符, emoji 等等

备注: 最前面的 ?: 只是代表是非捕获组, 解释如下:

import re
text = "today is 2024-05-15, tomorrow is 2024-05-16"
pat1 = r"(\d{4})-(\d{2})-(\d{2})"
pat2 = r"(?:\d{4})-(\d{2})-(\d{2})"
print("捕获组")
for result in re.findall(pat1, text):
    print(result)
print("非捕获组")
for result in re.findall(pat2, text):
    print(result)

输出结果如下(其实只是结果里不包含非捕获组而已)

捕获组
('2024', '05', '15')
('2024', '05', '16')
非捕获组
('05', '15')
('05', '16')

cl100k_base 的这个 pat_str 如下

pat_str = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

值得注意的是, 这个 “正则” 不符合 python 的标注: 例如里面包含这种子序列 ?+, 这对于 python 正则是不合法的, 而像 \p{L} 代表匹配任意的 Unicode 文本字符, 这也同样不能用 python 的 re 模块所解析 (可以借助第三方包 regex). 后续会重新看看这个正则(从 Rust 的实现来看).

bpe 算法

完整实现请直接参考 tiktoken/_education.py

关于 bytes 的注解:

bytes(3)  # 表示 3 个 \x00, 即: b'\x00\x00\x00'
bytes([3, 4])  # 表示 b'\x03\x04'
b''.join([bytes([3]), bytes([4])])  # [b'\x03', b'\x04'] -> b'\x03\x04'

train

1) 初始化:

import regex
pat_str = r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
data = "中国人, haven't"

words: list[list[bytes]] = [
    [bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(pat_str, data)
]

# 具体过程为:
# regex.findall(pat_str, data)  # ['中国人', ',', ' haven', "'t"]
# 然后再将每个部分分解为字节:
# words = [
#     [b'\xe4', b'\xb8', b'\xad', b'\xe5', b'\x9b', b'\xbd', b'\xe4', b'\xba', b'\xba'],
#     [b','],
#     [b' ', b'h', b'a', b'v', b'e', b'n'],
#     [b"'", b't']
# ]
# assert b''.join(words[0]).decode() == "中国人"

2) 迭代:

while len(ranks) < vocab_size:  # vocab_size 是期望的词表大小
    # step 1: 找到最常见的紧邻的字节
    stats = collections.Counter()
    for piece in words:
        for pair in zip(piece[:-1], piece[1:]):
            stats[pair] += 1

    most_common_pair = max(stats, key=lambda x: stats[x])
    token_bytes = most_common_pair[0] + most_common_pair[1]
    token = len(ranks)
    ranks[token_bytes] = token

    # step 2: 用新添加的字节 pair 对 words 合并
    new_words = []
    for word in words:
        new_word = []
        i = 0
        while i < len(word) - 1:
            if (word[i], word[i + 1]) == most_common_pair:
                # We found our pair! Merge it
                new_word.append(token_bytes)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        if i == len(word) - 1:
            new_word.append(word[i])
        new_words.append(new_word)
    words = new_words

推理

def encode(self, text: str, visualise: Optional[str] = "colour") -> list[int]:
    words = self._pat.findall(text)
    tokens = []
    for word in words:
        word_bytes = word.encode("utf-8")
        word_tokens = bpe_encode(self.mergeable_ranks, word_bytes)
        tokens.extend(word_tokens)
    return tokens

def bpe_encode(mergeable_ranks: dict[bytes, int], input: bytes) -> list[int]:
    parts = [bytes([b]) for b in input]
    while True:
        # 每次进入都只会合并一次
        min_idx = None
        min_rank = None
        for i, pair in enumerate(zip(parts[:-1], parts[1:])):
            rank = mergeable_ranks.get(pair[0] + pair[1])
            # 这里 rank < min_rank 的作用是保证每次合并的是 rank 最小的 pair (rank 越小表示这是 train 过程越早发现的 pair)
            if rank is not None and (min_rank is None or rank < min_rank):
                min_idx = i
                min_rank = rank

        if min_rank is None:
            break
        assert min_idx is not None

        parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]

    tokens = [mergeable_ranks[part] for part in parts]
    return tokens

关于 rank < min_rank 的补充解释

mergeable_ranks = {k.encode("utf-8"): v for k, v in {"a": 1, "b": 2, "c": 3, "bc": 89, "ab": 100}.items()}
bpe_encode(mergeable_ranks, 'abc'.encode('utf-8'))  # [1, 89]
# 注意这里有两种合并规则: ['a', 'bc'] 和 ['ab', 'c'], 由于 {'bc': 89, 'ab': 100}, 所以 'bc' 更优先 

tiktoken 的 Rust 实现

  • https://github.com/youkaichao/fast_bpe_tokenizer: 这个仓库包含了一些对 tiktoken 的探索, 但注意无法保证作者的行文是否是统一的 (很可能像本文一样前后不一致)
  • 上一节是 tiktoken._education 实现的精确描述, 从本节的叙述可以看出, merge 的规则 Rust 的实现与 Python 的实现是一致的
  • tiktoken 中不包含 Rust 训练代码

encode 的核心代码如下, 笔者不熟悉 Rust, 只能尝试注解, 尽量解释成 Python 用户能理解

type Rank = u32;

// 此处 piece 即为使用 pat_str 切分后的一个子块, 而 ranks 即为训练过程中得到的可合并字节
// 整体的流程是先通过 pat_str 将原始文本切块, 然后再对每个块进行 BPE 解码 (特殊情况: 假设 piece 本身就在 rank 中, 则直接将对应的 token id 加入最终的结果里)
// 此函数 (byte_pair_encode) 即为对切块后的文本的 encode 操作
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
    assert!(piece.len() > 1);
    // _byte_pair_merge 返回的结果会是 [(0, x), (2, x), (3, x), ... (len(piece), x)]
    // 代表应将 pieces[0:2], piece[2:3], ... 作为 token
    _byte_pair_merge(&ranks, &piece)
        .windows(2)  // 滑动窗口: [((0, x), (2, x)), ((2, x), (3, x)), ...]
        .map(|part| ranks[&piece[part[0].0..part[1].0]])  // 这里相当于 Python 里的一个 lambda 语法: lambda x: ranks[pieces[x[0][0]:x[1][0]]]
        .collect()
}

fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
    // 这里的 ranks 对应于 Python 中的 Dict[bytes, int], 表示前文中的 mergeable_ranks
    let mut parts = Vec::with_capacity(piece.len() + 1);
    let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
    for i in 0..piece.len() - 1 {
        // piece[i..i+2] 就表示取相邻的两个 byte, 如果找不到, rank=Rank::MAX, 否则返回对应的 rank
        let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
        if rank < min_rank.0 {
            min_rank = (rank, i);
        }
        parts.push((i, rank));
    }
    parts.push((piece.len() - 1, Rank::MAX));
    parts.push((piece.len(), Rank::MAX));
    // 至此, parts 是一个列表(长度为 piece.len() + 1): [(0, 500), (1, Rank::MAX), (2, 300), ..., (9, Rank::MAX), (10, Rank::MAX)]
    // 代表 parts[0:1] 和 parts[2:3] 在 ranks 中出现, 且对应的 token id 为 500 及 300
    // 而 min_rank = (300, 2) 代表应该最优先合并的相邻字节
    // 这里假设 pieces 的长度为 10, 即 10 个字节
    // parts 中每一项 parts[i] 的具体含义是:
    //     假设 parts[i] 与 part[i+1] 进行合并, 即 pieces[parts[i][0]:parts[i+2][0]] 做成一个 token,
    //     那么它的 rank id 会是 parts[i][1]

    // get_rank 是一个闭包(类似于 Python 中的局部函数), |parts: &Vec<(usize, Rank)>, i: usize| 是这个局部函数的输入
    let get_rank = {
        #[inline(always)]
        |parts: &Vec<(usize, Rank)>, i: usize| {
            if (i + 3) < parts.len() {
                *ranks
                    .get(&piece[parts[i].0..parts[i + 3].0])
                    .unwrap_or(&Rank::MAX)
            } else {
                Rank::MAX
            }
        }
    };

    // main loop: 这里具体的解释见下面
    while min_rank.0 != Rank::MAX {
        let i = min_rank.1;
        if i > 0 {
            parts[i - 1].1 = get_rank(&parts, i - 1);
        }
        parts[i].1 = get_rank(&parts, i);
        parts.remove(i + 1);

        min_rank = (Rank::MAX, usize::MAX);
        for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
            if rank < min_rank.0 {
                min_rank = (rank, i);
            }
        }
    }
    parts  // 也就是 Python 中的 return parts
}

我们这里再仔细看一下 main loop 的逻辑 (每次进入 while 时, 假设 i>0):

(parts[i-1][0], parts[i-1][1])  # 表示如果将 pieces[parts[i-1][0]:parts[i+1][0]] 合在一起, 那么 token id 是 parts[i-1][1]
(parts[i][0], parts[i][1])      # 
(parts[i+1][0], parts[i+1][1])
(parts[i+2][0], parts[i+2][1])
(parts[i+3][0], parts[i+3][1])

注意 i = min_rank.1 表示这一步要将 pieces[parts[i][0]:parts[i+2][0]] 合在一起了, 这样一来 pieces[parts[i+1][0]:parts[i+3][0]] 是不可能做合并的, 因此 parts[i+1] 最终需要被移除. 而再次之前, 我们还要更新 parts[i-1]parts[i]:

  • parts[i-1] 如果要再进行合并, 那么就只能合并成 pieces[parts[i-1][0]:parts[i+2][0]], 这也就是 get_rank(&parts, i-1) 做的事情
  • parts[i] 如果要再进行合并, 那么就只能合并成 pieces[parts[i][0]:parts[i+2][0]], 这也就是 get_rank(&parts, i) 做的事情

i=0 的情况不难同理推敲, 此处从略.

至此, 我们给出等价的 python 实现 (不难理解, 这个实现本质上与上一节 _education.py 里的实现是等价的, 只是换了种实现方式):

from typing import Dict
MAX_INT = int(1e8)
def _byte_pair_merge(ranks: Dict[bytes, int], piece: bytes):
    parts = []
    min_rank = (MAX_INT, MAX_INT)  # (rank, i)
    n = len(piece)
    for i in range(n-1):
        rank = ranks.get(piece[i:i+2], MAX_INT)
        if rank < min_rank[0]:
            min_rank = (rank, i)
        parts.append([i, rank])
    parts.append([n-1, MAX_INT])
    parts.append([n, MAX_INT])

    def get_rank(parts, i):
        if (i+3) < len(parts):
            return ranks.get(piece[parts[i][0]:parts[i+3][0]], MAX_INT)
        return MAX_INT

    while min_rank[0] != MAX_INT:
        i = min_rank[1]
        if i > 0:
            parts[i-1][1] = get_rank(parts, i-1)
        parts[i][1] = get_rank(parts, i)
        parts.pop(i+1)
        min_rank = (MAX_INT, MAX_INT)
        for (i, (_, rank)) in enumerate(parts[:len(parts)-1]):
            if rank < min_rank[0]:
                min_rank = (rank, i)

    return parts

def byte_pair_encode(ranks: Dict[bytes, int], piece: bytes):
    parts = _byte_pair_merge(ranks, piece)
    n = len(parts)
    token_ids = []
    for i in range(n-1):
        token = piece[parts[i][0]:parts[i+1][0]]
        token_ids.append(ranks[token])
    return token_ids

# 测试用例
ranks = {k.encode("utf-8"):v for k, v in {"a": 1, "b": 2, "c": 3, "ab": 450, "bc": 650}.items()}
piece = "abc".encode()
print(byte_pair_encode(ranks, piece))  # [450, 3]

更多用法

# 这里的 mergeable_ranks 和 special_tokens 是 Encoding.__init__ 的入参
Encoding.n_vocab == (len(mergeable_ranks) + len(special_tokens))