AMD AI FP8内存布局修复 | AI生成和翻译

Home 2026.06

问题:您能解释一下这个终端输出和Python脚本中发生了什么吗?

答案:

AI代理(Hermes)正在调试一个底层内存布局错误,以使FP8(8位浮点数)矩阵乘法能够在AMD Instinct MI300X GPU上运行,使用ROCm 7.2

以下是问题的逐步分解、核心概念以及解决方式。


1. 核心问题:行优先 vs. 列优先

深度学习硬件加速器(如AMD的MI300X或NVIDIA的H100)拥有高度优化的矩阵乘法引擎。PyTorch通过torch._scaled_mm暴露了这一底层特性。

为了达到最大硬件效率,_scaled_mm期望特定的内存布局:

当Hermes最初尝试使用.to(torch.float8_e4m3fnuz)将矩阵B转换为FP8时,PyTorch隐式地分配了一个全新的张量。问题在于PyTorch的FP8转换会自动将内存布局重置为行优先,从而破坏了矩阵乘法所需的结构要求。


2. 解决方案说明

为了“欺骗”PyTorch,使其向硬件提供完全正确的内容,Hermes操纵了步长(strides,即移动到下一行或下一列所需的内存步数)。

Hermes没有在转换之前强制使用列优先布局,而是改变了操作顺序:

  1. 转置矩阵 B: b.t() 改变其维度。
  2. 使其连续: .contiguous() 强制将其锁定为内存中一个干净的行优先块。
  3. 转换为FP8: b_t.to(...) 安全地将其转换为8位浮点数,同时保持行优先。
  4. 关键技巧: 当传递给_scaled_mm(a_fp8, b_t_fp8)时,该函数会将第二个输入视为隐式转置($A \times B^T$)。由于该输入在结构上已被转置且为行优先,读取时再进行转置,硬件就会将其完美地视为列优先矩阵。

3. 基准测试结果

一旦内存布局不匹配问题得到修复,该脚本成功运行了一个性能基准测试,比较了16位脑浮点数(BF16)和新的8位浮点数(FP8)。

精度 执行时间 加速比 状态
BF16(标准) 标准基线 1.00x 正常
FP8(优化后) 执行更快 1.39倍 完全可用

注意: 在小型矩阵($256 \times 512$)上获得1.39倍加速已经非常出色。由于硬件张量在大型工作负载下表现更优,在实际的大型语言模型层尺寸(如$4096 \times 4096$)上,这一加速比会显著提升(通常可达2倍到3倍)。

Hermes接下来会做什么?

既然概念验证脚本已经证明FP8在ROCm 7.2上完美运行,Hermes正准备修改nanochat应用程序的源代码(fp8.py),以永久性地为AMD硬件实现这个内存布局补丁。


Back Donate