14.6B Tokens的最佳模型大小 | AI生成和翻译
机器:MI300X, 192 GB HBM3, 单 GPU
数据:14.6B 个 token(147 个分片 × 100M uint16,磁盘占用 28 GB)
现有配置:
| Config | Params | Architecture | Designed for |
|---|---|---|---|
| train_gpt2.py | 124M | 12L/768d/12h | 8×A100, 300B tokens |
| train_gpt2_200m.py | 219M | 18L/896d/16h | RTX 4070 12GB, 4B tokens |
| train_fineweb.py | 124M | 12L/768d/12h | 12GB VRAM |
| train_fineweb_gpt3.py | 124M | 12L/768d/12h | RTX 4070 12GB, 10B tokens |
| train_fineweb1_5b.py | 1.5B | 48L/1600d/25h | H200 143GB, ~7B tokens |
这些配置都不是为 192 GB VRAM + 14.6B tokens 设计的。 我们需要一个新的配置。
Chinchilla 缩放定律分析:
Chinchilla 论文(Hoffmann 等人,2022)指出,最优训练 token 数 ≈ 20× 参数量。因此:
- 124M → 2.5B tokens(我们拥有 14.6B → 严重过度训练)
- 350M → 7B tokens(接近,但我们更多)
- 760M → 15.2B tokens ← 14.6B tokens 的最佳点
- 1.5B → 30B tokens(我们将欠训练 2 倍)
- 3B → 60B tokens(严重欠训练)
但是 Chinchilla 是一个计算最优定律——它最小化每个 FLOP 的损失。如果我们有额外的 VRAM 并且不介意使用更多 FLOP,那么在更少的 token 上训练的更大模型仍然会给出更好的绝对损失。LLaMA-1 论文表明,在 1T tokens 上训练的 7B 模型在同等规模下优于 Chinchilla 最优模型。
内存预算(bf16 训练,Adam):
每个参数:2(bf16 参数)+ 4(fp32 主参数)+ 4(m)+ 4(v)+ 2(梯度)≈ 16 字节
| Model | Params | Train mem | + Activations | Total | Fits 192GB? |
|---|---|---|---|---|---|
| 760M | 760M | 12 GB | ~10 GB | ~22 GB | ✅ easily |
| 1.5B | 1.5B | 24 GB | ~20 GB | ~44 GB | ✅ easily |
| 3B | 3B | 48 GB | ~40 GB | ~88 GB | ✅ yes |
| 7B | 7B | 112 GB | ~60 GB | ~172 GB | ✅ tight |
| 13B | 13B | 208 GB | ~80 GB | ~288 GB | ❌ no |
推荐:3B 模型(GPT-2 3B)
为什么选择 3B 而非 7B:
- 3B 在 14.6B tokens 上训练 = 约 5 倍过度训练。这没问题——LLaMA-1 7B 在 1T tokens 上训练(Chinchilla 最优的 143 倍)。过度训练总是有助于降低绝对损失。
- 7B 在 14.6B tokens 上训练 = 约 2 倍欠训练。你会得到一个尚未收敛的模型。
- 3B 可以轻松容纳大批量大小 → 训练更快
- 约 5 倍过度训练的 3B 将击败约 2 倍欠训练的 7B
建议配置:GPT-2 3B
# 3B 参数:n_layer=48, n_head=32, n_embd=2560
# bf16 下每层约 10.2 GB
# 总训练内存:约 48 GB 参数/优化器 + 约 40 GB 激活值 ≈ 88 GB(远低于 192)
# 允许大批量大小以提高吞吐量
batch_size = 32 # 32 × 1024 = 32K tokens per micro-step
block_size = 1024
gradient_accumulation_steps = 64 # effective batch = 32 × 1024 × 64 = 2M tokens
n_layer = 48
n_head = 32 # head_size = 2560/32 = 80
n_embd = 2560
dropout = 0.0
bias = False
learning_rate = 3e-4
min_lr = 3e-5
warmup_iters = 2000
max_iters = 11500 # 11500 × 2M = 23B tokens(数据遍历约 1.6 倍)
lr_decay_iters = 11500
weight_decay = 0.1
beta2 = 0.95
grad_clip = 1.0
compile = True
替代方案:如果你想训练更长时间并更快迭代,可以从 1.5B(GPT-2 XL 架构)开始并进行多次遍历:
# 1.5B:重用 train_fineweb1_5b.py,但使用更大的批量
n_layer = 48
n_head = 25
n_embd = 1600
batch_size = 64 # 轻松适配 192GB
gradient_accumulation_steps = 32 # effective = 64 × 1024 × 32 = 2M tokens
max_iters = 15000 # 15000 × 2M = 30B tokens(数据遍历 2 倍)
你想要哪一种——3B(更好的最终损失,约 2-3 天)还是 1.5B(更快的迭代,可以尝试更多实验)?