DataCollator

在 huggingface 的 trainer 里, dataset 一般是使用 datasets 包的 load_dataset 得到的 Dataset 经过 map 等方式处理. 而 dataloader 则一般是通过普通的 torch.utils.data.DataLoader 将 dataset 和这里的 datacollator 作为参数传入得到最终的 dataloader, 而 transformers 的 Model 的 forward 的输入就是 dataloader 的一个 batch

transformers.DataCollatorForLanguageModeling

import torch
from transformers import AutoTokenizer
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

tokenizer = AutoTokenizer.from_pretrained("./gpt-2/")
tokenizer.pad_token = tokenizer.eos_token

import torch
data = [
    {
        "input_ids": torch.tensor([1000, 1234, 5679]),
        "attention_mask": torch.tensor([1, 1, 1]),
        # "other_text": "text-A",  # 不允许, 必须预先删除
    },
    {
        "input_ids": torch.tensor([1000, 11111, 2345, 1234, 5679]),
        "attention_mask": torch.tensor([1, 1, 1, 1, 1]),
        # "other_text": "text-B",  # 不允许, 必须预先删除
    }
]
data_collator(data)

# {
#     'input_ids': tensor([[ 1000,  1234,  5679, 50256, 50256],
#         [ 1000, 11111,  2345,  1234,  5679]]),
#     'attention_mask': tensor([[1, 1, 1, 0, 0],
#         [1, 1, 1, 1, 1]]),
#     'labels': tensor([[ 1000,  1234,  5679,  -100,  -100],
#         [ 1000, 11111,  2345,  1234,  5679]])
# }

下载模型/数据相关

验证下载的正确性(主要适用于手动点击下载)

以下载 bert-base-uncased/pytorch_model.bin 文件为例

from huggingface_hub import model_info
model_info("bert-base-uncased", revision="main", files_metadata=True)

# 输出结果里包含如下输出
# RepoFile: { 
#     {'blob_id': 'ba5d19791be1dd7992e33bd61f20207b0f7f50a5',
#      'lfs': {'pointerSize': 134,
#              'sha256': '097417381d6c7230bd9e3557456d726de6e83245ec8b24f529f60198a67b203a',
#              'size': 440473133},
#      'rfilename': 'pytorch_model.bin',
#      'size': 440473133}

检验本地下载的数据是否与上面的信息一致

sha256sum pytorch_model.bin  # 097417381d6c7230bd9e3557456d726de6e83245ec8b24f529f60198a67b203a

再举一例, 下载一个大规模的数据集, 下载方式采用手动点击链接的方式, 首先生成 lfs 文件的期望 hash 值

# get_hash.py
from huggingface_hub import dataset_info
import json
repo_id = "Skywork/SkyPile-150B"
validate_filepath = "SkyPile-150B_sha256.json"
info = dataset_info(repo_id, revision="main", files_metadata=True)
sha256_info = {}
for file_info in info.siblings:
    filename = file_info.rfilename
    if file_info.lfs:
        sha256 = file_info.lfs['sha256']
        sha256_info[filename] = sha256
with open(validate_filepath, "w") as fw:
    json.dump(sha256_info, fw, ensure_ascii=False, indent=4)

验证 hash 值的代码: 每次下载完一些新的内容时, 执行一次此脚本进行验证

# validate.py
import hashlib
import json
import os

def get_sha256(filepath):
    with open(filepath, "rb") as f:
        chunksize = 50 * 1024 * 1024
        m = hashlib.sha256()
        while True:
            chunk = f.read(chunksize)
            if not chunk:
                break
            m.update(chunk)
    return m.hexdigest()


validate_filepath = "SkyPile-150B_sha256.json"
state_filepath = "local_download_state.json"
root_path = "./"

with open(validate_filepath, "r") as fr:
    sha256_info = json.load(fr)
    
if os.path.exists(state_filepath):
    with open(state_filepath) as fr:
        state_dict = json.load(fr)
else:
    state_dict = {key: "undownload" for key in sha256_info}


for prefix, folders, filenames in os.walk(root_path):
    for filename in filenames:
        filepath = os.path.join(prefix, filename)
        relpath = os.path.relpath(filepath, root_path).replace("\\", "/")  # Windows
        if relpath in sha256_info:
            if state_dict[relpath] != "matched":
                expected_sha256_value = sha256_info[relpath]
                local_sha256_value = get_sha256(filepath)
                if expected_sha256_value == local_sha256_value:
                    print(relpath, "\033[0;32m matched \033[0m")  # green
                    state_dict[relpath] = "matched"
                else:
                    print(relpath, "\033[0;31m unmatched, please redownload !! \033[0m")  # red
                    state_dict[relpath] = "unmatched"
            else:
                print(relpath, "\033[0;32m skip validate \033[0m")
            sha256_info.pop(relpath)

print("="*40)
print(f"There are {len(sha256_info)} files should be download")

for relpath in sha256_info:
    print(relpath, "\033[0;33m should download !! \033[0m")
    
with open(state_filepath, "w") as fw:
    json.dump(state_dict, fw, ensure_ascii=False, indent=4)