ROCm 矩阵乘法源代码 | AI生成和翻译
现在我对ROCm的矩阵乘法源代码有了透彻的理解。以下是全面的解释:
ROCm如何在源代码中实现矩阵乘法
ROCm主要通过三个库来实现矩阵乘法,每个库提供不同层次的抽象:
1. rocBLAS — BLAS库入口点
文件: library/src/blas3/rocblas_gemm.hpp
rocBLAS实现了标准的BLAS GEMM(通用矩阵乘法)API:C = α·op(A)·op(B) + β·C。
主机端流程从rocblas_gemm_impl()(位于rocblas_gemm_imp.hpp)开始。在验证参数并将α/β复制到主机后,它调用:
status = ROCBLAS_API(rocblas_internal_gemm_template)(handle, trans_a, trans_b,
m, n, k, alpha, A, 0, lda, 0, B, 0, ldb, 0, beta, C, 0, ldc, 0, 1);
这会分发到Tensile后端。
2. Tensile — 内核生成引擎(主要GEMM后端)
Tensile是一个YAML驱动的内核生成器,用于创建手工优化的汇编内核。rocBLAS ↔ Tensile的接口位于:
文件: library/src/tensile_host.cpp
分发流程
rocblas_gemm()
→ rocblas_internal_gemm_template()
→ runContractionProblem() [tensile_host.cpp]
→ ConstructTensileProblem() 构建Tensile::ContractionProblem
→ GetTensileInputs() 设置A、B、C、D的GPU指针
→ library->findBestSolution() 从预调优库中选择最佳内核
→ adapter.launchKernels() 启动GPU内核
关键代码(来自tensile_host.cpp):
// 针对此问题规模找到最佳GPU内核
solution = library->findBestSolution(tensile_prob, *hardware, fitness_query);
// 启动内核
hipError_t hip_status = adapter.launchKernels(
solution->solve(tensile_prob, GetTensileInputs(prob), *hardware),
handle->get_stream(), ...);
Tensile预编译了数千个经过调优的内核变体(针对不同的M、N、K大小、数据类型、GPU架构),并将其存储为.co代码对象文件,位于/opt/rocm/lib/rocblas/library/。
Tensile内核编写器(汇编)
文件: Tensile/KernelWriterAssembly.py
Tensile生成实际的GCN/AMDGPU汇编(.s文件)。内核编写器会生成类似v_mfma_f32_16x16x4f32的MFMA指令。例如:
# 来自KernelWriterAssembly.py
class KernelWriterAssembly(KernelWriter):
def __init__(self, ...):
self.do["MAC"] = True # 乘加运算
self.do["GlobalReadA"] = True
self.do["GlobalReadB"] = True
self.do["LocalWrite"] = True
self.do["GlobalWrite"] = True
它会生成类似以下的汇编代码:
v_mfma_f32_16x16x4f32 v[0:3], v4, v5, v[0:3] // C += A * B
3. Composable Kernel (CK) — 现代C++模板库(较新方法)
仓库: https://github.com/ROCm/composable_kernel
CK是一种基于现代C++模板的方法。它采用基于tile的编程模型,构建在AMDGPU内建函数之上。
三层层次结构
第一层 — 网格级GEMM(内核入口):
GridGemm
└─ BlockGemm (每个线程块)
└─ WarpGemm (每个波前)
└─ MFMA / WMMA指令
第二层 — 块级GEMM(基于共享内存):
文件: ck/tutorial/ck_tile/gemm/01_naive_gemm/block_gemm_asmem_bsmem_creg.hpp
// C += A * B (A和B来自共享内存,C在寄存器中)
template <typename Problem, typename Policy>
struct BlockGemmASmemBSmemCReg {
template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c,
const ABlockWindow& a,
const BBlockWindow& b) const {
// 在内层循环中遍历K:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
AWarpTensor a_warp = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
BWarpTensor b_warp = load_tile(b_warp_windows(nIter)(kIter));
WarpGemm{}(c_warp, a_warp, b_warp); // ← 实际的乘加运算
});
});
});
}
};
第三层 — 波前级GEMM(实际的MFMA指令):
文件: ck/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
这里是实际的GPU矩阵乘法指令被调用的地方:
// F32 GEMM:使用MFMA的16x16x4 tile
struct WarpGemmAttributeMfmaImplF32F32F32M16N16K4 {
// c_vec += a_vec * b_vec
template <typename... Params>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const AVecType& b_vec) const {
#if defined(__gfx9__)
// 直接使用内建函数
c_vec = __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0);
#else
// 或使用内联汇编
asm volatile("v_mfma_f32_16x16x4f32 %0, %1, %2, %3\n"
: "+v"(c_vec)
: "v"(a_vec), "v"(b_vec), "v"(c_vec));
#endif
}
};
4. hipBLASLt — 轻量级备用后端
rocBLAS还支持hipBLASLt作为轻量级后端(位于hipblaslt_host.cpp)。这提供了使用hipBLASLt的hipblasLtMatmul() API的替代路径。
5. rocWMMA — 波前矩阵乘加API
文件: rocwmma/internal/mma_impl.hpp
rocWMMA提供了一个库级别的WMMA API,封装了硬件MFMA指令,支持不同的数据布局(行主序/列主序)。
架构总结
PyTorch / TensorFlow / 用户应用
│
┌────▼────┐
│ rocBLAS │ ← C = α·A·B + β·C (BLAS GEMM API)
└────┬────┘
│
┌────▼──────────┐
│ Tensile │ ← YAML定义、预调优的汇编内核
│ (主要) │ 加载目标GPU的.co代码对象
└────┬──────────┘
│
┌────▼──────────────────┐
│ Composable Kernel (CK)│ ← 现代C++模板库
│ (较新/可选) │ 基于tile,使用内建函数
└────┬──────────────────┘
│
┌────▼────────────┐
│ AMDGPU ISA │
│ v_mfma / v_wmma │ ← 硬件矩阵指令
└─────────────────┘
在AMD GPU上,实际的矩阵乘法最终归结为MFMA(矩阵融合乘加)指令——这是NVIDIA Tensor Core的AMD等效物——通过内联汇编或编译器内建函数(如__builtin_amdgcn_mfma_f32_16x16x4f32)来调用。