动机、参考资料、涉及内容

涉及内容: GGUF 文件格式, llama.cpp, GGML

简介

GGML(GPT-Generated Model Language) 是一个开源项目(可以类比 pytorch), 用于高效的利用 CPU 进行张量计算, 项目定义了一种 GGUF(GPT-Generated Unified Format) 文件格式. GGML 的作者也同样是 llama.cpp 的作者, 而 llama.cpp 用于高效的利用 CPU 进行市面上流行的许多开源大模型的推理.

目前看 GGML 和 llama.cpp 项目没有依赖关系, 因此可能有代码会重复?

如果只希望快速部署一些开源大模型, 一般只需要关心 llama.cpp 即可, 具体例子见下节

llama.cpp Quick Start

普通的算法工程师一般只需要知道按如下方式部署一个现成的模型

参考: https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct-GGUF, 步骤如下:

git clone --depth=1 -b master https://github.com/ggerganov/llama.cpp.git
cd llama.cpp
# 编译可执行文件
make -j8
# 下载 gguf 文件
huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct-GGUF qwen2.5-1.5b-instruct-q5_k_m.gguf --local-dir . --local-dir-use-symlinks False
# 运行一个命令行交互的简单 demo
./llama-cli -m qwen2.5-1.5b-instruct-q5_k_m.gguf -co -cnv -p "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." -fa -ngl 80 -n 512

更深入的情况是:

  • Q1: 如果只有 pytorch_model.bin 这种形式的权重, 怎么转换为 GGUF 格式?
  • A1: 如果模型结构被 convert_hf_to_gguf.py 脚本支持, 那么直接用脚本转化即可
  • Q2: 如果模型结构不被支持呢?
  • A2: 可以修改 convert_hf_to_gguf.py 文件, 其实如果你这么做了, 你事实上可以为 llama.cpp 提 PR, 实际做起来你可能需要参考 HOWTO-add-model.md
  • Q3: 如果你很想知道 GGUF 文件的具体格式, 以探索用 GGUF 文件存储别的东西, 或者做更多的事情
  • A3: 请看下文

通常情况下, 至多只需要知道 A1 即可: 一般是你对模型做了微调, 需要自己转换格式, 但一般来说也就是用 huggingface transformers 进行微调, 而它大概也继承了这个转换脚本, 因此甚至于不需要知道 A1.

GGUF 文件格式

llama.cpp 项目下包含一个子项目: gguf-py (pypi:gguf), 这个子项目是纯 python 实现 GGUF 文件的读写, 因此可以用来了解 GGUF 的文件格式.

TODO: llama.cpp 里实际使用的应该是 C 语言的实现, 待搜寻

初体验: 读取一个实际的 GGUF 文件

cd /path/to/llama.cpp/gguf-py
pip install -e .
python examples/reader.py ../qwen/qwen2.5-1.5b-instruct-q5_k_m.gguf

可以得到这样的输出(省略了许多层)

Key-Value Pairs:
GGUF.version                           : [3]
GGUF.tensor_count                      : [339]
GGUF.kv_count                          : [26]
general.architecture                   : [113 119 101 110  50]
general.type                           : [109 111 100 101 108]
general.name                           : [113 119 101 110  50  46  53  45  49  46  53  98  45 105 110 115 116 114 117  99 116]
general.version                        : [118  48  46  49]
general.finetune                       : [113 119 101 110  50  46  53  45  49  46  53  98  45 105 110 115 116 114 117  99 116]
general.size_label                     : [49 46 56 66]
qwen2.block_count                      : [28]
qwen2.context_length                   : [32768]
qwen2.embedding_length                 : [1536]
qwen2.feed_forward_length              : [8960]
qwen2.attention.head_count             : [12]
qwen2.attention.head_count_kv          : [2]
qwen2.rope.freq_base                   : [1000000.]
qwen2.attention.layer_norm_rms_epsilon : [1.e-06]
general.file_type                      : [17]
tokenizer.ggml.model                   : [103 112 116  50]
tokenizer.ggml.pre                     : [113 119 101 110  50]
tokenizer.ggml.tokens                  : [33]
tokenizer.ggml.token_type              : [1]
tokenizer.ggml.merges                  : [196 160  32 196 160]
tokenizer.ggml.eos_token_id            : [151645]
tokenizer.ggml.padding_token_id        : [151643]
tokenizer.ggml.bos_token_id            : [151643]
tokenizer.ggml.add_bos_token           : [False]
tokenizer.chat_template                : [123  37  45 ...  37 125  10]
general.quantization_version           : [2]
----
Tensors:
Tensor Name                    | Shape: Shape           | Size: Size         | Quantization: Quantization
--------------------------------------------------------------------------------
output.weight                  | Shape: 1536x151936     | Size: 233373696    | Quantization: Q6_K
token_embd.weight              | Shape: 1536x151936     | Size: 233373696    | Quantization: Q5_K
blk.0.attn_norm.weight         | Shape: 1536            | Size: 1536         | Quantization: F32
blk.0.ffn_down.weight          | Shape: 8960x1536       | Size: 13762560     | Quantization: Q6_K
blk.0.ffn_gate.weight          | Shape: 1536x8960       | Size: 13762560     | Quantization: Q5_K
blk.0.ffn_up.weight            | Shape: 1536x8960       | Size: 13762560     | Quantization: Q5_K
blk.0.ffn_norm.weight          | Shape: 1536            | Size: 1536         | Quantization: F32
blk.0.attn_k.bias              | Shape: 256             | Size: 256          | Quantization: F32
blk.0.attn_k.weight            | Shape: 1536x256        | Size: 393216       | Quantization: Q5_K
blk.0.attn_output.weight       | Shape: 1536x1536       | Size: 2359296      | Quantization: Q5_K
blk.0.attn_q.bias              | Shape: 1536            | Size: 1536         | Quantization: F32
blk.0.attn_q.weight            | Shape: 1536x1536       | Size: 2359296      | Quantization: Q5_K
blk.0.attn_v.bias              | Shape: 256             | Size: 256          | Quantization: F32
blk.0.attn_v.weight            | Shape: 1536x256        | Size: 393216       | Quantization: Q6_K
...
output_norm.weight             | Shape: 1536            | Size: 1536         | Quantization: F32

这些输出可以与huggingface-hub的显示进行对照, 两者本质是完全一致的, 只是 huggingface-hub 上的显示更为友好一些, 例如:

# general.architecture: [113 119 101 110  50]
x = [113, 119, 101, 110, 50]
assert "".join([chr(_) for _ in x]) == "qwen2"

GGUF 与 Pytorch 相互转化, 可以参考: https://huggingface.co/docs/transformers/v4.45.2/en/gguf

# 利用 transformers 将 GGUF 转化为普通的 Pytorch Module 的 state_dict: 本质上也是利用 gguf python 包: from gguf import GGUFReader, dequantize
from transformers import AutoModelForCausalLM
import os
path = "/path/to/qwen/qwen2.5-1.5b-instruct-q5_k_m.gguf"
model = AutoModelForCausalLM.from_pretrained(os.path.dirname(path), gguf_file=path)  # 转化后是普通版本 float 权重的模型

# 将 transformers 的 Pytorch Module 权重文件转化为 GGUF 格式: 本质上是一个模型一个模型去与 huggingface 对齐的, 实质上也是使用了 GGUFWriter
https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py

接下来, 我们将深入理解读取过程, 从而理解 GGUF 的数据排布方式

前置知识: numpy 与 C 结构体的文件读写

隐藏

numpy: memmap 与 .npy 文件

import numpy as np
import struct

# memmap
nrows, ncols = 2, 3
arr = np.memmap('arr.dat', dtype=np.int32, mode='w', shape=(nrows, ncols))
arr[0][0] = 1
arr[0][1] = 2
arr[0][2] = 3
arr[1][0] = 4
arr[1][1] = 5
arr[1][2] = 6
arr.flush()

with open('arr.dat', "rb") as fr:
    x = fr.read()

struct.unpack("6I", x)  # (1, 2, 3, 4, 5, 6)


# npy 文件
with open("arr.npy", "wb") as fw:
    np.save(fw, np.array([1, 2, 3, 4], dtype=np.int32))
with open("arr.npy", "rb") as fr:
    x = fr.read()  # b"\x93NUMPY\x01\x00v\x00{'descr': '<i4', 'fortran_order': False, 'shape': (4,), }
len(x)  # 144

以上示例代码有如下要点:

  • np.memmap 写入的内容只包含纯数据, 而 .npy 文件还包含了一些额外信息
  • numpy 数组是行优先存储的

C 语言结构体写入文件

#include <stdio.h>
#include <string.h>

struct MyStruct {
    int id;
    float value;
    char name[20];
};

int write_example() {
    struct MyStruct example;

    example.id = 1;
    example.value = 10.5;
    snprintf(example.name, sizeof(example.name), "ExampleName");

    // 打开文件以二进制写入模式
    FILE *file = fopen("data.bin", "wb");
    if (!file) {
        perror("Failed to open file");
        return 1;
    }

    // 将结构体写入文件
    fwrite(&example, sizeof(struct MyStruct), 1, file);

    // 关闭文件
    fclose(file);

    printf("Structure saved to binary file.\n");
    return 0;
}

int read_example() {
    struct MyStruct example_read;
    FILE *file = fopen("data.bin", "rb");
    if (!file) {
        perror("Failed to open file");
        return 1;
    }

    fread(&example_read, sizeof(struct MyStruct), 1, file);
    fclose(file);

    printf("ID: %d, Value: %f, Name: %s\n", example_read.id, example_read.value, example_read.name);
    return 0;
}

int main() {
    write_example();
    read_example();
    return 0;
}

使用 python 读取文件

import numpy as np
with open("data.bin", "rb") as fr:
    x = fr.read(4) # 32 位整数
    y = fr.read(4) # 浮点数
    z = fr.read()  # 长度为 20

np.frombuffer(x, dtype=np.int32)   # array([1], dtype=int32)
np.frombuffer(y, dtype=np.float32) # array([10.5], dtype=float32)
"".join([chr(x) for x in np.frombuffer(z, dtype=np.uint8)])  # 'ExampleName\x00\x00\x00\x00\x00 á\x1d*'

GGUF 文件数据排布

reader.py 详解: GGUF 文件的内容格式如下图所示

https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/gguf-spec.png

以下为文件内容的 Overview (精确描述)

  • 文件的前 4 个字节以 uint32 按 Little Endian 读取必须是 GGUF
  • 接下来的 4 个字节以 uint32 按系统默认的字节序读取, 得到版本号(目前是3), 并同时得知后续内容中文件的字节序是 Big Endian 还是 Little Endian
  • 接下来的 8 个字节代表 metadata 总共有几个
  • 接下来的 8 个字节代表 tensor 总共有几个
  • metadata 的实际数据内容
  • tensor 的元信息, 例如数据类型, 数据个数, 起始位置
  • padding: 字节对齐, 使得第一个 tensor 的起始位置是字节对齐的
  • tensor 的实际数据内容, 每个 tensor 结尾都需要 padding 用于字节对齐

其准确的数据排布定义于 https://github.com/ggerganov/ggml/blob/master/docs/gguf.md 文件内

个人注释版
// GGUF 所支持的 tensor 数据类型
enum ggml_type: uint32_t {
    GGML_TYPE_F32     = 0,  // float32 类型
    GGML_TYPE_F16     = 1,  // float16 类型
    GGML_TYPE_Q4_0    = 2,
    // ...
    GGML_TYPE_IQ1_M   = 29,
    GGML_TYPE_COUNT,
};

// GGUF 文件中 value 的数据类型
enum gguf_metadata_value_type: uint32_t {
    GGUF_METADATA_VALUE_TYPE_UINT8 = 0,
    GGUF_METADATA_VALUE_TYPE_INT8 = 1,
    GGUF_METADATA_VALUE_TYPE_UINT16 = 2,
    GGUF_METADATA_VALUE_TYPE_INT16 = 3,
    GGUF_METADATA_VALUE_TYPE_UINT32 = 4,
    GGUF_METADATA_VALUE_TYPE_INT32 = 5,
    GGUF_METADATA_VALUE_TYPE_FLOAT32 = 6,
    GGUF_METADATA_VALUE_TYPE_BOOL = 7,
    GGUF_METADATA_VALUE_TYPE_STRING = 8,
    GGUF_METADATA_VALUE_TYPE_ARRAY = 9,
    GGUF_METADATA_VALUE_TYPE_UINT64 = 10,
    GGUF_METADATA_VALUE_TYPE_INT64 = 11,
    GGUF_METADATA_VALUE_TYPE_FLOAT64 = 12,
};

// GGUF 文件中的字符串数据结构: len 代表长度, 然后跟着实际的字符串数据
struct gguf_string_t {
    uint64_t len;
    char string[len];
};

// GGUF 文件中 metadata 的 value 实际值的数据结构
union gguf_metadata_value_t {
    // 数值类型: 直接存值
    uint8_t uint8;
    int8_t int8;
    uint16_t uint16;
    int16_t int16;
    uint32_t uint32;
    int32_t int32;
    float float32;
    uint64_t uint64;
    int64_t int64;
    double float64;
    bool bool_;
    // 字符串类型: 先存长度, 再存值
    gguf_string_t string;
    // 数组类型: 先存元素类型, 再存元素个数, 最后存值
    struct {
        gguf_metadata_value_type type;
        uint64_t len;
        gguf_metadata_value_t array[len];
    } array;
};

// GGUF 文件中一项 metadata 的 kv 对的数据结构: 先存key,再存value的数据类型,最存value的值
struct gguf_metadata_kv_t {
    gguf_string_t key;
    gguf_metadata_value_type value_type;
    gguf_metadata_value_t value;
};

// GGUF 文件头部分
struct gguf_header_t {
    uint32_t magic;  // "GGUF" 的字节表示: 0x47475546
    uint32_t version;  // 目前是 3
    uint64_t tensor_count;  // 文件中 tensor 的数量
    uint64_t metadata_kv_count;  // 文件中 metadata 的数量
    gguf_metadata_kv_t metadata_kv[metadata_kv_count];  // metadata
};

// 字节对齐: 每个 tensor data 的起始位置必须是 ALIGNMENT 的整数倍
uint64_t align_offset(uint64_t offset) {
    return offset + (ALIGNMENT - (offset % ALIGNMENT)) % ALIGNMENT;
}

// GGUF 文件的一项 tensor info 的数据结构
struct gguf_tensor_info_t {
    gguf_string_t name;  // tensor 的名称
    uint32_t n_dimensions;  // tensor 的维度, 目前最大是 4
    uint64_t dimensions[n_dimensions];  // tensor 的 shape.
    ggml_type type;  // tensor 的数据类型
    uint64_t offset;  // 相较于 gguf_file_t.tensor_data 的 offset
};


// GGUF 文件
struct gguf_file_t {
    gguf_header_t header;  // GGUF 文件头部分
    gguf_tensor_info_t tensor_infos[header.tensor_count];  // tensor info
    // 字节对齐
    uint8_t _padding[];
    // tensor 的实际值, 每个 tensor 的起始位置必须是 ALIGNMENT 的整数倍(字节对齐)
    uint8_t tensor_data[];
};

GGUFReader 详解

qwen2.5-1.5b-instruct-q5_k_m.gguf 文件为例, 阅读 gguf.GGUFReader__init__ 方法

本质上是按顺序解析

  • (1) 头部
  • (2) metadata
  • (3) tensor info
  • (4) tensor data

(1) 头部

4+4+8+8=24 个字节, GGUFReader 最终会在 self.fields 里增加 "GGUF.version", "GGUF.tensor_count", "GGUF.kv_count" 这三个 ReaderField

以下是直接对字节进行手工解析的代码

import numpy as np
with open(path, "rb") as fr:
    offset = 0
    
    x = fr.read(4)  # b'GGUF': 魔术数字
    offset += 4
    
    x = fr.read(4)   # b'\x03\x00\x00\x00'
    np.frombuffer(x, dtype=np.uint32)  # array([3], dtype=uint32): GGUF 版本号
    offset += 4

    tensor_count = fr.read(8)
    tensor_count = int(np.frombuffer(tensor_count, dtype=np.uint64))  # 339, 表示有 339 个 tensor
    offset += 8

    kv_count = fr.read(8)
    kv_count = int(np.frombuffer(kv_count, dtype=np.uint64)[0])  # 26, 表示后续有 26 个 metadata
    offset += 8

(2) metadata

class GGUFValueType(IntEnum):
    UINT8   = 0
    INT8    = 1
    UINT16  = 2
    INT16   = 3
    UINT32  = 4
    INT32   = 5
    FLOAT32 = 6
    BOOL    = 7
    STRING  = 8
    ARRAY   = 9
    UINT64  = 10
    INT64   = 11
    FLOAT64 = 12

# 一个 ReaderField 代表了一个 metadata 的 kv 对
class ReaderField(NamedTuple):
    offset: int  # 起始地址
    name: str    # key的字符串表示
    parts: list[npt.NDArray[Any]] = []  # 将字节拆分为多个部分, 将 parts 转化为字节合并在一起即为文件中针对该 metadata 的原始的字节表示
    data: list[int] = [-1]  # parts 的下标数组, 真实的数据所对应的实际下标
    types: list[GGUFValueType] = []  # value的数据类型描述, 对于简单的标量, types 仅包含一个元素, 对于嵌套情形, types 包含多个元素

# GGUFReader 中最终是将这些 metadata 放在 self.fields 里
# field: ReaderField
# self.fields[field.name] = field

以下是直接对字节进行手工解析的代码(TODO)

import numpy as np
from gguf.constants import GGUFValueType

with open(path, "rb") as fr:
    fr.seek(24)  # 跳过头部的 24 个字节

    # value 的类型仅支持数字,字符串, Array. 但 Array 里必须是同类型的, 且允许嵌套(嵌套情形后面再看)

    # metadata 的数据组织形式
    # uint64: key 的长度
    # key
    # uint32: value 的类型
    # 以下排布方式与 value 类型相关:
    # (1) string:
    # uint64: value 的长度
    # value
    # 例子: [20, general.architecture, 8, 5, qwen2], GGUFValueType.STRING=8
    # (2) uint8,int8,uint16,int16,uint32,int32,uint64,int64,bool,float32,float64
    # value
    # 例子: [26, qwen2.attention.head_count, 4, 12], GGUFValueType.UINT32=4
    # (3) array
    # uint32: value 中每个元素的类型
    # value: 根据每个元素是字符串还是数值类型确定
    # 例子(array[string]): [21, tokenizer.ggml.tokens, 9, 8, 1, 33, 1, 34, ...]
    # 其中 GGUFValueType.ARRAY=9, GGUFValueType.STRING=8
    # 接下来每两个数据项为一组, 例如: [1, 33] 代表 `!`, [1, 34] 代表 `"`
    

    # ReaderField: NamedTuple
    # offest: int, 起始地址
    # name: str, key 的字符串形式
    # parts: List[np.array], 以 value 是字符串类型为例, parts 是一个五元组
    #     [
    #         np.array([20], dtype=np.uint64),
    #         np.array([103, 101, 110, 101, 114,  97, 108,  46,  97, 114,  99, 104, 105, 116, 101,  99, 116, 117, 114, 101], dtype=np.uint8),
    #         np.array([8], dtype=np.uint32),
    #         np.array([5], dtype=np.uint64),
    #         np.array([113, 119, 101, 110,  50], dtype=np.uint8)
    #     ]
    # data: List[int], parts 中哪些项是数据, value 是字符串类型为例, data 是 [4], 代表 parts[4] 是 value 的值 
    # types: List[GGUFValueType], 以 value 是字符串类型为例, types = [GGUFValueType.STRING]
    # types 与 data 的长度不一定一致, 例如 data=[6,8,...], type = [GGUFValueType.ARRAY, GGUFValueType.STRING]
    
    key_len = fr.read(8)
    key_len = int(np.frombuffer(key_len, dtype=np.uint64)[0])
    offset += 8
    
    key_data = fr.read(key_len)
    key_data = np.frombuffer(key_data, dtype=np.uint8)
    offset += key_len

    raw_kv_type = fr.read(4)
    raw_kv_type = int(np.frombuffer(raw_kv_type, dtype=np.uint32)[0])  # 8, 代表 String

metadata 中的 value 可以是 Array, 且允许一定程度的嵌套, 首先利用 GGUFWriterGGUFReader 观察一下可以怎么嵌套

from gguf import GGUFWriter, GGUFReader
import numpy as np

def writer_example() -> None:
    gguf_writer = GGUFWriter("example.gguf", "llama")
    
    gguf_writer.add_array("arr", [[1, 2, 3], [4, 5, 6]])  # 情况1, ok
    # gguf_writer.add_array("arr", [[1, 2, 3], ["abc", "def"]])  # 情况2, ok, 嵌套时允许不等长, 也允许基本元素类型不一致
    # gguf_writer.add_array("arr", [[1, 2, 3], ["abc", "def", 3]])  # error
    
    tensor1 = np.ones((32,), dtype=np.float32) * 100.0
    gguf_writer.add_tensor("tensor1", tensor1)
    gguf_writer.write_header_to_file()
    gguf_writer.write_kv_data_to_file()
    gguf_writer.write_tensors_to_file()
    gguf_writer.close()

writer_example()
reader = GGUFReader("example.gguf")

parts = reader.fields['arr'].parts
types = reader.fields["arr"].types
data_indexes = reader.fields["arr"].data
data = [parts[idx] for idx in data_indexes]

print("parts:", parts)
print("types:", types)
print("data_indexes:", data_indexes)
print("data:", data)
情况1

情况1: [[1, 2, 3], [4, 5, 6]]

parts:
[
    memmap([3], dtype=uint64),           # key 字符串 "arr" 的字节数是 3
    memmap([97, 114, 114], dtype=uint8), # key 字符串: "arr" 的 ASCII 码表示是 [97, 114, 114]
    memmap([9], dtype=uint32),           # GGUFValueType.ARRAY=9, 代表 value 的数据类型是 ARRAY, 因此按照 gguf_metadata_value_t 的定义, 接下来需要记录 value 中每个元素的数据类型和 value 的长度, 最后是 value 的实际数据
    memmap([9], dtype=uint32),           # GGUFValueType.ARRAY=9, 由于 value 的数据类型是 ARRAY, 这里需要记录 value 中元素的数据类型, 这里内层依然是 ARRAY
    memmap([2], dtype=uint64),           # 代表 value 的长度是 2
    memmap([5], dtype=uint32),           # GGUFValueType.INT32=5, 由于 value 中元素的数据类型是 ARRAY, 因此需要记录内部数据类型, 长度以及实际内容
    memmap([3], dtype=uint64),           # 长度为 3
    memmap([1], dtype=int32),            # 1, 实际数据
    memmap([2], dtype=int32),            # 2, 实际数据
    memmap([3], dtype=int32),            # 3, 实际数据
    memmap([5], dtype=uint32),           # GGUFValueType.INT32=5, 由于 value 中元素的数据类型是 ARRAY, 因此需要记录内部数据类型, 长度以及实际内容
    memmap([3], dtype=uint64),           # 长度为 3
    memmap([4], dtype=int32),            # 4, 实际数据
    memmap([5], dtype=int32),            # 5, 实际数据
    memmap([6], dtype=int32)             # 6, 实际数据
]
types:
[<GGUFValueType.ARRAY: 9>, <GGUFValueType.ARRAY: 9>, <GGUFValueType.INT32: 5>]
data_indexes:
[7, 8, 9, 12, 13, 14]
data:
[memmap([1], dtype=int32), memmap([2], dtype=int32), memmap([3], dtype=int32), memmap([4], dtype=int32), memmap([5], dtype=int32), memmap([6], dtype=int32)]

备注: types 字段看起来少记录了一个 GGUFValueType.INT32, 这应该是 gguf 的 BUG

情况2

情况2: [[1, 2, 3], ["abc", "def"]]

parts:
[
    memmap([3], dtype=uint64),            # key 字符串 "arr" 的字节数是 3
    memmap([ 97, 114, 114], dtype=uint8), # key 字符串: "arr" 的 ASCII 码表示是 [97, 114, 114]
    memmap([9], dtype=uint32),            # GGUFValueType.ARRAY=9, 代表 value 的数据类型是 ARRAY, 因此按照 gguf_metadata_value_t 的定义, 接下来需要记录 value 中每个元素的数据类型和 value 的长度, 最后是 value 的实际数据
    memmap([9], dtype=uint32),            # GGUFValueType.ARRAY=9, 由于 value 的数据类型是 ARRAY, 这里需要记录 value 中元素的数据类型, 这里内层依然是 ARRAY
    memmap([2], dtype=uint64),            # 2, 代表 value 的长度是 2
    memmap([5], dtype=uint32),            # GGUFValueType.INT32=5, 由于 value 中元素的数据类型是 ARRAY, 因此需要记录内部数据类型, 长度以及实际内容
    memmap([3], dtype=uint64),            # 长度为 3
    memmap([1], dtype=int32),             # 1, 实际数据
    memmap([2], dtype=int32),             # 2, 实际数据
    memmap([3], dtype=int32),             # 3, 实际数据
    memmap([8], dtype=uint32),            # GGUFValueType.STRING=8, 由于 value 中元素的数据类型是 ARRAY, 因此需要记录内部数据类型, 长度以及实际内容
    memmap([2], dtype=uint64),            # 长度为 2
    memmap([3], dtype=uint64),            # 字符串长度是 3
    memmap([97, 98, 99], dtype=uint8),    # 实际数据: "abc"
    memmap([3], dtype=uint64),            # 字符串长度是 3
    memmap([100, 101, 102], dtype=uint8)  # 实际数据: "def"
]
types:
[<GGUFValueType.ARRAY: 9>, <GGUFValueType.ARRAY: 9>, <GGUFValueType.INT32: 5>]
data_indexes:
[7, 8, 9, 13, 15]
data:
[memmap([1], dtype=int32), memmap([2], dtype=int32), memmap([3], dtype=int32), memmap([97, 98, 99], dtype=uint8), memmap([100, 101, 102], dtype=uint8)]

备注: types 字段看起来少记录了一个 GGUFValueType.STRING, 这应该是 gguf 的 BUG

(3) tensor info 以及 padding

TODO

(4) tensor data

TODO

量化与反量化

前面已经提及 huggingface transformers 代码库里的 from_pretrained 方法可以直接读取 GGUF 文件, 这个过程将 GGUF 里的权重反量化为浮点数. 而这个反量化本质上是 gguf 包提供的, 下面先看一个例子验证这一点:

  • 使用 from_pretrained 方法得到的模型权重
  • 直接使用 gguf 包提供的 GGUFReaderdequantize 手动反量化权重

环境要求: gguf == 0.10.0, transformers==4.45.2

from transformers import AutoModelForCausalLM
from gguf import GGUFReader, GGUFWriter, dequantize, quantize
import os

path = "/content/qwen2.5-1.5b-instruct-q5_k_m.gguf"

model = AutoModelForCausalLM.from_pretrained(os.path.dirname(path), gguf_file=path)
model_tensor = model.state_dict()["lm_head.weight"]
print(model_tensor.shape)  # torch.Size([151936, 1536])
print(model_tensor[:2, :3].numpy())
# array([[ 0.00715065,  0.01251364, -0.01072598],
#        [ 0.00555754,  0.0155611 ,  0.02334166]], dtype=float32)


reader = GGUFReader(path)
gguf_tensor = reader.tensors[0]
print(gguf_tensor.name)  # "output.weight"
print(gguf_tensor.tensor_type)   # <GGMLQuantizationType.Q6_K: 14>
print(gguf_tensor.data.shape)  # (151936, 1260), 量化后的数据维度
print(gguf_tensor.shape)  # (151916, 1536) 原始 float 权重的维度

gguf_float_tensor = dequantize(gguf_tensor.data, gguf_tensor.tensor_type)
print(gguf_float_tensor[:2, :3])
# array([[ 0.00715065,  0.01251364, -0.01072598],
#        [ 0.00555754,  0.0155611 ,  0.02334166]], dtype=float32)

可以看到, 上面两种做法得到的浮点数形式的权重是一致的. 因此我们下一步可以深入研究 dequantize 方法, 相关代码是: llama.cpp/gguf-py/gguf/quants.py

我们先看这个例子中用到的 GGMLQuantizationType.Q6_K

gguf.quants.__Quant

gguf.dequantize 方法实际上也就是 gguf.quants.Q6_K.dequantize, 而 Q6_K 继承自 gguf.quants.__Quant.

gguf 采用的量化方法都有如下统一的设定: 假设一个浮点数权重原始是 (out_size, in_size), 而量化后的权重表示是 (out_size, qin_size), 对于 LLM 的输出层来说 out_size 也就是词表长度, in_size 就是最后一个隐藏层的维度.

量化方式总是逐行分组量化的, 也就是将 in_size 进一步按每组 block_size 划分, 也就是:

  • in_size = block_size * num_blocks
  • qin_size = type_size * num_blocks

也就是 block_size 个 float32 浮点数对应于 type_size 个 uint8 的整数, gguf 中不同的量化方式对应的 block_sizetype_size 不尽相同, 例如 Q6_Kblock_size=256, type_size=210.

以实际例子来说, 对于 qwen2.5-1.5b-instruct-q5_k_m.gguf 这份模型的 output.weight 权重来说, 其原始权重的数据类型为 float32, 形状为 (out_size, qin_size)=(151936, 1536), 而量化后的数据类型为 uint8, 形状为 (out_size, qin_size)=(151936, 1260), 也就是

  • 量化前第 0 行第 [0, …, 255] 共 256 个 32位浮点数 <—-> 量化后第 0 行第 [0, …, 209] 共 210 个 8位整数
  • 量化前第 0 行第 [256, …, 511] 共 256 个 32位浮点数 <—-> 量化后第 0 行第 [210, …, 419] 共 210 个 8位整数
  • 量化前第 151935 行第 [0, …, 255] 共 256 个 32位浮点数 <—-> 量化后第 151935 行第 [0, …, 209] 共 210 个 8位整数

__Quant 类及其子类实现了量化与反量化的逻辑, 在具体实现上, 每个子类 (例如: Q6_K) 基本上只需要实现 quantize_blocksdequantize_blocks 两个方法即可

  • quantize_blocks: 输入是形状为 (k, block_size), 数据类型为 float32 的数组, 输出形状为 (k, type_size), 数据类型为 uint8 的数组
  • dequantize_blocks: 输入形状为 (k, type_size), 数据类型为 uint8 的数组, 输出是形状为 (k, block_size), 数据类型为 float32 的数组

以下是对 __Quant 类中反量化的相关代码的个人注释, 源代码位于: https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py

个人注释版

可以参照注释前面的标号 (1), (2), … 依次进行追踪阅读

_type_traits = dict[GGMLQuantizationType, type[__Quant]] = {}
QK_K = 256

# (4) 根据量化后的形状计算量化前权重的形状, 在我们的例子中, 输入是 (151936, 1260), 输出是 (151936, 1536)
def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
    # GGML_QUANT_SIZES[GGMLQuantizationType.Q6_K] = (256, 210) = (256, 2+QK_K//2+QK_K//4+QK_K//16)
    # (block_size=256, type_size=210)
    # block_size 与 type_size 的含义是:
    # 量化是按行分组量化的, 也就是说, 量化前是 (151936, 1536) float32, 量化后是 (151936, 1260) uint8
    # 以每行按照 block_size 个元素分组进行量化, 也就是:
    # 量化前第 0 行第 [0, ..., 255] 共 256 个 32位浮点数 <----> 量化后第 0 行第 [0, ..., 209] 共 210 个 8位整数
    # 量化前第 0 行第 [256, ..., 511] 共 256 个 32位浮点数 <----> 量化后第 0 行第 [210, ..., 419] 共 210 个 8位整数
    # ...
    # 量化前第 151935 行第 [0, ..., 255] 共 256 个 32位浮点数 <----> 量化后第 151935 行第 [0, ..., 209] 共 210 个 8位整数
    # ...
    block_size, type_size = GGML_QUANT_SIZES[quant_type]
    if shape[-1] % type_size != 0:
        raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})")
    return (*shape[:-1], shape[-1] // type_size * block_size)

# (5) 这个函数的工作实质上都是通过 func 来完成的
# This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
    # 入参解释:
    # arr: (151936, 1260) uint8, 量化后的表示
    # func 是一个函数, func 用于对多行进行反量化, 输入是 (n, 210*k=1260) 形状的 uint8 数组, 输出是 (n, 256*k=1536) 形状的 float32 数组. 因此本质上输入输出维数与 _apply_over_grouped_rows 是类似的
    # otype: np.float32
    # oshape: (151936, 1536)
    # 出参解释:
    # (151936, 1536) float32, 反量化后的浮点数表示
    rows = arr.reshape((-1, arr.shape[-1]))  # (m, type_size*k) = (151936, 210*6)
    osize = 1
    for dim in oshape:
        osize *= dim
    out = np.empty(shape=osize, dtype=otype)  # (151936*1536,)
    # compute over groups of 16 rows (arbitrary, but seems good for performance)
    n_groups = (rows.shape[0] // 16) or 1

    # 每个 group 的形状是 (151936/16=9496, 1260)
    # 因此 func(group) 的形状输出是 (9496, 1536)
    # ravel() 函数就只是一个 flatten
    # np.concatenate 之后的输出形状是 (151936*1536,)
    np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
    # reshape 为输出维度, 也就是 (151936*1536,)
    return out.reshape(oshape)

class __Quant(ABC):
    qtype: GGMLQuantizationType  # 对于 Q6_K 来说, qtype=GGMLQuantizationType.Q6_K
    block_size: int              # 对于 Q6_K 来说, block_size=256
    type_size: int               # 对于 Q6_K 来说, type_size=210

    # 对于 Q6_K 来说, 以下均为默认值
    grid: np.ndarray[Any, np.dtype[np.float32]] | None = None
    grid_shape: tuple[int, int] = (0, 0)
    grid_map: tuple[int | float, ...] = ()
    grid_hex: bytes | None = None

    def __init__(self):
        return TypeError("Quant conversion classes can't have instances")

    def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
        cls.qtype = qtype
        cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
        cls.__quantize_lazy = LazyNumpyTensor._wrap_fn(
            cls.__quantize_array,
            meta_noop=(np.uint8, cls.__shape_to_bytes)
        )
        cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn(
            cls.__dequantize_array,
            meta_noop=(np.float32, cls.__shape_from_bytes)
        )
        assert qtype not in _type_traits
        _type_traits[qtype] = cls
    
    @classmethod
    def init_grid(cls):
        if cls.grid is not None or cls.grid_hex is None:
            return
        # 此处省略后面的逻辑: ...

    @classmethod
    @abstractmethod
    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
        raise NotImplementedError
    
    @classmethod
    def __shape_from_bytes(cls, shape: Sequence[int]):
        return quant_shape_from_byte_shape(shape, cls.qtype)

    # (6) 反量化多行的逻辑
    # dequantize_rows 的输入是 (n=9496, 210*k=1260) 形状的 uint8 数组, 输出是 (n=9496, 256*k=1536) 形状的 float32 数组
    @classmethod
    def dequantize_rows(cls, rows: np.ndarray) -> np.ndarray:
        rows = rows.view(np.uint8)  # 这里什么也没干
        shape = rows.shape
        n_blocks = rows.size // cls.type_size
        blocks = rows.reshape((n_blocks, cls.type_size))  # (9496*6, 210)
        # 对 blocks 进行反量化, 输入形状是 (n*k, type_size), 输出形状是 (n*k, block_size)
        # 也就是子类只需要重载这个 dequantize_blocks 函数即可
        blocks = cls.dequantize_blocks(blocks)  # (9496*6, 256)
        assert blocks.dtype == np.float32
        assert blocks.shape[-1] == cls.block_size
        return blocks.reshape(cls.__shape_from_bytes(shape))  # (n*k, block_size) -> (n, block_size*k)

    # (3) dequantize 的调用实际发生在此处
    @classmethod
    def __dequantize_array(cls, array: np.ndarray) -> np.ndarray:
        cls.init_grid()  # 对于 Q6_K 来说, 这里没有执行任何逻辑, 但其他的量化方式, 可能会有逻辑
        return _apply_over_grouped_rows(
            cls.dequantize_rows,
            arr=array,  # (151936, 1260), uint8
            otype=np.float32,
            oshape=cls.__shape_from_bytes(array.shape)  # (151936, 1536)
        )

    # (2) 我们暂时不考虑 LazyNumpyTensor 的逻辑, 因此实际上只是调用 __dequantize_array
    @classmethod
    def dequantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
        if isinstance(tensor, LazyNumpyTensor):
            return cls.__dequantize_lazy(tensor)
        else:
            return cls.__dequantize_array(tensor)

class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
    # 疑惑: Q6_K 在 gguf==0.10.0 版本没有 quantize_blocks
    # (7) 对多个 block (或者说是 group) 进行反量化
    # 输入是: (x=9496*6, type_size=210) uint8
    # 输出是: (x=9496*6, block_size=256) float32
    # 以下注释只先注释形状变化
    @classmethod
    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
        n_blocks = blocks.shape[0]

        # np.hsplit 的作用是按列进行划分
        # np.hsplit(arr, [2, 5, 10])  # 输出: arr[:, :2], arr[:, 2:5], arr[:, 5:10], arr[:, 10:]
        ql, rest = np.hsplit(blocks, [QK_K // 2])  # ql: (x, 128), rest: (x, 210-128=82)
        qh, rest = np.hsplit(rest, [QK_K // 4])    # qh: (x, 82), rest: (x, 82-64=18)
        scales, d = np.hsplit(rest, [QK_K // 16])  # scales: (x, 16), d: (x, 2)

        # view 不改变数据, 只是改变解释方式, astype 是逐元素进行数据类型转化
        # scales: (x, 16): uint8 -> (x, 16): float32, 注意是 astype(np.float32), 而不是 view(np.float32)
        scales = scales.view(np.int8).astype(np.float32)
        # d: (x, 2): uint8 -> (x, 1): float16->float32, 注意是 view(np.float16)
        d = d.view(np.float16).astype(np.float32)

        # 以下为实际的反量化过程的核心逻辑, 具体解释见后文
        # d: (x, 16, 1)
        d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
        
        # (x, 2, 1, 64), (1, 1, 2, 1) -> (x, 2, 2, 64)
        ql = ql.reshape((n_blocks, -1, 1, 64)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
        # # (x, 8, 32)
        ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32))
        
        # (x, 2, 1, 32), (1, 1, 4, 1) -> (x, 2, 4, 32)
        qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
        # (x, 8, 32)
        qh = (qh & np.uint8(0x03)).reshape((n_blocks, -1, 32))

        # (x, 8, 32)
        q = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(32)
        # (x, 16, 16)
        q = q.reshape((n_blocks, QK_K // 16, -1)).astype(np.float32)

        # (x, 16, 1), (x, 16, 16) -> (x, 16, 16) -> (x, 256)
        return (d * q).reshape((n_blocks, QK_K))

# (1) 对外接口, 起始本质上就是 Q6_K.dequantize
# 在 Q6_K 的情形下, qwen2.5-1.5b, data=output.weight (float32) 的形状是 (151936, 1536)
# 其中 151936 是词表长度, 1536 是隐层的维数
# 而将其量化后的表示是 (151936, 1260), uint8
def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
    # data: (151936, 1260), uint8
    if qtype == GGMLQuantizationType.F32:
        return data.view(np.float32)
    elif qtype == GGMLQuantizationType.F16:
        return data.view(np.float16).astype(np.float32)
    elif (q := _type_traits.get(qtype)) is not None:
        return q.dequantize(data)
    else:
        raise NotImplementedError(f"Dequantization for {qtype.name} is not yet implemented")

gguf.quants.Q6_K

以下为 gguf.quants.Q6_Kgguf==0.10.0 版本里的全部源代码.

备注: gguf==0.10.0 版本的 Q6_K 类没有 quantize_blocks 方法, 但实际上应该可以参考 C 语言的实现 https://github.com/ggerganov/llama.cpp/blob/master/ggml/src/ggml-quants.c

QK_K = 256

class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
    # 输入例子是: (x=9496*6, type_size=210) uint8
    # 输出例子是: (x=9496*6, block_size=256) float32
    # 以下注释只先注释形状变化
    @classmethod
    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
        n_blocks = blocks.shape[0]

        # np.hsplit 的作用是按列进行划分
        # np.hsplit(arr, [2, 5, 10])  # 输出: arr[:, :2], arr[:, 2:5], arr[:, 5:10], arr[:, 10:]
        ql, rest = np.hsplit(blocks, [QK_K // 2])  # ql: (x, 128), rest: (x, 210-128=82)
        qh, rest = np.hsplit(rest, [QK_K // 4])    # qh: (x, 82), rest: (x, 82-64=18)
        scales, d = np.hsplit(rest, [QK_K // 16])  # scales: (x, 16), d: (x, 2)

        # view 不改变数据, 只是改变解释方式, astype 是逐元素进行数据类型转化
        # scales: (x, 16): uint8 -> (x, 16): float32, 注意是 astype(np.float32), 而不是 view(np.float32)
        scales = scales.view(np.int8).astype(np.float32)
        # d: (x, 2): uint8 -> (x, 1): float16->float32, 注意是 view(np.float16)
        d = d.view(np.float16).astype(np.float32)

        # 以下为实际的反量化过程的核心逻辑
        # d: (x, 16, 1)
        d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
        
        # (x, 2, 1, 64), (1, 1, 2, 1) -> (x, 2, 2, 64)
        ql = ql.reshape((n_blocks, -1, 1, 64)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
        # # (x, 8, 32)
        ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32))
        
        # (x, 2, 1, 32), (1, 1, 4, 1) -> (x, 2, 4, 32)
        qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
        # (x, 8, 32)
        qh = (qh & np.uint8(0x03)).reshape((n_blocks, -1, 32))

        # (x, 8, 32)
        q = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(32)
        # (x, 16, 16)
        q = q.reshape((n_blocks, QK_K // 16, -1)).astype(np.float32)

        # (x, 16, 1), (x, 16, 16) -> (x, 16, 16) -> (x, 256)
        return (d * q).reshape((n_blocks, QK_K))

https://huggingface.co/docs/hub/en/gguf 中对这种量化类型的描述

6-bit quantization (q). Super-blocks with 16 blocks, each block has 16 weights. Weight formula: w = q * block_scale(8-bit), resulting in 6.5625 bits-per-weight.

ps: 计划给 llama.cpp 提 PR, 增加 Q6_K 的量化算法, 需要对齐 C 的实现

原始的 PR

cd /path/to/llama.cpp
make libglmm.so
cd gguf-py
python tests/test_quants.py --libggml ../libggml.so

量化推理

TODO