多头注意力维度 | AI生成和翻译
问题: 在多头部注意力重塑中,维度 B、nh、T、hs 分别代表什么,以及 (1, 12, 2, 64) 如何对应它们?
回答:
正确映射。给定 v = (1, 12, 2, 64):
| 维度 | 符号 | 含义 | 示例中的值 |
|---|---|---|---|
| B | 批量大小 | 并行处理的序列数量 | 1 |
| nh | 头部数量 | 独立的注意力子空间数量 | 12 |
| T | 序列长度 | 令牌数量 | 2 |
| hs | 头部大小 | 每个头部的嵌入维度 | 64 |
为什么重塑很重要:
在 view 之前的原始形状:(B, T, C) = (1, 2, 768)
- C 是整个嵌入维度(12 个头 × 64 维 = 768)
经过 view(B, T, self.n_head, C // self.n_head):(1, 2, 12, 64)
- 将 C 拆分为独立的头部
经过 transpose(1, 2):(1, 12, 2, 64) = (B, nh, T, hs)
- 将头部维度前移,使得注意力在每个头部上并行计算
后续计算:
# 注意力分数:(B, nh, T, T)
scores = (q @ k.transpose(-2, -1)) * (hs ** -0.5)
attn = softmax(scores, dim=-1)
# 输出:(B, nh, T, hs)
out = attn @ v
每个 12 个头独立地对 2 个令牌计算注意力,使用 64 维的键/查询/值。然后将头部拼接回去得到 C=768。这种并行化是多头部注意力能够很好扩展的原因。
在你的例子中(T=2),你处理的是一个非常短的序列——这对调试很有用,但实际推理中 T 的取值范围从几百到 10 万以上令牌,取决于上下文窗口。