@shreyansh_26: https://x.com/shreyansh_26/status/2069125463860302212
摘要
本文介绍了Decompose-K技术,用于加速瘦高大K矩阵乘法,通过将K维度分割成块,执行批量矩阵乘法,并求和部分结果。还提供了PyTorch实现和基准测试,显示对于形状不佳的矩阵乘法,相比标准torch.compile有显著加速。
查看缓存全文
缓存时间: 2026/06/22 21:53
Decompose-K:从 torch.compile 到手工调优的 Triton 内核,用于瘦高型大规模K矩阵乘法
本文源码托管于 GitHub – shreyansh26/MLSys-Experiments/decompose-k
Decompose-K 的思想以及自定义算子自动调优流程源自 PyTorch Conference 报告“闪电演讲:通过子图融合和自定义算子自动调优实现比 SOTA 内核更快的 torch.compile”——Elias Ellison & Paul Zhang, Meta。本文是围绕该思想自主实现的代码走读和基准测试研究。
瘦高型大规模K矩阵乘法问题
标准矩阵乘法为 C[M, N] = A[M, K] @ B[K, N],GPU GEMM 通过分块 M×N 输出矩阵来提取并行性。每个程序拥有 C 的一个 BLOCK_M × BLOCK_N 分块,并在 K 维上进行累加。当 M 和 N 较大时,这种方法效果很好,因为输出分块数量多,GPU 有足够多的独立工作来填满其流式多处理器 (SM)。问题场景是瘦高型、K主导的矩阵乘法:M 和 N 很小,而 K 很大。比如 M = N = 16,K = 32768,或者解码时的 MoE 路由器 GEMM,如 [T, 7168] @ [7168, 256],其中 T 小至 1。此时输出仅为 16×16=256 个元素,相当于一两个分块。GPU 有 132 个 SM 空闲,而一两个程序在长度为 32768 的归约上串行执行。这种矩阵乘法受限于归约操作,但标准的分块策略几乎无法在唯一的大维度上提供并行性。Decompose-K 正是为了解决这种不匹配而提出的重构方法。基本思想很简单:如果唯一大的维度是 K,那么将 K 拆分,并在拆分后的维度上进行并行化。
Decompose-K 的原理
Decompose-K 将长 K 维度拆分成 S 个块,以批处理矩阵乘法的形式运行 S 个部分 GEMM,并将部分结果求和(可在归约存储时融合可选的 epilogue)。具体操作为:
将 K 维拆分成 S 个独立的块,计算 S 个部分 GEMM,然后对部分结果求和:
A[M, K] @ B[K, N] -> partials[S, M, N] -> sum(partials, dim=0)
每个部分是在 K/S 归约长度上的更小矩阵乘法。S 个部分相互独立,因此变成了一个批次维度为 S 的批处理矩阵乘法(bmm)。简化的 PyTorch 实现如下:
def decomposeK(a, b, k_splits):
m, k = a.shape
n = b.shape[1]
assert k % k_splits == 0, "k must be divisible by k_splits"
k_parts = k // k_splits
# [m, k] -> [m, k_splits, k_parts] -> [k_splits, m, k_parts]
a_reshaped = a.reshape(m, k_splits, k_parts).permute(1, 0, 2)
b_reshaped = b.reshape(k_splits, k_parts, n)
# [k_splits, k_parts, n]
result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32)
reduced_result = result.sum(dim=0)
return reduced_result.to(a.dtype)
关键在于 reshape 带来的好处。对于 M = N = 16,K = 32768,S = 64:
- a_reshaped 为 [64, 16, 512],b_reshaped 为 [64, 512, 16]。
- 现在 bmm 包含 64 个独立的矩阵乘法,而不是一个。这些是调度器可以在 SM 之间分布的 64 个单位工作量,而之前只有一个输出分块。
- 每个部分仅累加 512 的归约长度,而非 32768。我们用 S 个短并行归约替代了一个长串行归约,外加一步对 S 个部分的最终归约。
部分结果以 fp32 累加(out_dtype=torch.float32),因此拆分不会损失相对于单一 fp32 累加矩阵乘法的精度。这本质上就是 split-K,但以张量层次的 bmm 加归约来表达,而不是通过原子操作累加到单个输出分块。这种区别在添加 epilogue 时变得重要,这是下一部分的内容。
为何 epilogue 友好
使用原子操作累加到输出的 split-K 设计在融合元素级 epilogue(如 ReLU)时会遇到问题:在所有拆分完成其原子贡献之前,输出分块并非最终状态,因此无法在累加过程中应用 ReLU。您需要等待所有原子操作完成后,再单独执行 pass。Decompose-K 将部分结果保存在独立的 [S, M, N] 缓冲区中,并进行显式归约。这意味着归约步骤是每个输出元素变为最终值的自然且唯一的位置,因此 epilogue 可以直接折叠到归约的存储中:
acc = sum over splits of partials[:, m, n]
acc = relu(acc) # 融合到同一个内核中
store C[m, n] = acc
无需额外的逐点运算 pass 在 C 上执行,也无需对输出进行二次读写。对于输出很小且 epilogue 受限于内存带宽的情况,这能带来实际的节省,我们稍后将进行测量(相对于非融合 ReLU 约 1.2 倍到 1.4 倍)。
适用场景
Decompose-K 在以下场景中具有吸引力:
- K 非常大,M/N 很小(例如 M = N = 16..64,K = 8192..32768)。
- 工作负载对延迟敏感,特定形状比通用 GEMM 更重要。
- 具体例子:DeepSeek-V3 MoE 路由器 GEMM
[T, 7168] @ [7168, 256],解码时 T 动态变化且很小(1..256),预填充时 T 较大。 - 像 ReLU 这样的融合 epilogue 可以随归约一起完成。
当 M 和 N 已经足够大可以填满 GPU、K 很小、K 难以被候选拆分数量整除、或者额外的 [S, M, N] 缓冲区及其归约占主导成本时,Decompose-K 则不值得使用。
本文后续部分将逐一介绍围绕该思想的实现,从最简单的方式(torch.compile)到手工编写的 Triton 内核,最终超越 Inductor 自身的自动调优选择。所有基准测试均在 H100(132 SM)上以 BF16 格式运行,测试网格为 M = N ∈ {16, 32, 48, 64},K ∈ {8192, …, 32768}。
基线:直接调用 torch.compile
首先尝试用普通 PyTorch 编写 decomposeK,让 Inductor 处理其余部分。相关细节是编译模式。在本文使用的三个基准测试套件中(带融合 ReLU epilogue 的 BF16 矩阵乘法(epilogue-bf16)、普通 BF16 矩阵乘法(matmul-bf16)和普通 FP32 矩阵乘法),max-autotune-no-cudagraphs 是最佳模式,优于 max-autotune:
decomposeK_compiled = torch.compile(decomposeK, mode="max-autotune-no-cudagraphs")
max-autotune 启用 Inductor 的模板自动调优(它会基准测试几个生成的内核并选择最快的)。-no-cudagraphs 变体跳过 CUDA 图捕获,对于这些微小的单次调用,这避免了捕获开销,同时保留了自动调优的好处。
简单编译实际生成了什么?
编译上述 decomposeK 函数(针对路由器形状 [64, 7168] @ [7168, 256],S = 4)会产生两个操作,可以从 Inductor 输出代码中读出:
# extern bmm 进入 fp32 partials 缓冲区
buf0 = empty_strided_cuda((4, 64, 256), (16384, 256, 1), torch.float32)
extern_kernels.bmm_dtype(
reinterpret_tensor(arg0_1, (4, 64, 1792), ...),
reinterpret_tensor(arg1_1, (4, 1792, 256), ...),
out_dtype=torch.float32, out=buf0)
# 一个生成的逐点内核:在 4 个拆分上求和 + 转换为 bf16
triton_poi_fused__to_copy_sum_0.run(buf0, buf1, 16384, ...)
因此,bmm 进入外部(cuBLAS)批处理内核,而 sum(dim=0) 加 .to(bf16) 转换被融合到一个生成的 Triton 逐点内核中。生成的归约内核实际上是 S=4 个切片的展开加法:
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tl.load(in_ptr0 + (16384 + x0), None)
tmp3 = tl.load(in_ptr0 + (32768 + x0), None)
tmp5 = tl.load(in_ptr0 + (49152 + x0), None)
tmp7 = (tmp0 + tmp1 + tmp3 + tmp5)
tl.store(out_ptr0 + (x0), tmp7, None)
这是用 PyTorch 编写的显式 Decompose-K 的调用图:bmm + 融合求和/转换内核。如果编写带 ReLU epilogue 的版本,融合内核将进一步包含 maximum(0, x),因此得到 bmm + 融合求和+relu 内核。epilogue 是免费的,因为它附加在必须运行的归约内核上。
如果只写 relu(mm(a, b)) 让 Inductor 决定会怎样?
这是一个更有趣的问题,因为这里使用的 PyTorch 夜间版(torch==2.12.0.dev20260408+cu128)在 Inductor 内部已经包含了 Decompose-K 的 lowering——参见 https://github.com/pytorch/pytorch/blob/main/torch/_inductor/template_heuristics/decompose_k.py 以及它在 https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/mm.py 中注册的子图选择。因此,对于大 K 形状,Inductor 会自动采用分解,并将其作为候选之一与常规矩阵乘法模板一起自动调优。该验证代码编译了普通的 torch.relu(torch.mm(a, b)),并在两个 K 值下输出生成的代码。
小 K(M = N = 16,K = 256)——Inductor 生成单个融合矩阵乘法模板 triton_tem_fused_mm_relu_0,源节点为 [aten.mm, aten.relu]。ReLU 被融合到矩阵乘法模板的存储后缀中:
# inductor 的模板后缀,位于矩阵乘法内核内部
tmp1 = triton_helpers.maximum(tmp0, acc) # relu
tl.store(out_ptr1 + xindex, tmp1, mask)
一个内核,ReLU 已融合,完成。小 K 时没有理由分解。
大 K(M = N = 16,K = 32768)——现在 Inductor 自行选择 Decompose-K。生成图名为 decompose_k_mm_64_split_5(它选择了 S=64,因此 k_part=512),包含三个部分:
# 1) 通过 cuBLAS 的批处理部分矩阵乘法,fp32 累加
extern_kernels.bmm_dtype(
reinterpret_tensor(arg0_1, (64, 16, 512), ...),
reinterpret_tensor(arg1_1, (64, 512, 16), ...),
out_dtype=torch.float32, out=buf0)
# buf0: [64, 16, 16] fp32
# 2) 在 64 个拆分上生成的归约
triton_per_fused_mm_0.run(buf0, buf2, 256, 64, ...)
# 3) 一个单独的逐点 relu 内核
triton_poi_fused_relu_1.run(buf1, 256, ...)
需注意第 3 部分。当 Inductor 采用 Decompose-K lowering 时,它将 ReLU 作为单独的 triton_poi_fused_relu_1 逐点内核在归约之后发射。它并没有将 ReLU 融合到 Decompose-K 的归约/存储中。这意味着对输出缓冲区进行了一次额外的完整读写。对于微小的 16×16 输出,绝对值很小,但正是手工编写内核可以重新利用的融合机会,也是本文剩余部分所要追求的性能差距。
因此,我们有两个基础事实:Decompose-K 在大 K 时是正确的结构(Inductor 也认同),而 Inductor 的默认 lowering 未融合 epilogue。现在该我们自行编写内核了。
手工编写的 Triton 内核
源码:kernels/decompose_k_triton_kernel.py - https://github.com/shreyansh26/MLSys-Experiments/tree/main/decompose-k/kernels/decompose_k_triton_kernel.py
该内核分为两个阶段,与上述结构对应:一个部分矩阵乘法内核填充 [S, M, N],一个归约/epilogue 内核在 S 上求和,并在存储时可选地应用 ReLU。
第一阶段:部分矩阵乘法
部分矩阵乘法内核使用二维启动网格:program_id(0) 索引 M×N 输出分块(采用常见的 L2 友好的组主 swizzle),program_id(1) 索引拆分。每个程序计算一个拆分的一个 BLOCK_M × BLOCK_N 分块,仅累加其 K // SPLIT_K 切片。
@triton.jit
def _partial_mm(a, b, partials, ...):
pid = tl.program_id(0)
split_id = tl.program_id(1)
# 对 pid 进行组主 swizzle -> (pid_m, pid_n) 以实现 L2 重用
...
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
k_per_split = K // SPLIT_K
split_k_start = split_id * k_per_split
acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32)
for k0 in range(0, k_per_split, BLOCK_K):
k_offsets = k0 + offs_k
a_ptrs = a + offs_m[:, None] * stride_am + (split_k_start + k_offsets[None, :]) * stride_ak
b_ptrs = b + (split_k_start + k_offsets[:, None]) * stride_bk + offs_n[None, :] * stride_bn
k_mask = k_offsets < k_per_split
a_vals = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & k_mask[None, :], other=0.0)
b_vals = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0)
acc += tl.dot(a_vals, b_vals, out_dtype=tl.float32, input_precision=INPUT_PRECISION)
partial_ptrs = partials + split_id * stride_ps + offs_m[:, None] * stride_pm + offs_n[None, :] * stride_pn
tl.store(partial_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
值得指出的几个细节:
- 累加器为 fp32,无论输入精度如何,
input_precision对于 fp32 输入设为"ieee",否则为"tf32"。这保证了拆分不会改变与单次累加矩阵乘法相比的数值行为。 split_k_start = split_id * k_per_split是区分不同拆分程序的唯一因素。每个拆分读取 K 中连续的一段k_per_split数据。- 存储写入按拆分索引的
partials[split_id]切片。没有原子操作:每个(split_id, tile)对拥有 partials 缓冲区中的不相交区域。
第二阶段:归约 + 融合 epilogue
归约启动每个输出分块一个程序,循环遍历所有 SPLIT_K 个 part,累加到一个分块形状的累加器,如果要求则应用 ReLU,然后存储:
@triton.jit
def _reduce_epilogue(partials, c, ..., SPLIT_K, BLOCK_M, BLOCK_N, FUSE_RELU):
# 与之前相同的 (pid -> pid_m, pid_n) swizzle
...
acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32)
for split_id in range(0, SPLIT_K):
acc += tl.load(ptrs + split_id * stride_ps, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0)
if FUSE_RELU:
acc = tl.maximum(acc, 0.0)
tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
这就是 Inductor 的 Decompose-K lowering 没有实现的融合:ReLU 在寄存器中应用,然后一次性存储 C,无需单独的逐点传递。从正确性角度,这是安全的,因为显式归约正是每个输出元素首次成为最终值的地方。
这个内核正确且合理,但它有一个结构性局限,即归约的并行化方式——稍后我们会指出这一点。令人惊讶的是,这个手工编写的内核实际上没有击败 Inductor。
自定义算子自动调优:让 Inductor 选择分解
源码:custom_op_autotune_relu_dispatch.py - https://github.com/shreyansh26/MLSys-Experiments/tree/main/decompose-k/custom_op_autotune_relu_dispatch.py
Inductor 暴露了一个 API register_custom_op_autotuning,允许您提供一个算子的多个替代分解列表,让 Inductor 为每个形状进行基准测试并选择最优,然后将最优者进行 lower。巧妙之处在于,目标算子可以是真实的 @torch.library.custom_op 或者现有的 ATen 重载,如 torch.ops.aten.mm.default。因此,您可以拦截编译图中每个 torch.mm 的 lowering。
候选包括普通 mm(或 mm + relu)以及每个有效拆分数的 Decompose-K 分解:
K_SPLITS = (2, 4, 8, 16, 32, 64, 128, 256)
def generate_mm_relu_configs(fake_tensors):
k = int(fake_tensors["a"].shape[1])
splits = [s for s in K_SPLITS if k % s == 0]
configs = [CustomOpConfig(mm_relu_impl)]
configs += [CustomOpConfig(decompose_k_relu_impl, k_splits=s) for s in splits]
return configs
decompose_k_relu_impl 就是本文开头的 PyTorch 级 bmm + sum + relu;我们不是向 Inductor 提供 Triton 内核,而是提供几个数学上等价的 PyTorch 分解,让 Inductor 对每个进行 lower 和计时。
该脚本在两个不同边界进行注册,每个都有匹配的配置生成器,以便覆盖普通矩阵乘法和融合的 matmul+ReLU 情况:
**aten.mm 边界**——generate_mm_configs,键为self/mat2。候选包括mm_impl(普通的torch.mmlower)以及每个能整除 K 的decompose_k_impl(k_splits=s)。ReLU 保持在自动调优的算子外部,作为一个单独的逐点内核。- 融合的
mm_relu自定义算子边界——generate_mm_relu_configs,键为a/b。候选包括mm_relu_impl(普通的relu(mm))以及每个能整除 K 的decompose_k_relu_impl(k_splits=s)。
相似文章
@shreyansh_26: 当 M 和 N 很小而 K 很大时,如何让矩阵乘法变快?(MoE routers、small-batch decode。)Decompose-K: …
一种加速矩阵乘法的技术,适用于 M 和 N 较小而 K 较大的情况(如 MoE routers 和 small-batch decoding),通过分解 K 并并行运行部分 GEMM,然后将 epilogue 折叠到归约存储中。该方法使用自定义 Triton 内核,在大多数形状上击败了 PyTorch Inductor。
@leloykun:[进行中] 关于 Lean4-to-TileLang 张量程序超级优化器的博文:
一篇技术博文介绍了一种 Lean4-to-TileLang 张量程序超级优化器,能自动生成优化的 GPU/TPU 内核与超参数缩放规律,展示了相较 torch.compile 的性能提升。
Block-sparse GPU kernels
OpenAI 发布 block-sparse GPU kernels,这是一款用于在 GPU 上进行高效稀疏矩阵乘法的工具,可以减少神经网络操作的计算量和内存占用。
@leloykun: 我又忙忘了时间 >.< 最近如果给我发过私信,真的非常抱歉。我保证会逐一查看!--- 在本次迭代中,我……
作者开发了一个从 Lean4 到 TileLang 的张量程序超优化器,能够自动生成优化后的加速器内核并推导超参数缩放定律,在 A100 GPU 上实现了 1.8 倍的加速。
@hamzaelshafie: 新深度博客文章:《剖析ThunderKittens:高性能AI内核的紧凑型DSL解剖》这篇帖子……
一篇详细分析ThunderKittens的博客文章,ThunderKittens是用于高性能AI内核的紧凑型DSL。文章包括从底向上的抽象分析,以及一个实现非因果注意力预填充内核的基准测试,该内核比FlashAttention-2快约1.55倍,与FlashAttention-3性能相当。