多头注意力维度 | AI生成和翻译

Home 2026.06

问题: 在多头部注意力重塑中,维度 B、nh、T、hs 分别代表什么,以及 (1, 12, 2, 64) 如何对应它们?

回答:

正确映射。给定 v = (1, 12, 2, 64)

维度 符号 含义 示例中的值
B 批量大小 并行处理的序列数量 1
nh 头部数量 独立的注意力子空间数量 12
T 序列长度 令牌数量 2
hs 头部大小 每个头部的嵌入维度 64

为什么重塑很重要:

在 view 之前的原始形状:(B, T, C) = (1, 2, 768)

经过 view(B, T, self.n_head, C // self.n_head)(1, 2, 12, 64)

经过 transpose(1, 2)(1, 12, 2, 64) = (B, nh, T, hs)

后续计算:

# 注意力分数:(B, nh, T, T)
scores = (q @ k.transpose(-2, -1)) * (hs ** -0.5)
attn = softmax(scores, dim=-1)

# 输出:(B, nh, T, hs)
out = attn @ v

每个 12 个头独立地对 2 个令牌计算注意力,使用 64 维的键/查询/值。然后将头部拼接回去得到 C=768。这种并行化是多头部注意力能够很好扩展的原因。

在你的例子中(T=2),你处理的是一个非常短的序列——这对调试很有用,但实际推理中 T 的取值范围从几百到 10 万以上令牌,取决于上下文窗口。


Back Donate