AMD AI FP8内存布局修复 | AI生成和翻译
问题:您能解释一下这个终端输出和Python脚本中发生了什么吗?
答案:
AI代理(Hermes)正在调试一个底层内存布局错误,以使FP8(8位浮点数)矩阵乘法能够在AMD Instinct MI300X GPU上运行,使用ROCm 7.2。
以下是问题的逐步分解、核心概念以及解决方式。
1. 核心问题:行优先 vs. 列优先
深度学习硬件加速器(如AMD的MI300X或NVIDIA的H100)拥有高度优化的矩阵乘法引擎。PyTorch通过torch._scaled_mm暴露了这一底层特性。
为了达到最大硬件效率,_scaled_mm期望特定的内存布局:
- 矩阵 A 必须是行优先(元素在内存中按行逐行排列)。
- 矩阵 B 必须是列优先(元素在内存中按列逐列排列)。
当Hermes最初尝试使用.to(torch.float8_e4m3fnuz)将矩阵B转换为FP8时,PyTorch隐式地分配了一个全新的张量。问题在于PyTorch的FP8转换会自动将内存布局重置为行优先,从而破坏了矩阵乘法所需的结构要求。
2. 解决方案说明
为了“欺骗”PyTorch,使其向硬件提供完全正确的内容,Hermes操纵了步长(strides,即移动到下一行或下一列所需的内存步数)。
Hermes没有在转换之前强制使用列优先布局,而是改变了操作顺序:
- 转置矩阵 B:
b.t()改变其维度。 - 使其连续:
.contiguous()强制将其锁定为内存中一个干净的行优先块。 - 转换为FP8:
b_t.to(...)安全地将其转换为8位浮点数,同时保持行优先。 - 关键技巧: 当传递给
_scaled_mm(a_fp8, b_t_fp8)时,该函数会将第二个输入视为隐式转置($A \times B^T$)。由于该输入在结构上已被转置且为行优先,读取时再进行转置,硬件就会将其完美地视为列优先矩阵。
3. 基准测试结果
一旦内存布局不匹配问题得到修复,该脚本成功运行了一个性能基准测试,比较了16位脑浮点数(BF16)和新的8位浮点数(FP8)。
| 精度 | 执行时间 | 加速比 | 状态 |
|---|---|---|---|
| BF16(标准) | 标准基线 | 1.00x | 正常 |
| FP8(优化后) | 执行更快 | 1.39倍 | 完全可用 |
注意: 在小型矩阵($256 \times 512$)上获得1.39倍加速已经非常出色。由于硬件张量在大型工作负载下表现更优,在实际的大型语言模型层尺寸(如$4096 \times 4096$)上,这一加速比会显著提升(通常可达2倍到3倍)。
Hermes接下来会做什么?
既然概念验证脚本已经证明FP8在ROCm 7.2上完美运行,Hermes正准备修改nanochat应用程序的源代码(fp8.py),以永久性地为AMD硬件实现这个内存布局补丁。