CODA: 将Transformer块重写为GEMM-尾声程序
摘要
介绍CODA,一种GPU内核抽象,将Transformer操作表达为GEMM加尾声程序以减少数据移动,覆盖Transformer块中几乎所有非注意力计算。
暂无内容
查看缓存全文
缓存时间: 2026/05/22 06:43
# CODA: 将 Transformer 块重写为 GEMM-尾声程序 来源:https://arxiv.org/html/2605.19269 韩郭¹ Jack Zhang² Arjun Menon² Driss Guessous⁴ Vijay Thakkar⁴ Yoon Kim¹ Tri Dao²,³ ¹麻省理工学院 ²普林斯顿大学 ³Together AI ⁴Meta [email protected] ###### 摘要 Transformer 训练系统基于稠密线性代数构建,但端到端时间中相当一部分花在围绕的内存受限算子上。归一化、激活函数、残差更新、归约及相关计算频繁在全局内存中移动大型中间张量,同时进行的算术操作很少,这使得数据移动成为本已高度优化的训练栈中日益重要的瓶颈。我们提出 **CODA**,一种 GPU 内核抽象,将这些计算表达为 GEMM 加尾声的程序。CODA 基于以下观察:许多 Transformer 算子在框架层面是独立的内核,但可以通过代数重新参数化,在 GEMM 输出块仍留在芯片上时执行,之后再写入内存。该抽象固定 GEMM 主循环,并暴露一小组可组合的尾声原语,用于缩放、归约、成对变换和累积。这一受约束的接口既保持了专家编写的 GEMM 的性能结构,又具有足够的表现力,涵盖了标准 Transformer 块前向和后向传播中几乎所有的非注意力计算。在代表性的 Transformer 工作负载上,无论是人类编写还是 LLM 编写的 CODA 内核都实现了高性能,这表明 GEMM 加尾声编程为结合框架级生产力与硬件级效率提供了一条实用路径。¹¹代码见 https://github.com/HanGuo97/coda-kernels。 ## 1 引言 参见图注 图 1:在单个 H100 上使用 TorchTitan 训练 LLaMA-3 风格 1B 模型的运行时分解。LLM 训练已成为系统问题与模型问题并重。现代基于 Transformer 的 LLM 中,FLOP 主要由矩阵乘法和注意力支配,它们的内核已针对 Tensor Core 执行进行了高度优化。然而,Transformer 以及更广泛的深度学习架构还包含归一化、激活函数、残差更新、归约及其他带宽受限的操作,这些操作在内存中移动大型张量,同时进行的算术操作很少。先前工作表明,数据移动是 Transformer 训练的核心瓶颈[7 (https://arxiv.org/html/2605.19269#bib.bib7)];如图 1 (https://arxiv.org/html/2605.19269#S1.F1) 所示,当在单个 H100 上使用 TorchTitan[11 (https://arxiv.org/html/2605.19269#bib.bib11)] 训练 LLaMA-3 风格 1B 模型时,这些非 GEMM 操作占端到端运行时长的相当一部分。随着硬件通过 FP8 和 FP4 等格式加速低精度矩阵乘法,这一瓶颈变得更加重要,因为具体化中间张量的成本并未以相同速度改善。 现有编程模型使得这个问题难以解决。PyTorch 等高级框架将 Transformer 块表达为算子序列,autograd 使得反向传播同样便捷。这很高效,但算子边界往往成为具体化边界,并掩盖前向和反向计算中的融合机会。因此,生产级 LLM 系统常常绕过框架抽象,使用手写的反向传播或自定义内核,如在大型 LLaMA 训练[5 (https://arxiv.org/html/2605.19269#bib.bib5)] 和推理系统[9 (https://arxiv.org/html/2605.19269#bib.bib9),25 (https://arxiv.org/html/2605.19269#bib.bib25)] 中。本文探讨是否存在中间地带:即是否可能在不放弃可编程性和自动化所需结构的情况下,恢复自定义内核的大部分性能? 我们的出发点是:许多在框架层面表现为独立算子的 Transformer 计算,可以通过代数重新参数化,作为 GEMM 加尾声的程序(图 2 (https://arxiv.org/html/2605.19269#S1.F2))。在这种形式下,高度优化的 GEMM 主循环生成输出块,而可编程的尾声在结果写入内存前对块进行局部变换(图 3 (https://arxiv.org/html/2605.19269#S2.F3))。这在 GPU 上高效,因为尾声操作的数据已由 GEMM 块产生,避免了中间张量的额外全局内存往返。借助现代流水线调度(如 Hopper Ping-Pong GEMM 和 Blackwell TMEM 流水线),这些尾声工作常可置于其他块主循环的阴影中。因此,我们将尾声从简单的后处理(如缩放或偏置加法)扩展为结构化接口,用于将内存受限的计算融合到 GEMM 块的生存期内。 基于上述,我们提出 **CODA**,一个实现此接口的内核抽象原型。CODA 保持 GEMM 主循环固定,并暴露一小组可组合的尾声原语,用于缩放、归约、成对变换和累积。这种编程模型刻意受约束,但具有表现力:经过重新参数化后,这些原语涵盖了标准 Transformer 模型前向和反向传播的几乎所有部分,同时保持高效。CODA 在中间张量具体化到全局内存之前,将计算插入已知高性能 GEMM 的尾声中,从而捕获稠密线性代数周围的一大类内存受限计算。Transformer 是我们的主要应用,但同样的 GEMM 加尾声视图也适用于更广泛的场景,即高吞吐量矩阵乘法被块可表达、数据移动受限的计算所包围。 最后,这种结构使自动化更加实用。尾声融合在高性能 GEMM 库中已是成熟技术,但将其应用于 Transformer 工作负载仍是一项底层工程任务。CODA 通过在调优的 GEMM 主循环之上提供 Transformer 特定的尾声原语来填补这一空白。人类或基于 LLM 的作者可以组装这些原语为重新参数化的 Transformer 内核,而非合成任意 CUDA 代码。在代表性工作负载上,两种编写模式均实现高性能,表明领域特定的尾声抽象可以使成熟的 GEMM 融合技术对 LLM 内核更加可编程。 参见图注 图 2:标准 Transformer 层的前向传播。顶行展示了规范形式,对应计算密集型和内存密集型内核的混合。我们对计算进行重新参数化,使得大多数内存密集操作被归入计算密集型内核的尾声中。 ## 2 背景与相关工作 ### 2.1 LLM 系统的编程模型 现代 LLM 系统在多个抽象层次上编程。PyTorch 和 JAX 等框架将模型表达为张量算子图,并与自动微分自然集成,但算子边界往往成为具体化边界。 编译器系统通过图重写、调度、代码生成和自动调优将张量程序降级为优化内核[1 (https://arxiv.org/html/2605.19269#bib.bib1),2 (https://arxiv.org/html/2605.19269#bib.bib2),19 (https://arxiv.org/html/2605.19269#bib.bib19)]。代数重写是另一个重要的性能来源,如 TASO[8 (https://arxiv.org/html/2605.19269#bib.bib8)] 和 Mirage[22 (https://arxiv.org/html/2605.19269#bib.bib22)] 所示。然而,快速演进的加速器使得通用编译器难以达到峰值性能。 更接近硬件层面,程序员使用内核 DSL 和库如 Triton[19 (https://arxiv.org/html/2605.19269#bib.bib19)]、ThunderKittens[14 (https://arxiv.org/html/2605.19269#bib.bib14),13 (https://arxiv.org/html/2605.19269#bib.bib13),17 (https://arxiv.org/html/2605.19269#bib.bib17)]、TileLang[20 (https://arxiv.org/html/2605.19269#bib.bib20)]、CuTeDSL[18 (https://arxiv.org/html/2605.19269#bib.bib18)]、Gluon 和 TLX,或依赖 vLLM[9 (https://arxiv.org/html/2605.19269#bib.bib9)]、SGLang[25 (https://arxiv.org/html/2605.19269#bib.bib25)]、FlashInfer[23 (https://arxiv.org/html/2605.19269#bib.bib23)] 和 Liger Kernels[6 (https://arxiv.org/html/2605.19269#bib.bib6)] 中的专用 LLM 内核。这些方法实现高性能,但将其扩展到新的变换或反向计算仍需大量底层工程工作。 ### 2.2 GEMM 主循环与尾声融合 矩阵乘法是现代 LLM 工作负载的核心计算原语。高性能 GEMM 内核通常分为主循环和尾声两部分。主循环执行分块矩阵乘累加计算,而尾声则变换计算出的输出块并将其高效写回全局内存。 参见图注 图 3:GEMM 主循环计算输出块;尾声对每个块进行变换,然后最终存储到全局内存。尾声是实现融合的自然位置,因为矩阵乘的输出已接近计算核心的片上缓存。实际应用中,尾声通常执行缩放、偏置加法、激活函数、残差更新、数据类型转换、块级归约及其他输出元素操作,从而避免独立的内核启动和额外的全局内存往返。现代内核库直接形式化了这种分离:CUTLASS[18 (https://arxiv.org/html/2605.19269#bib.bib18)] 将 GEMM 内核表示为集体主循环和集体尾声的组合,而尾声访问者树进一步将尾声表达为原语的组合[4 (https://arxiv.org/html/2605.19269#bib.bib4)]。 这种灵活性受限于局部性约束。尾声仅能看到局部输出块、其累加器以及一致索引的辅助张量,这意味着需要全局归约或跨块通信的操作必须重新表述为块局部片段,或在单独的过程中处理。CODA 建立在此接口之上,保持高性能 GEMM 主循环不变,并将尾声作为附近内存受限计算的可编程站点。 ## 3 CODA 上一节论证了 GEMM 尾声是将内存受限计算融合到稠密线性代数中的自然位置。现在我们描述 **CODA**,一个实现该思想的 GPU 内核抽象。第 3.1 节 (https://arxiv.org/html/2605.19269#S3.SS1) 确定了一小组可高效映射到 GPU 执行的尾声原语。第 3.2 节 (https://arxiv.org/html/2605.19269#S3.SS2) 展示了如何利用这些原语将标准 Transformer 前向和反向传播中的非注意力、非嵌入部分重新参数化。最后,第 3.3 节 (https://arxiv.org/html/2605.19269#S3.SS3) 描述了它们的实现以及我们面向 LLM 的编写工作流。 ### 3.1 高效的尾声原语 CODA 对 GEMM 尾声进行编程,同时保持主循环固定且高度优化。对于每个输出块,尾声可以加载辅助数据、变换累加器值、发出辅助结果并存储最终输出。该接口有意限制为块局部计算,而非任意全局通信。我们的尾声模板(见附录 B.1 (https://arxiv.org/html/2605.19269#A2.SS1))受尾声访问者树[4 (https://arxiv.org/html/2605.19269#bib.bib4)] 启发。CODA 提供五类尾声原语: 1. 1.*元素级与成对映射*:对累加器值应用局部变换,包括残差更新、激活函数、RoPE 风格旋转和 SwiGLU 风格门控。 2. 2.*向量(秩-1 张量)加载与存储*:加载行向量或列向量,将其广播到输出块上,并可选择性写入向量值辅助结果。 3. 3.*块(秩-2 张量)加载与存储*:加载或存储矩阵块,例如残差流、保存的激活函数或反向传播所需的中间值。 4. 4.*块(秩-2 张量)归约*:计算输出块的行或列的部分归约,稍后由轻量级辅助内核进行合并。 5. 5.*有状态变换*:维护运行中的块状态,例如在线对数求和指数和交叉熵中使用的最大值与求和指数统计量。 这些原语有意保持狭窄,操作级别足够低以便编译为高效的尾声代码,且具有足够的表现力以捕获我们 Transformer 重新参数化中 GEMM 周围的内存受限操作,如下所示。 ### 3.2 将 Transformer 重新参数化为尾声 现在我们展示上述原语集足以处理大部分 Transformer 计算。经过轻量级代数重新参数化后,标准 Transformer 前向传播中的许多非注意力、非嵌入组件可写为: GEMM: \( \mathbf{h} = \mathbf{x} \mathbf{W} \), Epilogue: \( \mathbf{y}[i,j] = \mathbf{f}[i,j]\!\left(\mathbf{h}[i,j]\right) \), 其中 \([i,j]\) 索引一个输出块,\(\mathbf{f}[i,j]\) 是在 GEMM 尾声中实现的块函数。尾声要么完全是块局部的,要么是块局部但产生部分结果,再由轻量级辅助归约进行合并。我们首先将此视图应用于前向传播,然后展示独立的块函数在反向传播中保留了相同的 GEMM-尾声结构。 #### 3.2.1 GEMM-残差-RMSNorm-GEMM 模式 在预归一化 Transformer 中,重复出现的模式是 GEMM 后接残差更新和归一化,再接另一个 GEMM。该模式出现在多个相邻子层中: 1. 注意力输出投影 → 残差流 → RMSNorm → MLP 门/上投影; 2. MLP 下投影 → 残差流 → RMSNorm → 注意力 QKV 投影; 3. 最终 MLP 下投影 → 残差流 → 最终 RMSNorm → 语言模型头。 尽管这些情况通常属于不同模块的部分,但它们共享相同的计算结构: \[ \mathbf{y} = \operatorname{RMSNorm}(\mathbf{x} \mathbf{W}_0 + \mathbf{z}, \gamma) \mathbf{W}_1 = \Bigl( r \, \bigl( \mathbf{x} \mathbf{W}_0 + \mathbf{z} \bigr) \odot \gamma \Bigr) \mathbf{W}_1, \] 其中 \(\mathbf{z}\) 表示残差流,\(r = 1 / \operatorname{rms}(\mathbf{x} \mathbf{W}_0 + \mathbf{z})\) 是逐行的逆 RMS 因子。该模式跨越通常的模块边界:它将一个子层的输出投影与下一个子层的输入投影耦合在一起。 残差加法和乘以 RMSNorm 权重 \(\gamma\) 是块局部操作,因此可以融合到 GEMM 尾声中。然而,逐行因子 \(r\) 需要跨隐藏维度的归约,该维度大于单个输出块。在规范计算中,\(r\) 在第二个 GEMM 之前应用,这就在归一化与下一个投影之间建立了明显的依赖关系。 参见图注 图 4:GEMM-RMSNorm-GEMM 重新参数化。 我们通过将归约分为两个层次来解决这一问题。第一个 GEMM 尾声计算块局部的部分归约,一个小的辅助内核将这些部分归约跨块合并以得到 \(r\)。由于辅助内核 rea
相似文章
@juleslogs: 想理解现代AI?从这里开始:1. Transformers → Illustrated Transformer 2. LLMs → Build a Large Language Mo…
一条推文,整理了理解现代AI的基础资源,涵盖从Transformer到物理AI的主题,包括关键论文和模型。
@loganthorneloe:阅读此文,开始学习机器学习基础设施。这是对机器学习中重要考虑因素的极好高层概述……
卡内基梅隆大学软件工程研究所发布了一篇机器学习训练基础设施概述,涵盖了硬件考量(如GPU与CPU)以及内存需求等。
@FinanceYF5: Anthropic 正在雇佣 1000 名自由职业软件工程师来训练 Claude Code。 单任务报酬 280 美元。 他们负责编写提示词、比对代码输出、测试模型的追问响应,并且教会 Claude 真实开发者的工作方式。 这简直是在亲手…
Anthropic is hiring 1000 freelance software engineers to train Claude Code, with each task paying $280. The engineers will write prompts, compare code outputs, test model responses, and teach Claude how real developers work.
@FeitengLi: 异步、稀疏,和小数点后第五位:Cursor 训练 Composer 2 的工程细节 https://lattifai.com/zh/podcasts/SequoiaCapital/UDTr9yUnLUI…
本文深入探讨Cursor训练Composer 2模型采用的异步、稀疏等技术细节,并介绍了RL基础设施的全解析。
人口统计偏差对皮肤病变分类的影响
本文研究了人口统计偏差(性别和年龄)对使用ResNet模型进行皮肤病变分类的影响,发现性别偏差源于数据不平衡,而年龄偏差则始终偏向较年轻群体,并评估了多任务学习和对抗性学习的缓解策略。