OpenWebText数据集预处理脚本 | AI生成和翻译

Home 2025.11

概述

该Python脚本用于处理包含OpenWebText数据集的大型文本文件(这是一个网络抓取文本语料库,类似于用于训练GPT-2等模型的数据)。其目标是:

该脚本针对多核系统效率进行了优化,在分词时使用多进程处理。其设计灵感来自Flash Attention代码库中的数据加载模块(代码中已链接),该模块处理语言模型训练的类似预处理。注意:OpenWebText数据量巨大(约40GB未压缩),但本脚本假设已预下载本地的openwebtext.txt文件。输出文件小得多:train.bin约17GB(90亿个token),val.bin约8.5MB(440万个token)。

脚本开始时打印代理设置(可能是为了调试任何隐式下载时的网络问题,尽管此处没有显式下载)。默认使用8个工作进程进行分词。

逐步分解

1. 导入和初始设置

import os
import tarfile
from tqdm import tqdm
import numpy as np
import tiktoken
from huggingface_hub import hf_hub_download
from datasets import load_dataset # huggingface datasets
import datasets

print("HTTP_PROXY:", os.getenv("HTTP_PROXY"))
print("HTTPS_PROXY:", os.getenv("HTTPS_PROXY"))

# .map()调用中的工作进程数
# 建议使用约CPU核心数//2的值
num_proc = 8

# load_dataset()调用中的工作进程数
# 最佳数值可能与上面的num_proc不同,因为它还取决于网络速度
# 但通常比1更好
num_proc_load_dataset = num_proc

enc = tiktoken.get_encoding("gpt2")

datasets.logging.set_verbosity_info()

if __name__ == '__main__':保护确保主逻辑仅在脚本直接执行时运行(而不是导入时)

2. 读取和分割文本文件

if __name__ == '__main__':
    # 读取本地openwebtext.txt文件
    txt_file = os.path.join(os.path.dirname(__file__), 'openwebtext.txt')
    print(f"从本地文件读取: {txt_file}")

    # 读取文本内容
    texts = []
    with open(txt_file, 'r', encoding='utf-8', errors='ignore') as f:
        # 读取整个文件
        full_text = f.read().strip()

        # 首先尝试通过双换行符分割成文档
        documents = full_text.split('\n\n')

        # 如果只得到一个文档,通过单换行符分割
        if len(documents) <= 1:
            documents = full_text.split('\n')

        # 如果仍然只有一个文档,通过句点加空格分割
        if len(documents) <= 1:
            # 在句点加空格处分割,然后将句子重新连接
            sentences = full_text.split('. ')
            # 将句子分组为每个文档约100句的块
            chunk_size = 100
            for i in range(0, len(sentences), chunk_size):
                chunk = '. '.join(sentences[i:i+chunk_size])
                if chunk.strip():
                    texts.append(chunk.strip() + '.')
        else:
            # 处理从双/单换行符分割得到的文档
            for doc in documents:
                doc = doc.strip()
                if doc:  # 仅添加非空文档
                    texts.append(doc)

        print(f"从文本文件创建了{len(texts)}个文档")

3. 创建和分割数据集

    # 从文本创建数据集
    dataset = datasets.Dataset.from_dict({'text': texts})

    # 从10k示例创建训练/验证分割
    split_dataset = dataset.train_test_split(test_size=0.0005, seed=2357, shuffle=True)
    split_dataset['val'] = split_dataset.pop('test') # 将测试分割重命名为val

4. 分词函数

    # 现在想要对数据集进行分词。首先定义编码函数(gpt2 bpe)
    def process(example):
        ids = enc.encode_ordinary(example['text']) # encode_ordinary忽略任何特殊标记
        ids.append(enc.eot_token) # 添加文本结束标记,例如gpt2 bpe的50256
        # 注意:我认为eot应该前置而不是追加...嗯。但它被称为"eot"...
        out = {'ids': ids, 'len': len(ids)}
        return out

5. 应用分词

    # 对数据集进行分词
    tokenized = split_dataset.map(
        process,
        remove_columns=['text'],
        desc="tokenizing the splits",
        num_proc=num_proc,
    )

6. 将分词数据保存到二进制文件

    # 将每个数据集中的所有id连接成一个大文件,可用于训练
    for split, dset in tokenized.items():
        arr_len = np.sum(dset['len'], dtype=np.uint64)
        filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
        dtype = np.uint16 # (可以这样做因为enc.max_token_value == 50256 < 2**16)
        arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))

        # 基于数据集大小使用自适应批处理大小
        total_batches = min(1024, len(dset))
        if total_batches < 1024:
            print(f"对{split}数据集使用{total_batches}个批次(大小:{len(dset)})")

        idx = 0
        for batch_idx in tqdm(range(total_batches), desc=f'写入{filename}'):
            # 仅当此批次索引对数据集大小有效时处理
            if batch_idx < len(dset):
                # 将样本批处理在一起以加快写入速度
                batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
                arr_batch = np.concatenate(batch['ids'])
                # 写入mmap
                arr[idx : idx + len(arr_batch)] = arr_batch
                idx += len(arr_batch)
        arr.flush()

7. 输出和使用说明的注释

    # train.bin约17GB,val.bin约8.5MB
    # train有约90亿个token(9,035,582,198)
    # val有约440万个token(4,434,897)

    # 以后读取bin文件,例如使用numpy:
    # m = np.memmap('train.bin', dtype=np.uint16, mode='r')

潜在问题和注意事项

此脚本将原始文本转换为训练就绪的格式,针对像GPT这样的语言模型进行了优化。如果您对修改它有疑问(例如不同的分词器或块大小),请告诉我!


Back

x-ai/grok-4-fast

Donate