最后,衷心感谢这个了不起的团队:@jcz42, Arjun, Driss, @tensorcore, @yoonrkim 和 @tri_dao!PDF: https://a…
摘要
CODA 引入了一种 GPU 内核抽象,将 transformer 计算重写为 GEMM-plus-epilogue 程序,减少内存受限操作,提高训练效率。
查看缓存全文
缓存时间: 2026/05/23 01:56
最后,特别感谢出色的团队:@jcz42、Arjun、Driss、@tensorcore、@yoonrkim 和 @tri_dao!PDF:https://arxiv.org/abs/2605.19269 代码:https://github.com/HanGuo97/coda-kernels…
CODA:将Transformer块重写为GEMM-尾声程序
来源:https://arxiv.org/html/2605.19269 韩郭¹ 张杰克² 阿琼·梅农² 德里斯·盖索斯⁴ 维贾伊·塔卡尔⁴ 尹·金¹ 道三²,³ ¹麻省理工学院 ²普林斯顿大学 ³Together AI ⁴Meta [email protected]
摘要
Transformer训练系统建立在密集线性代数之上,然而,端到端时间中相当大的一部分花在了外围的内存受限操作上。归一化、激活函数、残差更新、归约及相关计算会反复将大型中间张量移入全局内存,同时执行很少的算术运算,这使得数据移动在原本高度优化的训练栈中成为日益重要的瓶颈。我们引入了CODA,一种GPU内核抽象,将这些计算表达为GEMM-加-尾声程序。CODA基于这样一个观察:许多暴露为独立框架内核的Transformer操作符可以通过代数重新参数化,在GEMM输出块仍在芯片上时执行,然后再写入内存。该抽象固定了GEMM主循环,并公开了一小组可组合的尾声原语,用于缩放、归约、逐对变换和累积。这种受限接口保留了专家编写的GEMM的性能结构,同时其表达能力足以覆盖标准Transformer块前向和反向传播中几乎所有非注意力计算。在代表性的Transformer工作负载上,无论是人类编写还是LLM编写的CODA内核都实现了高性能,这表明GEMM-加-尾声编程为将框架级生产力与硬件级效率相结合提供了一条实用路径。¹¹代码可在https://github.com/HanGuo97/coda-kernels获取。
1 引言
参见图注 图1:使用TorchTitan在单个H100上训练LLaMA-3风格1B模型的运行时分解。LLM训练已经成为一个同样重要的系统问题,而不仅仅是建模问题。现代基于Transformer的LLM中的FLOPs主要由矩阵乘法和注意力主导,其内核已针对Tensor Core执行进行了高度优化。然而,Transformer以及更广泛的深度学习架构,也包含归一化、激活函数、残差更新、归约以及其他带宽受限的操作,这些操作在内存中移动大型张量,同时进行的算术运算很少。先前的工作表明,数据移动是Transformer训练中的核心瓶颈[7 (https://arxiv.org/html/2605.19269#bib.bib7)];如图1 (https://arxiv.org/html/2605.19269#S1.F1)所示,当使用TorchTitan[11 (https://arxiv.org/html/2605.19269#bib.bib11)]在单个H100上训练LLaMA-3风格1B模型时,这些非GEMM操作占用了端到端运行时中不可忽视的一部分。随着硬件通过FP8和FP4等格式越来越加速低精度矩阵乘法,这个瓶颈变得更加重要,因为实例化中间张量的成本并没有以同样的速度提升。
现有的编程模型使得这个问题难以解决。诸如PyTorch之类的高级框架将Transformer块表达为操作符序列,自动求导使得反向传播同样便捷。这很有生产力,但操作符边界往往成为实例化边界,并掩盖了前向和反向计算中的融合机会。因此,生产级LLM系统通常会绕过框架抽象,采用手写的反向传播或自定义内核,如大规模LLaMA训练[5 (https://arxiv.org/html/2605.19269#bib.bib5)]和推理系统[9 (https://arxiv.org/html/2605.19269#bib.bib9),25 (https://arxiv.org/html/2605.19269#bib.bib25)]中所示。这项工作探讨是否存在一个中间地带。也就是说,是否可以在不放弃可编程性和自动化所需结构的情况下,恢复自定义内核的大部分性能?
我们的出发点是,许多在框架层面表现为独立操作符的Transformer计算,可以通过代数重新参数化为GEMM-加-尾声程序(图2 (https://arxiv.org/html/2605.19269#S1.F2))。在这种形式中,一个高度优化的GEMM主循环生成输出块,而一个可编程的尾声则在结果写入内存之前执行块局部变换(图3 (https://arxiv.org/html/2605.19269#S2.F3))。这在GPU上是高效的,因为尾声操作在GEMM块已经产生的数据上运行,避免了中间张量的额外全局内存往返。借助现代流水线调度,这个尾声工作通常可以隐藏在其他块主循环的影子中,如Hopper Ping-Pong GEMM和Blackwell基于TMEM的流水线。因此,我们将尾声扩展到不仅仅是简单后处理(如缩放或偏置添加)的地方,将其提升为一个结构化接口,用于将内存受限计算融合到GEMM块的整个生命周期中。
基于上述,我们引入了CODA,一种实现此接口的内核抽象原型。CODA保持GEMM主循环固定,并公开了一小组可组合的尾声原语,用于缩放、归约、逐对变换和累积。这种编程模型有意设置得受限,但具有表现力,因为在重新参数化后,这些原语几乎覆盖了标准Transformer模型的整个前向和反向传播,同时保持了效率。CODA在中间张量实例化到全局内存之前,将计算插入到已知高性能GEMM的尾声中,捕获了围绕密集线性代数的一大类内存受限计算。Transformer是我们的主要应用,但同样的GEMM-加-尾声视图更广泛地适用于每当高吞吐量矩阵乘法被可块表达、数据移动受限的计算所包围时。
最后,这种结构使自动化变得更加实用。尾声融合在高性能GEMM库中已经确立,但将其应用于Transformer工作负载仍然是一项低层次的工程任务。CODA的目标是通过在调优的GEMM主循环之上提供Transformer特定的尾声原语来填补这一空白。基于人类或LLM的作者可以将这些原语组装成重新参数化的Transformer内核,而不是合成任意的CUDA。在代表性的工作负载上,两种创作模式都实现了高性能,这表明领域特定的尾声抽象可以使既定的GEMM融合技术对LLM内核更具可编程性。
参见图注 图2:标准Transformer层的前向传播。顶行显示了规范公式,它映射到计算型和内存型内核的混合。我们对计算进行重新参数化,以便大多数内存受限操作被吸收到计算型内核的尾声中。
2 背景与相关工作
2.1 LLM系统的编程模型
现代LLM系统在多个抽象层级上进行编程。诸如PyTorch和JAX之类的框架将模型表示为张量操作符图,并自然集成自动微分,但操作符边界往往成为实例化边界。
编译器系统通过图重写、调度、代码生成和自动调优将张量程序降低为优化后的内核[1 (https://arxiv.org/html/2605.19269#bib.bib1),2 (https://arxiv.org/html/2605.19269#bib.bib2),19 (https://arxiv.org/html/2605.19269#bib.bib19)]。代数重构是另一个重要的性能来源,如TASO[8 (https://arxiv.org/html/2605.19269#bib.bib8)]和Mirage[22 (https://arxiv.org/html/2605.19269#bib.bib22)]所示。然而,快速发展的加速器使得通用编译器的峰值性能成为一个移动的目标。
在更接近硬件的层面,程序员使用内核DSL和库,如Triton[19 (https://arxiv.org/html/2605.19269#bib.bib19)]、ThunderKittens[14 (https://arxiv.org/html/2605.19269#bib.bib14),13 (https://arxiv.org/html/2605.19269#bib.bib13),17 (https://arxiv.org/html/2605.19269#bib.bib17)]、TileLang[20 (https://arxiv.org/html/2605.19269#bib.bib20)]、CuTeDSL[18 (https://arxiv.org/html/2605.19269#bib.bib18)]、Gluon和TLX,或者依赖vLLM[9 (https://arxiv.org/html/2605.19269#bib.bib9)]、SGLang[25 (https://arxiv.org/html/2605.19269#bib.bib25)]、FlashInfer[23 (https://arxiv.org/html/2605.19269#bib.bib23)]和Liger Kernels[6 (https://arxiv.org/html/2605.19269#bib.bib6)]中的专门LLM内核。这些方法提供了高性能,但将它们扩展到新的变换或反向计算仍然需要大量的底层工程工作。
2.2 GEMM主循环与尾声融合
矩阵乘法是现代LLM工作负载中的核心计算原语。一个高性能的GEMM内核通常分为主循环和尾声两部分。主循环执行分块的矩阵乘加计算,而尾声则变换计算出的输出块并高效地写回全局内存。
参见图注 图3:GEMM主循环计算输出块;尾声在每个块最终存储到全局内存之前对其进行变换。尾声是实现融合的自然位置,因为矩阵乘法的输出已经在片上靠近计算核心。实际的尾声通常执行缩放、偏置添加、激活函数、残差更新、数据类型转换、块级归约以及其他输出元素级操作,从而避免了单独的内核启动和额外的全局内存往返。现代内核库直接形式化了这种分离:CUTLASS[18 (https://arxiv.org/html/2605.19269#bib.bib18)]将GEMM内核表示为一个集体主循环和一个集体尾声的组合,而Epilogue Visitor Trees进一步将尾声表达为原语的组合[4 (https://arxiv.org/html/2605.19269#bib.bib4)]。
这种灵活性在局部性约束下运作。一个尾声只能看到局部的输出块、其累加器以及一致索引的辅助张量,这意味着需要全局归约或跨块通信的操作必须重新表述为块局部部分,或在单独的通路中处理。CODA建立在此接口之上,保持高性能GEMM主循环固定,并将尾声用作附近内存受限计算的可编程位置。
3 CODA
上一节论证了GEMM尾声是将内存受限计算融合到密集线性代数中的自然位置。我们现在描述CODA,一种实现这一思想的GPU内核抽象。第3.1节 (https://arxiv.org/html/2605.19269#S3.SS1) 确定了一小组在GPU执行中高效的尾声原语。第3.2节 (https://arxiv.org/html/2605.19269#S3.SS2) 展示了Transformer前向和反向传播中的非注意力和非嵌入部分如何通过这些原语进行重新参数化。最后,第3.3节 (https://arxiv.org/html/2605.19269#S3.SS3) 描述了它们的实现和我们的面向LLM的创作工作流。
3.1 高效的尾声原语
CODA对GEMM尾声进行编程,同时保持主循环固定且高度优化。对于每个输出块,尾声可以加载辅助数据、变换累加器值、输出辅助结果以及存储最终输出。此接口有意限制为块局部计算,而不是任意的全局通信。我们的尾声模板(如附录B.1节 (https://arxiv.org/html/2605.19269#A2.SS1) 所示)受到Epilogue Visitor Trees[4 (https://arxiv.org/html/2605.19269#bib.bib4)]的启发。CODA提供了五类尾声原语:
- 1.*逐元素和逐对映射:*对累加器值应用局部变换,包括残差更新、激活函数、RoPE风格的旋转以及SwiGLU风格的门控。
- 2.*向量(秩1张量)加载和存储:*加载行向量或列向量,将其广播到整个输出块,并可选地写入向量值的辅助结果。
- 3.*块(秩2张量)加载和存储:*加载或存储矩阵块,例如残差流、保存的激活或反向传播所需的中间值。
- 4.*块(秩2张量)归约:*计算输出块的行或列的部分归约,稍后由轻量级辅助内核合并。
- 5.*状态转换:*维护运行中的块状态,例如在线log-sum-exp和交叉熵中使用的最大值和sum-exp统计量。
这些原语有意定义得狭窄,其级别足够低,可以编译为高效的尾声代码,同时表达能力足够强,能够捕获Transformer重新参数化中围绕GEMM的内存受限操作,如下所示。
3.2 将Transformer重写为尾声
我们现在展示上述原语集对于大部分Transformer计算来说是足够的。在经过轻量级的代数重新参数化后,标准Transformer前向传播的许多非注意力和非嵌入组件可以写为:
GEMM:h = xW, 尾声:y[i,j] = fi,j, 其中[i,j]索引一个输出块,f[i,j]是在GEMM尾声中实现的块函数。尾声要么是完全块局部的,要么是块局部直到部分结果,这些结果由轻量级辅助归约合并。我们首先将此视图应用于前向传播,然后展示独立的块函数在反向传播中保留了相同的GEMM-尾声结构。
3.2.1 GEMM-残差-RMSNorm-GEMM模式
预归一化Transformer中一个重复出现的模式是GEMM后跟残差更新和归一化,然后又是另一个GEMM。该模式出现在几个相邻的子层中:
- 注意力输出投影→残差流→RMSNorm→MLP门/上投影;
- MLP下投影→残差流→RMSNorm→注意力QKV投影;
- 最终MLP下投影→残差流→最终RMSNorm→语言模型头。
尽管这些情况通常作为不同模块的部分编写,但它们共享相同的计算结构:
y = RMSNorm(xW₀ + z, γ)W₁ = (r(xW₀ + z) ⊙ γ)W₁, 其中z表示残差流,r = 1/rms(xW₀ + z)是逐行的逆RMS因子。该模式跨越了通常的模块边界:它将一个子层的输出投影与下一个子层的输入投影耦合起来。
残差相加和乘以RMSNorm权重γ是块局部的,因此可以融合到GEMM尾声中。然而,逐行因子r需要跨隐藏维度的归约,这大于单个输出块。在标准计算中,r在第二个GEMM之前应用,这在归一化和下一个投影之间创建了一个明显的依赖关系。
参见图注 图4:GEMM-RMSNorm-GEMM重新参数化。
相似文章
@reach_vb: https://x.com/reach_vb/status/2057880274348695995
一名用户演示了使用OpenAI的Codex自动生成一个Colab笔记本,该笔记本在JAX/Flax/Optax中训练一个约1000万参数的transformer进行加法运算,在T4 GPU上经过4000步后达到了高准确率。
@pauliusztin_: 我刚找到了理解 GPU 最实用的资源之一。再也不用在不同文档、PDF 和论坛帖子之间跳来跳去了…
Modal Labs 发布了一个开源的 GPU 术语词典,将零散的 NVIDIA 文档、CUDA 细节及编译器参数整合为单一的可导航资源,旨在帮助工程师优化 LLM 的训练与推理。
@leloykun:[进行中] 关于 Lean4-to-TileLang 张量程序超级优化器的博文:
一篇技术博文介绍了一种 Lean4-to-TileLang 张量程序超级优化器,能自动生成优化的 GPU/TPU 内核与超参数缩放规律,展示了相较 torch.compile 的性能提升。
@optimalab1: 高度赞扬 Barbara Su(莱斯大学计算机科学 -> 斯坦福大学硕士):她主导了整个端到端流程:算法、GLUE/SQuAD 流水线…
介绍 AdaPaD,一种用于 LoRA 微调的并行秩-1 缩减方法,使得低秩线性回归组件可以并行计算而非顺序计算,提高了效率。
@leloykun: 我又忙忘了时间 >.< 最近如果给我发过私信,真的非常抱歉。我保证会逐一查看!--- 在本次迭代中,我……
作者开发了一个从 Lean4 到 TileLang 的张量程序超优化器,能够自动生成优化后的加速器内核并推导超参数缩放定律,在 A100 GPU 上实现了 1.8 倍的加速。