CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs
Summary
Introduces CODA, a GPU kernel abstraction that expresses Transformer operations as GEMM-plus-epilogue programs to reduce data movement, covering nearly all non-attention computation in a Transformer block.
View Cached Full Text
Cached at: 05/22/26, 06:43 AM
# CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs
Source: [https://arxiv.org/html/2605.19269](https://arxiv.org/html/2605.19269)
Han Guo1Jack Zhang2Arjun Menon2 Driss Guessous4Vijay Thakkar4Yoon Kim1Tri Dao2,3 1Massachusetts Institute of Technology2Princeton University3Together AI4Meta hanguo@mit\.edu
###### Abstract
Transformer training systems are built around dense linear algebra, yet a nontrivial fraction of end\-to\-end time is spent on surrounding memory\-bound operators\. Normalization, activations, residual updates, reductions, and related computations repeatedly move large intermediate tensors through global memory while performing little arithmetic, making data movement an increasingly important bottleneck in otherwise highly optimized training stacks\. We introduceCODA, a GPU kernel abstraction that expresses these computations as GEMM\-plus\-epilogue programs\.CODAis based on the observation that many Transformer operators exposed as separate framework kernels can be algebraically reparameterized to execute while a GEMM output tile remains on chip, before it is written to memory\. The abstraction fixes the GEMM mainloop and exposes a small set of composable epilogue primitives for scaling, reductions, pairwise transformations, and accumulation\. This constrained interface preserves the performance structure of expert\-written GEMMs while remaining expressive enough to cover nearly all non\-attention computation in the forward and backward pass of a standard Transformer block\. Across representative Transformer workloads, both human\- and LLM\-authoredCODAkernels achieve high performance, suggesting that GEMM\-plus\-epilogue programming offers a practical path toward combining framework\-level productivity with hardware\-level efficiency\.111Code available at[https://github\.com/HanGuo97/coda\-kernels](https://github.com/HanGuo97/coda-kernels)\.
## 1Introduction
Figure 1:Runtime breakdown for LLaMA\-3\-style 1B model training on a single H100 using TorchTitan\.LLM training has become just as much of a systems problem as a modeling one\. FLOPs in modern Transformer\-based LLMs are dominated by matrix multiplications and attention, whose kernels have been heavily optimized for Tensor Core execution\. Yet Transformers, and deep learning architectures more broadly, also contain normalization, activations, residual updates, reductions, and other bandwidth\-limited operations that move large tensors through memory while doing little arithmetic\. Prior work has shown that data movement is a central bottleneck in Transformer training\[[7](https://arxiv.org/html/2605.19269#bib.bib7)\]; as[Figure˜1](https://arxiv.org/html/2605.19269#S1.F1)shows, when training a LLaMA\-3\-style 1B model on a single H100 using TorchTitan\[[11](https://arxiv.org/html/2605.19269#bib.bib11)\], these non\-GEMM operations account for a nontrivial fraction of end\-to\-end runtime\. As hardware increasingly accelerates low\-precision matrix multiplication through formats such as FP8 and FP4, this bottleneck is becoming more important, as the cost of materializing intermediate tensors does not improve at the same rate\.
Existing programming models make this issue difficult to address\. High\-level frameworks such as PyTorch express Transformer blocks as operator sequences, with autograd making the backward pass similarly convenient\. This is productive, but operator boundaries often become materialization boundaries and obscure fusion opportunities across forward and backward computation\. Production\-level LLM systems therefore often bypass framework abstractions with hand\-written backward passes or custom kernels, as in large\-scale LLaMA training\[[5](https://arxiv.org/html/2605.19269#bib.bib5)\]and inference systems\[[9](https://arxiv.org/html/2605.19269#bib.bib9),[25](https://arxiv.org/html/2605.19269#bib.bib25)\]\. This work asks whether there is a middle ground\. That is, is it possible to recover much of the performance of custom kernels without giving up the structure needed for programmability and automation?
Our starting observation is that many Transformer computations that appear at the framework level as separate operators can be algebraically reparameterized as GEMM\-plus\-epilogue programs \([Figure˜2](https://arxiv.org/html/2605.19269#S1.F2)\)\. In this form, a highly optimized GEMM mainloop produces output tiles, while a programmable epilogue performs tile\-local transformations before the result is written to memory \([Figure˜3](https://arxiv.org/html/2605.19269#S2.F3)\)\. This is efficient on GPUs because the epilogue operates on data that is already produced by the GEMM tile, avoiding additional global\-memory round trips for intermediate tensors\. With modern pipelined schedules, this epilogue work can often be placed in the shadow of other tiles’ mainloops, as in Hopper Ping\-Pong GEMM and Blackwell TMEM\-based pipelines\. Thus, we extend the epilogue beyond a place for simple post\-processing such as scaling or bias addition, elevating it into a structured interface for fusing memory\-bound computation into the lifetime of a GEMM tile\.
Based on the above, we introduceCODA, a kernel abstraction prototype that realizes this interface\.CODAkeeps the GEMM mainloop fixed and exposes a small set of composable epilogue primitives for scaling, reductions, pairwise transformations, and accumulation\. This programming model is deliberately constrained and yet expressive, as after reparameterization, these primitives cover nearly the entire forward and backward pass of a standard Transformer model while preserving efficiency\.CODAinserts computation into the epilogue of a known high\-performance GEMM before intermediate tensors are materialized in global memory, capturing a broad class of memory\-bound computations surrounding dense linear algebra\. Transformers are our primary application, but the same GEMM\-plus\-epilogue view applies more broadly whenever high\-throughput matrix multiplication is surrounded by tile\-expressible, data\-movement\-bound computations\.
Finally, this structure makes automation more practical\. Epilogue fusion is already established in high\-performance GEMM libraries, but applying it to Transformer workloads remains a low\-level engineering task\.CODAtargets this gap by providing Transformer\-specific epilogue primitives on top of a tuned GEMM mainloop\. Human or LLM\-based authors can assemble these primitives into reparameterized Transformer kernels rather than synthesizing arbitrary CUDA\. Across representative workloads, both authoring modes achieve high performance, suggesting that domain\-specific epilogue abstractions can make established GEMM fusion techniques more programmable for LLM kernels\.
Figure 2:Forward pass of a standard Transformer layer\. The top row shows the canonical formulation, which maps to a mix of compute\- and memory\-bound kernels\. We reparameterize the computation so that most memory\-bound operations are subsumed into the epilogues of compute\-bound kernels\.
## 2Background and Related Works
### 2\.1Programming Models for LLM Systems
Modern LLM systems are programmed at multiple abstraction levels\. Frameworks such as PyTorch and JAX express models as tensor\-operator graphs and integrate naturally with automatic differentiation, but operator boundaries often become materialization ones\.
Compiler systems lower tensor programs to optimized kernels through graph rewriting, scheduling, code generation, and autotuning\[[1](https://arxiv.org/html/2605.19269#bib.bib1),[2](https://arxiv.org/html/2605.19269#bib.bib2),[19](https://arxiv.org/html/2605.19269#bib.bib19)\]\. Algebraic reformulation is another important source of performance, as shown by TASO\[[8](https://arxiv.org/html/2605.19269#bib.bib8)\]and Mirage\[[22](https://arxiv.org/html/2605.19269#bib.bib22)\]\. However, rapidly evolving accelerators make peak performance a moving target for general\-purpose compilers\.
Closer to the hardware level, programmers use kernel DSLs and libraries such as Triton\[[19](https://arxiv.org/html/2605.19269#bib.bib19)\], ThunderKittens\[[14](https://arxiv.org/html/2605.19269#bib.bib14),[13](https://arxiv.org/html/2605.19269#bib.bib13),[17](https://arxiv.org/html/2605.19269#bib.bib17)\], TileLang\[[20](https://arxiv.org/html/2605.19269#bib.bib20)\], CuTeDSL\[[18](https://arxiv.org/html/2605.19269#bib.bib18)\], Gluon, and TLX, or rely on specialized LLM kernels in vLLM\[[9](https://arxiv.org/html/2605.19269#bib.bib9)\], SGLang\[[25](https://arxiv.org/html/2605.19269#bib.bib25)\], FlashInfer\[[23](https://arxiv.org/html/2605.19269#bib.bib23)\], and Liger Kernels\[[6](https://arxiv.org/html/2605.19269#bib.bib6)\]\. These approaches deliver high performance, but extending them to new transformations or backward computations still requires substantial low\-level engineering\.
### 2\.2GEMM Mainloops and Epilogue Fusion
Matrix multiplication is the central compute primitive in modern LLM workloads\. A high\-performance GEMM kernel is typically divided into a mainloop and an epilogue\. The mainloop performs the tiled matrix multiply\-accumulate computation, while the epilogue transforms the computed output tile and efficiently writes it back to global memory\.
Figure 3:A GEMM mainloop computes output tiles; the epilogue transforms each tile before the final global\-memory store\.The epilogue is a natural place to implement fusions because the output of the matmul is already present on chip close to compute cores\. Practical epilogues commonly perform scaling, bias addition, activations, residual updates, data type conversions, tile\-wise reductions and other output elementwise operations, avoiding separate kernel launches and extra global\-memory round trips\. Modern kernel libraries formalize this separation directly: CUTLASS\[[18](https://arxiv.org/html/2605.19269#bib.bib18)\]represents GEMM kernels as a composition of a collective mainloop and a collective epilogue, while Epilogue Visitor Trees further express epilogues as compositions of primitives\[[4](https://arxiv.org/html/2605.19269#bib.bib4)\]\.
This flexibility operates under a locality constraint\. An epilogue sees only the local output tile, its accumulators, and consistently indexed auxiliary tensors, meaning that operations requiring global reductions or cross\-tile communication must be reformulated into tile\-local pieces or handled in a separate pass\. CODA builds on this interface, keeping the high\-performance GEMM mainloop fixed and using the epilogue as a programmable site for nearby memory\-bound computation\.
## 3CODA
The previous section argued that GEMM epilogues are a natural place to fuse memory\-bound computation into dense linear algebra\. We now describeCODA, a GPU kernel abstraction that realizes this idea\.[Section˜3\.1](https://arxiv.org/html/2605.19269#S3.SS1)identifies a small set of epilogue primitives that map efficiently to GPU execution\.[Section˜3\.2](https://arxiv.org/html/2605.19269#S3.SS2)shows how the non\-attention and non\-embedding portions of the Transformer forward and backward pass can be reparameterized using these primitives\. Finally,[Section˜3\.3](https://arxiv.org/html/2605.19269#S3.SS3)describes their implementation and our LLM\-oriented authoring workflow\.
### 3\.1Efficient Epilogue Primitives
CODAprograms the GEMM epilogue while keeping the mainloop fixed and highly optimized\. For each output tile, an epilogue may load auxiliary data, transform accumulator values, emit auxiliary results, and store the final output\. This interface is deliberately restricted to tile\-local computation rather than arbitrary global communication\. Our epilogue template, shown in[Section˜B\.1](https://arxiv.org/html/2605.19269#A2.SS1), is inspired by Epilogue Visitor Trees\[[4](https://arxiv.org/html/2605.19269#bib.bib4)\]\.CODAprovides five classes of epilogue primitives:
1. 1\.*Elementwise and pairwise maps:*apply local transformations to accumulator values, including residual updates, activations, RoPE\-style rotations, and SwiGLU\-style gates\.
2. 2\.*Vector \(Rank\-1 Tensor\) loads and stores:*load row or column vectors, broadcast them over an output tile, and optionally write vector\-valued auxiliary results\.
3. 3\.*Tile \(Rank\-2 Tensor\) loads and stores:*load or store matrix tiles, such as residual streams, saved activations, or intermediate values needed by the backward pass\.
4. 4\.*Tile \(Rank\-2 Tensor\) reductions:*compute partial reductions over rows or columns of an output tile, to be combined later by a lightweight auxiliary kernel\.
5. 5\.*Stateful transforms:*maintain running tile state, such as the max and sum\-exp statistics used in online log\-sum\-exp and cross\-entropy\.
These primitives are intentionally narrow, operating at a level low enough to compile to efficient epilogue code and expressive enough to capture the memory\-bound operations surrounding GEMMs in our Transformer reparameterizations, as shown next\.
### 3\.2Reparameterizing Transformers as Epilogues
We now show that the primitive set above is sufficient for much of Transformer computation\. After lightweight algebraic reparameterization, many non\-attention and non\-embedding components of a standard Transformer forward pass can be written as
GEMM:𝒉=𝒙𝑾,Epilogue:𝒚\[i,j\]=𝒇\[i,j\]\(𝒉\[i,j\]\),\\displaystyle\\text\{GEMM:\}\\quad\{\\bm\{h\}\}=\{\\bm\{x\}\}\{\\bm\{W\}\},\\qquad\\text\{Epilogue:\}\\quad\{\\bm\{y\}\}\[i,j\]=\{\\bm\{f\}\}\[i,j\]\\\!\\left\(\{\\bm\{h\}\}\[i,j\]\\right\),where\[i,j\]\[i,j\]indexes an output tile and𝒇\[i,j\]\{\\bm\{f\}\}\[i,j\]is the tile function implemented in the GEMM epilogue\. The epilogue is either fully tile\-local, or tile\-local up to partial results that are combined by a lightweight auxiliary reduction\. We first apply this view to the forward pass, then show that independent tile functions preserve the same GEMM\-epilogue structure in the backward pass\.
#### 3\.2\.1GEMM\-Residual\-RMSNorm\-GEMM Pattern
A recurring pattern in pre\-normalized Transformers is a GEMM followed by a residual update and normalization, then another GEMM\. This pattern appears across several adjacent sublayers:
1. 1\.attention output projection→\\rightarrowresidual stream→\\rightarrowRMSNorm→\\rightarrowMLP gate/up projection;
2. 2\.MLP down projection→\\rightarrowresidual stream→\\rightarrowRMSNorm→\\rightarrowattention QKV projection;
3. 3\.final MLP down projection→\\rightarrowresidual stream→\\rightarrowfinal RMSNorm→\\rightarrowlanguage modeling head\.
Although these cases are usually written as parts of different modules, they share the same computational structure:
𝒚=RMSNorm\(𝒙𝑾0\+𝒛,𝜸\)𝑾1=\(r\(𝒙𝑾0\+𝒛\)⊙𝜸\)𝑾1,\\displaystyle\{\\bm\{y\}\}=\\operatorname\{RMSNorm\}\(\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\}\+\{\\bm\{z\}\},\\bm\{\\gamma\}\)\{\\bm\{W\}\}\_\{1\}=\\Bigl\(r\\,\\bigl\(\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\}\+\{\\bm\{z\}\}\\bigr\)\\odot\\bm\{\\gamma\}\\Bigr\)\{\\bm\{W\}\}\_\{1\},where𝒛\{\\bm\{z\}\}denotes the residual stream andr=1/rms\(𝒙𝑾0\+𝒛\)r=1/\\operatorname\{rms\}\(\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\}\+\{\\bm\{z\}\}\)is the row\-wise inverse RMS factor\. This pattern crosses the usual module boundary: it couples the output projection of one sublayer with the input projection of the next\.
Residual addition and multiplication by the RMSNorm weight𝜸\\bm\{\\gamma\}are tile\-local, so they can be fused into a GEMM epilogue\. The row\-wise factorrr, however, requires a reduction across the hidden dimension, which is larger than a single output tile\. In the canonical computation,rris applied before the second GEMM, creating an apparent dependency between normalization and the next projection\.
Figure 4:GEMM\-RMSNorm\-GEMM reparameterization\.We address the reduction by splitting it into two levels\. The first GEMM epilogue computes tile\-local partial reductions, and a small auxiliary kernel reduces these partials across tiles to obtainrr\. Since the auxiliary kernel reads a few partial values per tile rather than the full activation tensor, its memory traffic is much smaller than that of a standalone RMSNorm\.
The apparent dependency onrrcan be removed algebraically\. Sincerris shared across the row, it commutes with the second GEMM:
𝒚=\(r\(𝒙𝑾0\+𝒛\)⊙𝜸\)𝑾1=r\(\(𝒙𝑾0\+𝒛\)⊙𝜸\)𝑾1\.\\displaystyle\{\\bm\{y\}\}=\\Bigl\(r\\,\\bigl\(\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\}\+\{\\bm\{z\}\}\\bigr\)\\odot\\bm\{\\gamma\}\\Bigr\)\{\\bm\{W\}\}\_\{1\}=r\\,\\Bigl\(\\bigl\(\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\}\+\{\\bm\{z\}\}\\bigr\)\\odot\\bm\{\\gamma\}\\Bigr\)\{\\bm\{W\}\}\_\{1\}\.Thus, the row\-wise scale does not need to be applied before the second GEMM\. It can instead be delayed to the epilogue of the second GEMM, after the projection has been computed\.
Concretely, the computation decomposes into two GEMMs and one lightweight reduction \([Figure˜4](https://arxiv.org/html/2605.19269#S3.F4)\):
GEMM 1:𝒉0\\displaystyle\{\\bm\{h\}\}\_\{0\}=𝒙𝑾0,\\displaystyle=\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\},\\quadEpilogue 1:𝒉1\[i,j\]\\displaystyle\{\\bm\{h\}\}\_\{1\}\[i,j\]=𝒉0\[i,j\]\+𝒛\[i,j\],\\displaystyle=\{\\bm\{h\}\}\_\{0\}\[i,j\]\+\{\\bm\{z\}\}\[i,j\],𝒉2\[i,j\]\\displaystyle\{\\bm\{h\}\}\_\{2\}\[i,j\]=𝒉1\[i,j\]⊙𝜸\[j\],\\displaystyle=\{\\bm\{h\}\}\_\{1\}\[i,j\]\\odot\\bm\{\\gamma\}\[j\],𝒓^\[i,j\]\\displaystyle\\widehat\{\{\\bm\{r\}\}\}\[i,j\]=partialRMS\(𝒉1\[i,j\]\),\\displaystyle=\\operatorname\{partialRMS\}\(\{\\bm\{h\}\}\_\{1\}\[i,j\]\),GEMM 2:𝒉3\\displaystyle\{\\bm\{h\}\}\_\{3\}=𝒉2𝑾1,\\displaystyle=\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\},\\qquadEpilogue 2:𝒚\[i,j\]\\displaystyle\{\\bm\{y\}\}\[i,j\]=r\[i\]𝒉3\[i,j\]\.\\displaystyle=r\[i\]\\,\{\\bm\{h\}\}\_\{3\}\[i,j\]\.
Figure 5:Benchmarks\.Herer=1/reduce\(𝒓^\)\+ϵr=1/\\sqrt\{\\operatorname\{reduce\}\(\\widehat\{\{\\bm\{r\}\}\}\)\+\\epsilon\}is computed by a small auxiliary reduction over the tile partials\. This decomposition replaces a standalone RMSNorm kernel with tile\-local epilogue work around the two GEMMs, plus a lightweight auxiliary reduction\.
In[Figure˜5](https://arxiv.org/html/2605.19269#S3.F5), we benchmark this reparameterization against existing implementations on LLaMA\-style configurations with a batch of 16K tokens\. We vary the hidden dimension across representative model scales, withd∈\{2048,4096,8192\}d\\in\\\{2048,4096,8192\\\}corresponding roughly to 1B, 7B, and 70B models, respectively\. Our GEMM\-Epilogue kernel is generated by an LLM provided with the above abstractions \(explained in more detail in[Section˜3\.3\.1](https://arxiv.org/html/2605.19269#S3.SS3.SSS1)\)\.
Figure 6:Relative error\.##### Numerics\.
The reparameterization changes where the RMSNorm scale is applied: the row\-wise factorrris delayed from before the second GEMM to the second GEMM epilogue\. We compareBF16GEMM\-RMSNorm\-GEMM outputs against anFP32reference on Llama\-3 8B layers\. We report the errors ofCODAand QuACK, on which our GEMM template is based, normalized by the error of the standard PyTorch path\.[Figure˜6](https://arxiv.org/html/2605.19269#S3.F6)suggests that a more accurate GEMM mainloop can reduce numerical error, and thatCODA’s reparameterized epilogue can reduce it further\.
#### 3\.2\.2GEMM with Pairwise Activations
A second common pattern in Transformers is a GEMM followed by a*pairwise*activation\. Unlike an elementwise activation, which transforms each feature independently, a pairwise activation consumes two adjacent feature values and produces one or two outputs\.
𝒉=𝒙𝑾,𝒉a\[i,j\],𝒉b\[i,j\]=split\(𝒉\[i,j\]\),𝒚\[i,j\]=𝒇\[i,j\]\(𝒉a\[i,j\],𝒉b\[i,j\]\)\.\\displaystyle\{\\bm\{h\}\}=\{\\bm\{x\}\}\{\\bm\{W\}\},\\qquad\{\\bm\{h\}\}\_\{a\}\[i,j\],\{\\bm\{h\}\}\_\{b\}\[i,j\]=\\operatorname\{split\}\\left\(\{\\bm\{h\}\}\[i,j\]\\right\),\\qquad\{\\bm\{y\}\}\[i,j\]=\{\\bm\{f\}\}\[i,j\]\\left\(\{\\bm\{h\}\}\_\{a\}\[i,j\],\{\\bm\{h\}\}\_\{b\}\[i,j\]\\right\)\.
Figure 7:Pairwise activations operate on local feature pairs in the GEMM epilogue\.This form captures several operations in Transformer blocks:
- •RoPE rotates each feature pair and return two outputs;
- •SwiGLU combines gate and value stream into one output;
- •SwiGLU backward pass maps one incoming gradient into gradients for both paired inputs\.
Pairwise activations couple neighboring feature lanes and may change the feature dimension\. A naive implementation materializes the GEMM output, splits it into paires, and applies the activation in a separate kernel\. This adds memory traffic and sometimes materializes an expanded intermediate, as in SwiGLU\.
Instead, we arrange paired features to be adjacent along the output\-feature dimension\. This matches the Hopper Tensor Core accumulator layout exposed to the epilogue, where each thread holds a small tuple of adjacent output values in registers before they are stored\. The epilogue can therefore applyffdirectly to each pair with register\-level computation\.
This removes the standalone activation kernel and avoids materializing the paired intermediate in global memory\. The same idea applies to dimension\-preserving operations such as RoPE, dimension\-reducing operations such as SwiGLU, and dimension\-expanding operations in the backward pass, as long as the pairing is reflected in the GEMM output layout\. See[Figure˜8](https://arxiv.org/html/2605.19269#S3.F8)for performance benchmarks\.
Figure 8:Kernel\-level speedups for representative GEMM\-plus\-epilogue primitives acrossMNKMNKsizes\. RoPE uses an output dimension of3N3Nfor QKV projections, and cross\-entropy uses a32K32\\mathrm\{K\}vocabulary\. Speedups are relative to cuBLAS withtorch\.compile\.
#### 3\.2\.3GEMM with Cross\-Entropy Loss
Cross\-entropy loss can also be expressed as a GEMM with epilogue\-side reductions, as shown by Cut Cross\-Entropy\[[21](https://arxiv.org/html/2605.19269#bib.bib21)\]\. Let𝒉i=𝒙i𝑾lm\{\\bm\{h\}\}\_\{i\}=\{\\bm\{x\}\}\_\{i\}\{\\bm\{W\}\}\_\{\\mathrm\{lm\}\}be the logits for tokenii, and let𝒚i\{\\bm\{y\}\}\_\{i\}be its target label\. The per\-token loss is
ℓi=−𝒉i,𝒚i\+log∑kexp\(𝒉i,k\)\.\\displaystyle\\ell\_\{i\}=\-\{\\bm\{h\}\}\_\{i,\{\\bm\{y\}\}\_\{i\}\}\+\\log\\sum\_\{k\}\\exp\(\{\\bm\{h\}\}\_\{i,k\}\)\.Thus, the loss decomposes into an indexed logit and a row\-wise log\-sum\-exp over vocabulary entries\.
Both terms fit the GEMM\-plus\-epilogue pattern\. The indexed logit can be selected from the GEMM output tile using the target label, while the LSE can be accumulated as tile\-local maximum and sum\-exp statistics\. A small auxiliary reduction then combines these statistics across tiles, avoiding a standalone memory\-bound softmax over the full logits\.222We use a separate final reduction rather than atomics, and materialize logits to simplify the backward pass\.See[Figure˜8](https://arxiv.org/html/2605.19269#S3.F8)for performance benchmarks\.
#### 3\.2\.4Backward Pass
The preceding sections show that much of the Transformer forward pass can be reparameterized as GEMMs with epilogues, plus lightweight auxiliary reductions\. We now show that the backward pass preserves the same structure\.
##### GEMM with elementwise epilogue\.
Consider two GEMMs separated by an elementwise epilogue:
𝒉=𝒙𝑾0,𝒉′=f\(𝒉\),𝒚=𝒉′𝑾1,\\displaystyle\{\\bm\{h\}\}=\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\},\\qquad\{\\bm\{h\}\}^\{\\prime\}=f\(\{\\bm\{h\}\}\),\\qquad\{\\bm\{y\}\}=\{\\bm\{h\}\}^\{\\prime\}\{\\bm\{W\}\}\_\{1\},whereffis applied elementwise\. Given an upstream gradient∇𝒚ℒ\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}, reverse\-mode differentiation gives
∇𝒉′ℒ=∇𝒚ℒ𝑾1⊤,∇𝒉ℒ=∇𝒉′ℒ⊙f′\(𝒉\),∇𝒙ℒ=∇𝒉ℒ𝑾0⊤\.\\displaystyle\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\},\\qquad\\nabla\_\{\{\\bm\{h\}\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}\\odot f^\{\\prime\}\(\{\\bm\{h\}\}\),\\qquad\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{h\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{0\}^\{\\top\}\.Thus, the backward computation has the same structure as the forward computation: GEMM, local transformation, GEMM\. The only difference is the direction of fusion\. In the forward pass,ffis fused into the epilogue of theprecedingGEMM that produces𝒉\{\\bm\{h\}\}; in the backward pass, multiplication byf′\(𝒉\)f^\{\\prime\}\(\{\\bm\{h\}\}\)is fused into the epilogue of thefollowingGEMM that produces∇𝒉′ℒ\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}\([Figure˜9](https://arxiv.org/html/2605.19269#S3.F9)\)\.
###### Theorem 1\.
Consider a sequence of GEMM\-with\-epilogue blocks followed by a final GEMM:
𝒉ℓ\\displaystyle\{\\bm\{h\}\}\_\{\\ell\}=𝒙ℓ−1𝑾ℓ,𝒙ℓ\[i,j\]=𝒇ℓ\[i,j\]\(𝒉ℓ\[i,j\]\),ℓ=1,…,L−1,\\displaystyle=\{\\bm\{x\}\}\_\{\\ell\-1\}\{\\bm\{W\}\}\_\{\\ell\},\\qquad\{\\bm\{x\}\}\_\{\\ell\}\[i,j\]=\{\\bm\{f\}\}\_\{\\ell\}\[i,j\]\\\!\\left\(\{\\bm\{h\}\}\_\{\\ell\}\[i,j\]\\right\),\\qquad\\ell=1,\\ldots,L\-1,𝒉L\\displaystyle\{\\bm\{h\}\}\_\{L\}=𝒙L−1𝑾L\.\\displaystyle=\{\\bm\{x\}\}\_\{L\-1\}\{\\bm\{W\}\}\_\{L\}\.Assume each tile function𝐟ℓ\[i,j\]\{\\bm\{f\}\}\_\{\\ell\}\[i,j\]acts only on its corresponding GEMM output tile\. Then the activation gradients can be computed with the same GEMM\-with\-epilogue structure:
∇𝒙ℓ−1ℒ\\displaystyle\\nabla\_\{\{\\bm\{x\}\}\_\{\\ell\-1\}\}\\mathcal\{L\}=∇𝒉ℓℒ𝑾ℓ⊤,∇𝒉ℓ−1ℒ\[i,j\]=𝒈ℓ−1\[i,j\]\(∇𝒙ℓ−1ℒ\[i,j\],𝒉ℓ−1\[i,j\]\),ℓ=L,…,2,\\displaystyle=\\nabla\_\{\{\\bm\{h\}\}\_\{\\ell\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{\\ell\}^\{\\top\},\\qquad\\nabla\_\{\{\\bm\{h\}\}\_\{\\ell\-1\}\}\\mathcal\{L\}\[i,j\]=\{\\bm\{g\}\}\_\{\\ell\-1\}\[i,j\]\\\!\\left\(\\nabla\_\{\{\\bm\{x\}\}\_\{\\ell\-1\}\}\\mathcal\{L\}\[i,j\],\\;\{\\bm\{h\}\}\_\{\\ell\-1\}\[i,j\]\\right\),\\qquad\\ell=L,\\ldots,2,∇𝒙0ℒ\\displaystyle\\nabla\_\{\{\\bm\{x\}\}\_\{0\}\}\\mathcal\{L\}=∇𝒉1ℒ𝑾1⊤\.\\displaystyle=\\nabla\_\{\{\\bm\{h\}\}\_\{1\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\}\.Here𝐠ℓ\[i,j\]\{\\bm\{g\}\}\_\{\\ell\}\[i,j\]is the tile\-local backward rule for𝐟ℓ\[i,j\]\{\\bm\{f\}\}\_\{\\ell\}\[i,j\]: it maps the gradient of the epilogue output tile𝐱ℓ\[i,j\]\{\\bm\{x\}\}\_\{\\ell\}\[i,j\]to the gradient of the epilogue input tile𝐡ℓ\[i,j\]\{\\bm\{h\}\}\_\{\\ell\}\[i,j\]\. The weight gradients are GEMMs:
∇𝑾ℓℒ=𝒙ℓ−1⊤∇𝒉ℓℒ,ℓ=1,…,L\.\\displaystyle\\nabla\_\{\{\\bm\{W\}\}\_\{\\ell\}\}\\mathcal\{L\}=\{\\bm\{x\}\}\_\{\\ell\-1\}^\{\\top\}\\nabla\_\{\{\\bm\{h\}\}\_\{\\ell\}\}\\mathcal\{L\},\\qquad\\ell=1,\\ldots,L\.Thus, tile\-local epilogues in the forward pass induce tile\-local epilogues in the backward pass, while the surrounding linear maps remain GEMMs\.
###### Proof sketch\.
For the GEMM𝒉ℓ=𝒙ℓ−1𝑾ℓ\{\\bm\{h\}\}\_\{\\ell\}=\{\\bm\{x\}\}\_\{\\ell\-1\}\{\\bm\{W\}\}\_\{\\ell\}, reverse\-mode differentiation gives
∇𝒙ℓ−1ℒ=∇𝒉ℓℒ𝑾ℓ⊤,∇𝑾ℓℒ=𝒙ℓ−1⊤∇𝒉ℓℒ,\\displaystyle\\nabla\_\{\{\\bm\{x\}\}\_\{\\ell\-1\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{h\}\}\_\{\\ell\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{\\ell\}^\{\\top\},\\qquad\\nabla\_\{\{\\bm\{W\}\}\_\{\\ell\}\}\\mathcal\{L\}=\{\\bm\{x\}\}\_\{\\ell\-1\}^\{\\top\}\\nabla\_\{\{\\bm\{h\}\}\_\{\\ell\}\}\\mathcal\{L\},which are both GEMMs\. For the epilogue𝒙ℓ\[i,j\]=𝒇ℓ\[i,j\]\(𝒉ℓ\[i,j\]\)\{\\bm\{x\}\}\_\{\\ell\}\[i,j\]=\{\\bm\{f\}\}\_\{\\ell\}\[i,j\]\(\{\\bm\{h\}\}\_\{\\ell\}\[i,j\]\), define the local backward rule by
∇𝒉ℓℒ\[i,j\]=𝒈ℓ\[i,j\]\(∇𝒙ℓℒ\[i,j\],𝒉ℓ\[i,j\]\)\.\\displaystyle\\nabla\_\{\{\\bm\{h\}\}\_\{\\ell\}\}\\mathcal\{L\}\[i,j\]=\{\\bm\{g\}\}\_\{\\ell\}\[i,j\]\\\!\\left\(\\nabla\_\{\{\\bm\{x\}\}\_\{\\ell\}\}\\mathcal\{L\}\[i,j\],\\;\{\\bm\{h\}\}\_\{\\ell\}\[i,j\]\\right\)\.Since each𝒇ℓ\[i,j\]\{\\bm\{f\}\}\_\{\\ell\}\[i,j\]depends only on its own tile, the corresponding𝒈ℓ\[i,j\]\{\\bm\{g\}\}\_\{\\ell\}\[i,j\]also acts only on that tile\. The backward pass therefore introduces no new cross\-tile communication and preserves the GEMM\-with\-epilogue structure\. ∎
Figure 9:Forward and backward fusion for GEMM–epilogue blocks\. Forward epilogues attach to the GEMM that produces their input, while backward epilogues attach to the GEMM that produces the gradient with respect to their output\.
##### RMSNorm backward\.
RMSNorm is the main case where the backward pass is not purely tile\-local\. Its backward rule introduces two reductions: a row\-wise statistic needed for the input gradient, and a feature\-wise reduction across rows for the RMSNorm weight gradient\. A direct implementation computes both in a standalone RMSNorm backward kernel, requiring additional reads of activation\-sized tensors\. However, the row\-wise statistic can be moved to a neighboring GEMM boundary\. Consider
𝒉0=𝒙𝑾0,𝒉1=f\(𝒉0\),𝒉2=RMSNorm\(𝒉1,𝜸\),𝒚=𝒉2𝑾1\.\\displaystyle\{\\bm\{h\}\}\_\{0\}=\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\},\\qquad\{\\bm\{h\}\}\_\{1\}=f\(\{\\bm\{h\}\}\_\{0\}\),\\qquad\{\\bm\{h\}\}\_\{2\}=\\operatorname\{RMSNorm\}\(\{\\bm\{h\}\}\_\{1\},\\bm\{\\gamma\}\),\\qquad\{\\bm\{y\}\}=\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}\.RMSNorm backward requires the row\-wise inner product
𝒔=1dsumcols\(∇𝒉2ℒ⊙𝒉2\)\.\\displaystyle\{\\bm\{s\}\}=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\{\\bm\{h\}\}\_\{2\}\\right\)\.Using∇𝒉2ℒ=∇𝒚ℒ𝑾1⊤\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\}and𝒚=𝒉2𝑾1\{\\bm\{y\}\}=\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}, this statistic can be equivalently written as
𝒔=1dsumcols\(∇𝒚ℒ⊙𝒚\)\.\\displaystyle\{\\bm\{s\}\}=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\\odot\{\\bm\{y\}\}\\right\)\.This identity changes where the statistic is computed\. Instead of launching a standalone RMSNorm backward kernel to read𝒉2\{\\bm\{h\}\}\_\{2\}and∇𝒉2ℒ\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}, we accumulate the same row\-wise quantity at a boundary where𝒚\{\\bm\{y\}\}and∇𝒚ℒ\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}are already available, thereby exposing the computation to epilogue fusion\.
In consecutive Transformer patterns, each GEMM epilogue can therefore accumulate the row\-wise statistic needed by the preceding RMSNorm backward\. The RMSNorm weight gradient is handled similarly by emitting tile partials for the reduction across rows\. Overall, RMSNorm backward becomes GEMM–epilogue kernels plus lightweight auxiliary reductions over tile partials\. We give the full derivation and kernel organization in[Section˜A\.2](https://arxiv.org/html/2605.19269#A1.SS2), with benchmarks in[Figure˜11](https://arxiv.org/html/2605.19269#S4.F11)\.
### 3\.3Implementation
We implementCODAon top of CuTeDSL, which provides Python\-level kernel authoring while retaining low\-level control over details such as layouts and memory movement\.
Data movement\.Vector loads handle small broadcast operands such as RMSNorm weights\. These values are staged once in shared memory and reused across subtiles\. Tile loads handle larger operands such as residual activations, and uses Tensor Memory Accelerator transfers between global memory and shared memory, allowing data movement to overlap with epilogue computation\. Stores follow similar tile\-granular path for transformed outputs, saved intermediates, and reduction partials\.
Local computation\.Pairwise maps act on neighboring features, covering dimension\-preserving operations such as RoPE, dimension\-reducing operations such as SwiGLU, and dimension\-expanding operations in the backward pass\. When a map changes the feature dimension, the epilogue performs the corresponding local layout adjustment, such as compacting dimension\-reducing outputs or packing pairs of 16\-bit values for dimension\-expanding outputs\.
Reductions\.Tile\-wise reductions follow the ownership of GEMM output fragments\. Row\-wise reductions are accumulated by the warp that owns the row\. Column\-wise reductions may span multiple warps, so each warp first produces a partial result, and these partials are combined through shared memory\.
#### 3\.3\.1LLM\-Oriented Authoring
CODAis designed both for LLM workloads and for LLM\-assisted authoring\. Rather than asking a model to synthesize arbitrary CUDA or discover a hardware schedule from scratch,CODAexposes a constrained space of epilogue programs around expert\-designed GEMM mainloops\. Its primitives already encode efficient implementation strategies, so LLM\-based authoring becomes a problem of composing vector loads, tile loads, pairwise maps, reductions, and stores for a given Transformer computation\. This lightweight use of LLMs is complementary to prior work on kernel generation, which often relies on orchestration, search, execution feedback, or post\-training\[[12](https://arxiv.org/html/2605.19269#bib.bib12),[16](https://arxiv.org/html/2605.19269#bib.bib16),[3](https://arxiv.org/html/2605.19269#bib.bib3),[10](https://arxiv.org/html/2605.19269#bib.bib10),[24](https://arxiv.org/html/2605.19269#bib.bib24)\]\.
Because CuTeDSL is relatively new, current models have limited exposure to its idioms\. We therefore provide curated demonstrations for each abstraction\. In practice, the repository itself acts as a growing demonstration set, with new kernels being written by adapting and composing existing examples\.
##### Compositions\.
Transformer kernels often combine several epilogue operations, such as residual addition and RMSNorm scaling\. Monolithic fused epilogues lead to large, repetitive implementations that are difficult to place in context\.CODAinstead represents each epilogue as a composition of reusable primitives: the LLM specifies the local epilogue program, while the library supplies the fixed GEMM mainloop and implementation pattern for each primitive\. New fused kernels are therefore assembled from reusable building blocks instead of being rewritten from scratch\.
Figure 10:Kernel\-level speedups on reparameterized Transformer kernels relative to cuBLAS withtorch\.compile\. Raw GEMM baselines using PyTorch/cuBLAS and QuACK are included as reference ceilings, since they execute only the matrix multiplication and no epilogue work\.
## 4Experiments
After the reparameterizations in[Section˜3\.2](https://arxiv.org/html/2605.19269#S3.SS2), we obtain a compact benchmark suite of GEMM\-plus\-epilogue kernels spanning the Transformer\+\+ forward and backward pass \([Section˜C\.1](https://arxiv.org/html/2605.19269#A3.SS1)\)\. The suite covers nearly all computation outside attention, embeddings, auxiliary reductions, and lightweight glue operations\. We evaluate two implementations\.CODA \(LLM\)uses Claude Code to generate most kernels from a written specification, curated examples, and a running log of implementation tips, with lightweight human supervision\.CODA \(human\)is written by human programmers using the same high\-level reparameterizations, but without access to the exactCODAprimitive set\.
We compare against cuBLAS withtorch\.compile, as well as optimized LLM kernel libraries including Liger Kernel\[[6](https://arxiv.org/html/2605.19269#bib.bib6)\]and FlashInfer\[[23](https://arxiv.org/html/2605.19269#bib.bib23)\]\. Because our reparameterized kernels do not always have one\-to\-one counterparts in existing libraries, we compose the closest available optimized primitives and fall back to PyTorch operators as needed\. We applytorch\.compileto each method when compatible\. Additional setup details are given in[Section˜C\.2](https://arxiv.org/html/2605.19269#A3.SS2)\.
##### Kernel Benchmarks\.
We first evaluateCODAat the individual\-kernel level\. Unless otherwise noted, we use square GEMM shapes withM=N=K∈\{4096,8192\}M\{=\}N\{=\}K\\in\\\{4096,8192\\\}\. For cross\-entropy kernels, we set the vocabulary dimension to3276832768\. For RoPE kernels, we useNrope=3NN\_\{\\mathrm\{rope\}\}\{=\}3Nto account for QKV\-style projections and use precomputedcos\\cosandsin\\sintables\.333ForCODA, we additionally pre\-broadcast and extend these tables across batch, head, and QKV dimensions to avoid in\-kernel branching, at the cost of additional input traffic\.For kernels that emit partial reductions, we benchmark only the fused GEMM kernel with reduction tile size128128; auxiliary reductions are included in the block\-level benchmarks below\. We benchmark functions using Triton’sdo\_benchand show means and standard deviations across3030runs\.
We evaluate two groups of kernels from[Section˜C\.1](https://arxiv.org/html/2605.19269#A3.SS1)\. The first group consists of standard Transformer\-style kernels, such as GEMM with RoPE, SwiGLU, or cross\-entropy epilogues\. These kernels have close counterparts in existing libraries, so we compare against cuBLAS withtorch\.compile, Liger Kernels, and FlashInfer when applicable\. Results are shown in[Figure˜8](https://arxiv.org/html/2605.19269#S3.F8)\.
The second group consists of reparameterized Transformer forward and backward kernels, which generally do not have one\-to\-one equivalents in existing libraries\. For these kernels, we primarily compare against cuBLAS withtorch\.compile, and additionally report raw GEMM from PyTorch/cuBLAS and QuACK\.444[https://github\.com/Dao\-AILab/quack](https://github.com/Dao-AILab/quack)These raw GEMM omit epilogue work and therefore serve as reference ceilings for the attainable throughput\. Results are shown in[Figure˜10](https://arxiv.org/html/2605.19269#S3.F10)\.
Figure 11:Block\-level speedups for reparameterized Transformer kernel sequences, including auxiliary reductions and lightweight glue operations\. Here, a layer denotes two consecutive GEMM\-Residual\-RMSNorm\-GEMM blocks with the SwiGLU and RoPE activations, respectively\.
##### Block Benchmarks\.
We next benchmark kernel sequences corresponding to Transformer sublayers and full layers, which we call*blocks*\. We use hidden sizes in\{2048,4096,8192\}\\\{2048,4096,8192\\\}, roughly matching 1B, 7B, and 70B model scales, with FFN expansion rate8/38/3rounded to multiples of256256and vocabulary size3276832768\. Unlike isolated kernel benchmarks, these measurements include auxiliary reductions and lightweight glue operations\.
For the forward pass, we compare each reparameterized sequence against the closest available sequence of optimized operators\. For the backward pass, the reparameterization changes the dependency structure: each sublayer emits partial statistics needed by the preceding RMSNorm backward\. Individual backward sublayers therefore do not have direct PyTorch counterparts, so we report backward results at the full\-layer level\. InCODA, a layer consists of two consecutive GEMM\-Residual\-RMSNorm\-GEMM blocks covering the SwiGLU and RoPE paths\. Results are shown in[Figure˜11](https://arxiv.org/html/2605.19269#S4.F11)\.
## 5Conclusion and Limitations
CODAreparameterizes much of Transformer computation as GEMM epilogues, reducing memory\-bound overhead while preserving GEMM efficiency\. Its constrained abstraction supports high\-performance kernels authored by both humans and LLMs\.
##### Limitations\.
Our reparameterizations target a common Transformer architecture; extending them to broader model families is future work\.CODAcurrently focuses on single\-GPU kernels and does not yet address distributed execution\. Finally, while reparameterization improves efficiency, it can obscure module boundaries and algorithmic semantics, making integration with framework\-level abstractions more challenging\.
## Acknowledgment
We thank Beshr Islam Bouli, Kaiming Cheng, Xinle Cheng, Ryan Chin, Tarushii Goel, Wentao Guo, Lucas Torroba Hennigen, Alicia Li, Mayank Mishra, Jyothish Pari, Caiming Xiong, Nicholas Yap, Tianyuan Zhang, and Adam Zweiger for helpful discussion\. We gratefully acknowledge the support of the Schmidt Sciences AI2050 fellowship, the Google ML and Systems Junior Faculty Awards, the Google Research Scholar program, and the National Science Foundation \(Award \#2441872\)\.
## References
- Ansel et al\. \[2024\]J\. Ansel, E\. Yang, H\. He, N\. Gimelshein, A\. Jain, M\. Voznesensky, B\. Bao, P\. Bell, D\. Berard, E\. Burovski, et al\.Pytorch 2: Faster machine learning through dynamic python bytecode transformation and graph compilation\.In*Proceedings of the 29th ACM international conference on architectural support for programming languages and operating systems, volume 2*, pages 929–947, 2024\.
- Chen et al\. \[2018\]T\. Chen, T\. Moreau, Z\. Jiang, L\. Zheng, E\. Yan, H\. Shen, M\. Cowan, L\. Wang, Y\. Hu, L\. Ceze, et al\.\{\\\{TVM\}\\\}: An automated\{\\\{End\-to\-End\}\\\}optimizing compiler for deep learning\.In*13th USENIX Symposium on Operating Systems Design and Implementation \(OSDI 18\)*, pages 578–594, 2018\.
- Chen et al\. \[2025\]W\. Chen, J\. Zhu, Q\. Fan, Y\. Ma, and A\. Zou\.Cuda\-llm: Llms can write efficient cuda kernels\.*arXiv preprint arXiv:2506\.09092*, 2025\.
- Chen et al\. \[2024\]Z\. Chen, A\. Kerr, R\. Cai, J\. Kosaian, H\. Wu, Y\. Ding, and Y\. Xie\.Evt: Accelerating deep learning training with epilogue visitor tree\.In*Proceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 3*, pages 301–316, 2024\.
- Grattafiori et al\. \[2024\]A\. Grattafiori, A\. Dubey, A\. Jauhri, A\. Pandey, A\. Kadian, A\. Al\-Dahle, A\. Letman, A\. Mathur, A\. Schelten, A\. Vaughan, et al\.The llama 3 herd of models\.*arXiv preprint arXiv:2407\.21783*, 2024\.
- Hsu et al\. \[2024\]P\.\-L\. Hsu, Y\. Dai, V\. Kothapalli, Q\. Song, S\. Tang, S\. Zhu, S\. Shimizu, S\. Sahni, H\. Ning, and Y\. Chen\.Liger kernel: Efficient triton kernels for llm training\.*arXiv preprint arXiv:2410\.10989*, 2024\.
- Ivanov et al\. \[2021\]A\. Ivanov, N\. Dryden, T\. Ben\-Nun, S\. Li, and T\. Hoefler\.Data movement is all you need: A case study on optimizing transformers\.*Proceedings of Machine Learning and Systems*, 3:711–732, 2021\.
- Jia et al\. \[2019\]Z\. Jia, O\. Padon, J\. Thomas, T\. Warszawski, M\. Zaharia, and A\. Aiken\.Taso: optimizing deep learning computation with automatic generation of graph substitutions\.In*Proceedings of the 27th ACM Symposium on Operating Systems Principles*, pages 47–62, 2019\.
- Kwon et al\. \[2023\]W\. Kwon, Z\. Li, S\. Zhuang, Y\. Sheng, L\. Zheng, C\. H\. Yu, J\. Gonzalez, H\. Zhang, and I\. Stoica\.Efficient memory management for large language model serving with pagedattention\.In*Proceedings of the 29th symposium on operating systems principles*, pages 611–626, 2023\.
- Lange et al\. \[2025\]R\. T\. Lange, Q\. Sun, A\. Prasad, M\. Faldor, Y\. Tang, and D\. Ha\.Towards robust agentic cuda kernel benchmarking, verification, and optimization\.*arXiv preprint arXiv:2509\.14279*, 2025\.
- Liang et al\. \[2024\]W\. Liang, T\. Liu, L\. Wright, W\. Constable, A\. Gu, C\.\-C\. Huang, I\. Zhang, W\. Feng, H\. Huang, J\. Wang, et al\.Torchtitan: One\-stop pytorch native solution for production ready llm pre\-training\.*arXiv preprint arXiv:2410\.06511*, 2024\.
- Ouyang et al\. \[2025\]A\. Ouyang, S\. Guo, S\. Arora, A\. L\. Zhang, W\. Hu, C\. Ré, and A\. Mirhoseini\.Kernelbench: Can llms write efficient gpu kernels?*arXiv preprint arXiv:2502\.10517*, 2025\.
- Spector et al\. \[2025\]B\. Spector, J\. Juravsky, S\. Sul, O\. Dugan, D\. Lim, D\. Fu, S\. Arora, and C\. Ré\.Look ma, no bubbles\! designing a low\-latency megakernel for llama\-1b, 2025\.
- Spector et al\. \[2024\]B\. F\. Spector, S\. Arora, A\. Singhal, D\. Y\. Fu, and C\. Ré\.Thunderkittens: Simple, fast, and adorable ai kernels\.*arXiv preprint arXiv:2410\.20399*, 2024\.
- Su et al\. \[2024\]J\. Su, M\. Ahmed, Y\. Lu, S\. Pan, W\. Bo, and Y\. Liu\.Roformer: Enhanced transformer with rotary position embedding\.*Neurocomputing*, 568:127063, 2024\.
- Su et al\. \[2025\]S\. Su, X\. Sun, X\. Li, A\. Wang, J\. Li, and C\. Shum\.Cuda\-l2: Surpassing cublas performance for matrix multiplication through reinforcement learning\.*arXiv preprint arXiv:2512\.02551*, 2025\.
- Sul et al\. \[2025\]S\. H\. Sul, S\. Arora, B\. F\. Spector, and C\. Ré\.Parallelkittens: Systematic and practical simplification of multi\-gpu ai kernels\.*arXiv preprint arXiv:2511\.13940*, 2025\.
- Thakkar et al\. \[2023\]V\. Thakkar, P\. Ramani, C\. Cecka, A\. Shivam, H\. Lu, E\. Yan, J\. Kosaian, M\. Hoemmen, H\. Wu, A\. Kerr, M\. Nicely, D\. Merrill, D\. Blasig, A\. Atluri, F\. Qiao, P\. Majcher, P\. Springer, M\. Hohnerbach, J\. Wang, and M\. Gupta\.CUTLASS, Jan\. 2023\.URL[https://github\.com/NVIDIA/cutlass](https://github.com/NVIDIA/cutlass)\.
- Tillet et al\. \[2019\]P\. Tillet, H\.\-T\. Kung, and D\. Cox\.Triton: an intermediate language and compiler for tiled neural network computations\.In*Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages*, pages 10–19, 2019\.
- Wang et al\. \[2025\]L\. Wang, Y\. Cheng, Y\. Shi, Z\. Tang, Z\. Mo, W\. Xie, L\. Ma, Y\. Xia, J\. Xue, F\. Yang, et al\.Tilelang: A composable tiled programming model for ai systems\.*arXiv preprint arXiv:2504\.17577*, 2025\.
- Wijmans et al\. \[2025\]E\. Wijmans, B\. Huval, A\. Hertzberg, V\. Koltun, and P\. Krähenbühl\.Cut your losses in large\-vocabulary language models\.In*International Conference on Learning Representations*, 2025\.
- Wu et al\. \[2025\]M\. Wu, X\. Cheng, S\. Liu, C\. Shi, J\. Ji, M\. K\. Ao, P\. Velliengiri, X\. Miao, O\. Padon, and Z\. Jia\.Mirage: A\{\\\{Multi\-Level\}\\\}superoptimizer for tensor programs\.In*19th USENIX Symposium on Operating Systems Design and Implementation \(OSDI 25\)*, pages 21–38, 2025\.
- Ye et al\. \[2025\]Z\. Ye, L\. Chen, R\. Lai, W\. Lin, Y\. Zhang, S\. Wang, T\. Chen, B\. Kasikci, V\. Grover, A\. Krishnamurthy, et al\.Flashinfer: Efficient and customizable attention engine for llm inference serving\.*Proceedings of Machine Learning and Systems*, 7, 2025\.
- Yuksekgonul et al\. \[2026\]M\. Yuksekgonul, D\. Koceja, X\. Li, F\. Bianchi, J\. McCaleb, X\. Wang, J\. Kautz, Y\. Choi, J\. Zou, C\. Guestrin, et al\.Learning to discover at test time\.*arXiv preprint arXiv:2601\.16175*, 2026\.
- Zheng et al\. \[2024\]L\. Zheng, L\. Yin, Z\. Xie, C\. L\. Sun, J\. Huang, C\. H\. Yu, S\. Cao, C\. Kozyrakis, I\. Stoica, J\. E\. Gonzalez, et al\.Sglang: Efficient execution of structured language model programs\.*Advances in neural information processing systems*, 37:62557–62583, 2024\.
## Appendix ABackward Pass
### A\.1Tile\-wise Epilogue
Partition the GEMM output𝒉\{\\bm\{h\}\}into tiles𝒉\[i,j\]\{\\bm\{h\}\}\_\{\[i,j\]\}\. A tile\-wise epilogue applies an independent transformation to each tile:
𝒉\\displaystyle\{\\bm\{h\}\}=𝒙𝑾0,𝒉′=\[𝒇\[0,0\]\(𝒉\[0,0\]\)⋯𝒇\[0,N\]\(𝒉\[0,N\]\)⋮⋱⋮𝒇\[M,0\]\(𝒉\[M,0\]\)⋯𝒇\[M,N\]\(𝒉\[M,N\]\)\],𝒚=𝒉′𝑾1\.\\displaystyle=\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\},\\qquad\{\\bm\{h\}\}^\{\\prime\}=\\begin\{bmatrix\}\{\\bm\{f\}\}\[0,0\]\\left\(\{\\bm\{h\}\}\[0,0\]\\right\)&\\cdots&\{\\bm\{f\}\}\[0,N\]\\left\(\{\\bm\{h\}\}\[0,N\]\\right\)\\\\ \\vdots&\\ddots&\\vdots\\\\ \{\\bm\{f\}\}\[M,0\]\\left\(\{\\bm\{h\}\}\[M,0\]\\right\)&\\cdots&\{\\bm\{f\}\}\[M,N\]\\left\(\{\\bm\{h\}\}\[M,N\]\\right\)\\end\{bmatrix\},\\qquad\{\\bm\{y\}\}=\{\\bm\{h\}\}^\{\\prime\}\{\\bm\{W\}\}\_\{1\}\.
The backward pass has the same block structure\.
∇𝒉′ℒ\\displaystyle\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}=∇𝒚ℒ𝑾1⊤,∇𝒉ℒ=\[𝒈\[0,0\]\(∇𝒉′ℒ\[0,0\]\)⋯𝒈\[0,N\]\(∇𝒉′ℒ\[0,N\]\)⋮⋱⋮𝒈\[M,0\]\(∇𝒉′ℒ\[M,0\]\)⋯𝒈\[M,N\]\(∇𝒉′ℒ\[M,N\]\)\],∇𝒙ℒ=∇𝒉ℒ𝑾0⊤\.\\displaystyle=\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\},\\hskip 17\.00024pt\\nabla\_\{\{\\bm\{h\}\}\}\\mathcal\{L\}=\\begin\{bmatrix\}\{\\bm\{g\}\}\[0,0\]\\left\(\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}\[0,0\]\\right\)&\\cdots&\{\\bm\{g\}\}\[0,N\]\\left\(\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}\[0,N\]\\right\)\\\\ \\vdots&\\ddots&\\vdots\\\\ \{\\bm\{g\}\}\[M,0\]\\left\(\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}\[M,0\]\\right\)&\\cdots&\{\\bm\{g\}\}\[M,N\]\\left\(\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}\[M,N\]\\right\)\\end\{bmatrix\},\\hskip 17\.00024pt\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{h\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{0\}^\{\\top\}\.Here each𝒈\[i,j\]\{\\bm\{g\}\}\[i,j\]is the local backward rule for the corresponding tile:
𝒈\[i,j\]\(Δ\)=unvec\(𝑱\[i,j\]⊤vec\(Δ\)\),𝑱\[i,j\]=∂vec\(f\[i,j\]\(𝒉\[i,j\]\)\)∂vec\(𝒉\[i,j\]\)\.\\displaystyle\{\\bm\{g\}\}\[i,j\]\(\\Delta\)=\\operatorname\{unvec\}\\\!\\left\(\{\\bm\{J\}\}\_\{\[i,j\]\}^\{\\top\}\\operatorname\{vec\}\(\\Delta\)\\right\),\\qquad\{\\bm\{J\}\}\_\{\[i,j\]\}=\\frac\{\\partial\\operatorname\{vec\}\\left\(f\[i,j\]\\left\(\{\\bm\{h\}\}\[i,j\]\\right\)\\right\)\}\{\\partial\\operatorname\{vec\}\\left\(\{\\bm\{h\}\}\[i,j\]\\right\)\}\.
Thus, each gradient tile depends only on the corresponding forward tile and upstream gradient tile\. No cross\-tile communication is introduced, so the backward operation remains tile\-local and can be implemented as a GEMM epilogue\. As shown in[Figure˜9](https://arxiv.org/html/2605.19269#S3.F9), the only structural change is the direction of fusion: forward epilogues are fused into the GEMM that produces their input, while backward epilogues are fused into the GEMM that produces the gradient with respect to their output\.
### A\.2GEMM\-RMSNorm\-GEMM Backward Pass
We now describe the backward pass for the GEMM–epilogue–RMSNorm–GEMM pattern\. RMSNorm is the first case where the backward pass is not purely tile\-local\. The reason is simple: RMSNorm contains a row\-wise normalization factor, so its backward pass needs a row\-wise statistic\. In addition, the RMSNorm weight𝜸\\bm\{\\gamma\}is shared across rows, so its gradient requires a reduction across the row dimension\. The goal of this section is to show that these are the only non\-local pieces\. Everything else can still be fused into GEMM epilogues, with the non\-local pieces handled by lightweight reductions over tile partials\.
Consider the forward computation
𝒉0\\displaystyle\{\\bm\{h\}\}\_\{0\}=𝒙𝑾0,𝒉1=f\(𝒉0\),𝒉2=RMSNorm\(𝒉1,𝜸\),𝒚=𝒉2𝑾1\.\\displaystyle=\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\},\\qquad\{\\bm\{h\}\}\_\{1\}=f\(\{\\bm\{h\}\}\_\{0\}\),\\qquad\{\\bm\{h\}\}\_\{2\}=\\operatorname\{RMSNorm\}\(\{\\bm\{h\}\}\_\{1\},\\bm\{\\gamma\}\),\\qquad\{\\bm\{y\}\}=\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}\.Let𝒓\{\\bm\{r\}\}be the row\-wise inverse RMS factor\. We write𝒓¯=𝒓𝟏⊤\\overline\{\{\\bm\{r\}\}\}=\{\\bm\{r\}\}\\mathbf\{1\}^\{\\top\}and𝜸¯=𝟏𝜸⊤\\overline\{\\bm\{\\gamma\}\}=\\mathbf\{1\}\\bm\{\\gamma\}^\{\\top\}for the broadcasts of𝒓\{\\bm\{r\}\}and𝜸\\bm\{\\gamma\}to the shape of𝒉1\{\\bm\{h\}\}\_\{1\}\. Then
𝒉2=𝒓¯⊙𝒉1⊙𝜸¯\.\\displaystyle\{\\bm\{h\}\}\_\{2\}=\\overline\{\{\\bm\{r\}\}\}\\odot\{\\bm\{h\}\}\_\{1\}\\odot\\overline\{\\bm\{\\gamma\}\}\.
Given the upstream gradient∇𝒚ℒ\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}, the first backward operation is the GEMM
∇𝒉2ℒ=∇𝒚ℒ𝑾1⊤\.\\displaystyle\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\}\.The RMSNorm backward can be written as
∇𝒉1ℒ\\displaystyle\\nabla\_\{\{\\bm\{h\}\}\_\{1\}\}\\mathcal\{L\}=𝒓¯⊙\(∇𝒉2ℒ⊙𝜸¯−𝒓¯⊙𝒉1⊙𝒔¯\),\\displaystyle=\\overline\{\{\\bm\{r\}\}\}\\odot\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\\overline\{\\bm\{\\gamma\}\}\-\\overline\{\{\\bm\{r\}\}\}\\odot\{\\bm\{h\}\}\_\{1\}\\odot\\overline\{\{\\bm\{s\}\}\}\\right\),∇𝜸ℒ\\displaystyle\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}=sumrows\(∇𝒉2ℒ⊙𝒉1⊙𝒓¯\),\\displaystyle=\\operatorname\{sum\}\_\{\\mathrm\{rows\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\{\\bm\{h\}\}\_\{1\}\\odot\\overline\{\{\\bm\{r\}\}\}\\right\),where𝒔¯=𝒔𝟏⊤\\overline\{\{\\bm\{s\}\}\}=\{\\bm\{s\}\}\\mathbf\{1\}^\{\\top\}broadcasts one scalar per row\. The row\-wise statistic𝒔\{\\bm\{s\}\}is
𝒔\\displaystyle\{\\bm\{s\}\}=1d⊙𝒓⊙sumcols\(∇𝒉2ℒ⊙𝜸¯⊙𝒉1\),\\displaystyle=\\frac\{1\}\{d\}\\odot\{\\bm\{r\}\}\\odot\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\\overline\{\\bm\{\\gamma\}\}\\odot\{\\bm\{h\}\}\_\{1\}\\right\),=1dsumcols\(∇𝒉2ℒ⊙𝒓¯⊙𝜸¯⊙𝒉1\),\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\\overline\{\{\\bm\{r\}\}\}\\odot\\overline\{\\bm\{\\gamma\}\}\\odot\{\\bm\{h\}\}\_\{1\}\\right\),=1dsumcols\(∇𝒉2ℒ⊙𝒉2\),\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\{\\bm\{h\}\}\_\{2\}\\right\),whereddis the hidden dimension\. This expression identifies the two non\-local operations in RMSNorm backward\. The statistic𝒔\{\\bm\{s\}\}is a reduction across columns, producing one scalar per row\. The weight gradient∇𝜸ℒ\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}is a reduction across rows, producing one scalar per hidden feature\.
A standalone RMSNorm backward kernel would compute these reductions by reading activation\-sized tensors\. The key observation is that the row\-wise statistic𝒔\{\\bm\{s\}\}can be moved to a different boundary\. Using∇𝒉2ℒ=∇𝒚ℒ𝑾1⊤\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\\,\{\\bm\{W\}\}\_\{1\}^\{\\top\}and𝒚=𝒉2𝑾1\{\\bm\{y\}\}=\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}, we have
𝒔\\displaystyle\{\\bm\{s\}\}=1dsumcols\(∇𝒉2ℒ⊙𝒉2\),\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\{\\bm\{h\}\}\_\{2\}\\right\),=1dsumcols\(\(∇𝒚ℒ𝑾1⊤\)⊙𝒉2\)\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\(\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\}\)\\odot\{\\bm\{h\}\}\_\{2\}\\right\)=1ddiag\(\(∇𝒚ℒ𝑾1⊤\)𝒉2⊤\)\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{diag\}\\left\(\(\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\}\)\\;\{\\bm\{h\}\}\_\{2\}^\{\\top\}\\right\)=1ddiag\(∇𝒚ℒ\(𝒉2𝑾1\)⊤\)\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{diag\}\\left\(\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\\;\(\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}\)^\{\\top\}\\right\)=1dsumcols\(∇𝒚ℒ⊙\(𝒉2𝑾1\)\)\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\\odot\(\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}\)\\right\)=1dsumoutput\(∇𝒚ℒ⊙𝒚\)\.\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{output\}\}\\left\(\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\\odot\{\\bm\{y\}\}\\right\)\.Intuitively, the RMSNorm backward needs the inner product between an activation and its gradient along each row\. The identity above says that this inner product can be computed either before or after the following GEMM\. This lets us compute the statistic at a boundary where𝒚\{\\bm\{y\}\}and∇𝒚ℒ\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}are already available\.
This is useful because Transformer layers contain consecutive GEMM–epilogue–RMSNorm–GEMM patterns\. During the backward pass of one pattern, the GEMM that produces∇𝒙ℒ\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}already has access to both𝒙\{\\bm\{x\}\}and∇𝒙ℒ\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}\. Since this𝒙\{\\bm\{x\}\}is the output of the preceding pattern, the epilogue of the current pattern can accumulate the RMSNorm statistic needed by the preceding pattern:
𝒔^prev=reduceTilecols\(𝒙⊙∇𝒙ℒ\),𝒔prev=1dreduce\(𝒔^prev\)\.\\displaystyle\\widehat\{\{\\bm\{s\}\}\}\_\{\\mathrm\{prev\}\}=\\operatorname\{reduceTile\}\_\{\\mathrm\{cols\}\}\\left\(\{\\bm\{x\}\}\\odot\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}\\right\),\\qquad\{\\bm\{s\}\}\_\{\\mathrm\{prev\}\}=\\frac\{1\}\{d\}\\operatorname\{reduce\}\\left\(\\widehat\{\{\\bm\{s\}\}\}\_\{\\mathrm\{prev\}\}\\right\)\.Thus, each pattern computes the row\-wise RMSNorm backward statistic required by the pattern before it\. The reduction is still present, but it is now a small reduction over tile partials rather than a standalone activation\-sized RMSNorm backward kernel\.
The RMSNorm weight gradient is handled similarly, except that its reduction is across rows rather than columns\. We accumulate tile partials in the RMSNorm backward epilogue:
∇𝜸ℒ^=reduceTilerows\(∇𝒉2ℒ⊙𝒉1⊙𝒓¯\),∇𝜸ℒ=reducerows\(∇𝜸ℒ^\)\.\\displaystyle\\widehat\{\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}\}=\\operatorname\{reduceTile\}\_\{\\mathrm\{rows\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\{\\bm\{h\}\}\_\{1\}\\odot\\overline\{\{\\bm\{r\}\}\}\\right\),\\qquad\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}=\\operatorname\{reduce\}\_\{\\mathrm\{rows\}\}\\left\(\\widehat\{\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}\}\\right\)\.
Putting these pieces together, the backward pass is organized as follows:
GEMM 1:∇𝒉2ℒ\\displaystyle\\text\{GEMM 1:\}\\quad\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}=∇𝒚ℒ𝑾1⊤,\\displaystyle=\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\},Epilogue 1:∇𝒉1ℒ\\displaystyle\\text\{Epilogue 1:\}\\quad\\nabla\_\{\{\\bm\{h\}\}\_\{1\}\}\\mathcal\{L\}=𝒓¯⊙\(∇𝒉2ℒ⊙𝜸¯−𝒉1⊙𝒓¯⊙𝒔¯\),\\displaystyle=\\overline\{\{\\bm\{r\}\}\}\\odot\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\\overline\{\\bm\{\\gamma\}\}\-\{\\bm\{h\}\}\_\{1\}\\odot\\overline\{\{\\bm\{r\}\}\}\\odot\\overline\{\{\\bm\{s\}\}\}\\right\),∇𝒉0ℒ\\displaystyle\\nabla\_\{\{\\bm\{h\}\}\_\{0\}\}\\mathcal\{L\}=g\(∇𝒉1ℒ\),\\displaystyle=g\(\\nabla\_\{\{\\bm\{h\}\}\_\{1\}\}\\mathcal\{L\}\),∇𝜸ℒ^\\displaystyle\\widehat\{\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}\}=reduceTilerows\(∇𝒉2ℒ⊙𝒉1⊙𝒓¯\),\\displaystyle=\\operatorname\{reduceTile\}\_\{\\mathrm\{rows\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\{\\bm\{h\}\}\_\{1\}\\odot\\overline\{\{\\bm\{r\}\}\}\\right\),GEMM 2:∇𝒙ℒ\\displaystyle\\text\{GEMM 2:\}\\quad\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}=∇𝒉0ℒ𝑾0⊤,\\displaystyle=\\nabla\_\{\{\\bm\{h\}\}\_\{0\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{0\}^\{\\top\},Epilogue 2:𝒔^prev\\displaystyle\\text\{Epilogue 2:\}\\quad\\widehat\{\{\\bm\{s\}\}\}\_\{\\mathrm\{prev\}\}=reduceTilecols\(𝒙⊙∇𝒙ℒ\),\\displaystyle=\\operatorname\{reduceTile\}\_\{\\mathrm\{cols\}\}\\left\(\{\\bm\{x\}\}\\odot\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}\\right\),Auxiliary reductions:𝒔prev\\displaystyle\\text\{Auxiliary reductions:\}\\quad\{\\bm\{s\}\}\_\{\\mathrm\{prev\}\}=1dreducecols\(𝒔^prev\),\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{reduce\}\_\{\\mathrm\{cols\}\}\\left\(\\widehat\{\{\\bm\{s\}\}\}\_\{\\mathrm\{prev\}\}\\right\),∇𝜸ℒ\\displaystyle\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}=reducerows\(∇𝜸ℒ^\)\.\\displaystyle=\\operatorname\{reduce\}\_\{\\mathrm\{rows\}\}\\left\(\\widehat\{\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}\}\\right\)\.Hereggdenotes the tile\-local backward rule for the epilogueff\. The statistic𝒔\{\\bm\{s\}\}used inEpilogue 1is assumed to have already been accumulated by the following pattern in the backward order\.
Finally, the output of a pattern often passes through another epilogue before the next pattern begins:
𝒉0\\displaystyle\{\\bm\{h\}\}\_\{0\}=𝒙𝑾0,𝒉1=f0\(𝒉0\),𝒉2=RMSNorm\(𝒉1,𝜸\),𝒉3=𝒉2𝑾1𝒚=f1\(𝒉3\)\.\\displaystyle=\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\},\\qquad\{\\bm\{h\}\}\_\{1\}=f\_\{0\}\(\{\\bm\{h\}\}\_\{0\}\),\\qquad\{\\bm\{h\}\}\_\{2\}=\\operatorname\{RMSNorm\}\(\{\\bm\{h\}\}\_\{1\},\\bm\{\\gamma\}\),\\qquad\{\\bm\{h\}\}\_\{3\}=\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}\\qquad\{\\bm\{y\}\}=f\_\{1\}\(\{\\bm\{h\}\}\_\{3\}\)\.In this case, the statistic should be accumulated using the pre\-epilogue tensor𝒉3\{\\bm\{h\}\}\_\{3\}and its gradient\. The backward rule forf1f\_\{1\}is tile\-local, so it can be fused before accumulating the statistic:
GEMM 2:∇𝒙ℒ\\displaystyle\\text\{GEMM 2:\}\\quad\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}=∇𝒉0ℒ𝑾0⊤,\\displaystyle=\\nabla\_\{\{\\bm\{h\}\}\_\{0\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{0\}^\{\\top\},Epilogue 2:∇𝒉3←ℒ\\displaystyle\\text\{Epilogue 2:\}\\quad\\nabla\_\{\\overleftarrow\{\{\\bm\{h\}\}\_\{3\}\}\}\\mathcal\{L\}=f1←\(∇𝒙ℒ\),\\displaystyle=\\overleftarrow\{f\_\{1\}\}\(\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}\),𝒔^prev\\displaystyle\\widehat\{\{\\bm\{s\}\}\}\_\{\\mathrm\{prev\}\}=reduceTilecols\(𝒉3←⊙∇𝒉3←ℒ\)\\displaystyle=\\operatorname\{reduceTile\}\_\{\\mathrm\{cols\}\}\\left\(\\overleftarrow\{\{\\bm\{h\}\}\_\{3\}\}\\odot\\nabla\_\{\\overleftarrow\{\{\\bm\{h\}\}\_\{3\}\}\}\\mathcal\{L\}\\right\)Here𝒉3←\\overleftarrow\{\{\\bm\{h\}\}\_\{3\}\}denotes the pre\-epilogue GEMM output from the preceding pattern, andf1←\\overleftarrow\{f\_\{1\}\}denotes the local backward rule for its following epilogue\. This case preserves the same structure: apply the local backward epilogue first, then accumulate the row\-wise statistic from the pre\-epilogue activation and its gradient\.
Overall, this organization removes the activation\-sized RMSNorm backward kernel\. The remaining non\-local work consists only of reductions over tile partials: a column reduction for the row\-wise RMSNorm statistic, and a row reduction for the RMSNorm weight gradient\. The GEMMs and local backward updates are fused into GEMM epilogues, preserving the same GEMM\-plus\-epilogue structure used in the forward pass\.
## Appendix BCODA
### B\.1Epilogue Template
1
2
3
4epilogue\.consumer\_begin\(\.\.\.\)
5epilogue\.producer\_begin\(\.\.\.\)
6
7
8forepi\_idxinrange\(num\_epi\_tiles\):
9
10epilogue\.consumer\_begin\_loop\(\.\.\.\)
11
12epilogue\.producer\_tma\_load\(\.\.\.\)
13
14rD=load\_accumulator\_fragment\(\.\.\.\)
15epilogue\.consumer\_visit\(rD,\.\.\.\)
16
17store\_regs\_to\_smem\(\.\.\.\)
18
19epilogue\.consumer\_smem\_store\(\.\.\.\)
20tma\_store\_from\_smem\_to\_gmem\(\.\.\.\)
21epilogue\.consumer\_tma\_store\(\.\.\.\)
22
23epilogue\.consumer\_end\_loop\(gmem\_coord\)
24
25
26epilogue\.consumer\_end\(\.\.\.\)
Listing 1:Epilogue Kernel Abstraction\.
### B\.2Epilogue Example
1def\_create\_mean\_sq\_reduction\_op\(element\_type,inv\_block\_size\):
2"""Createareductionopthataccumulatesmeanofsquares:acc\+val^2\*inv\_block\_size\.
3
4Thecombine\_fnsquareseachnewelementandscalesby1/block\_sizebeforeadding
5totheaccumulator\.Thewarp\-levelreductionusesstandardadditionsincepartial
6sumsarealreadyaccumulatedandscaled\.
7"""
8init\_value=element\_type\(0\.\)
9inv\_bs=element\_type\(inv\_block\_size\)
10
11\_sq\_combine=lambdax,y:x\+y\*y\*inv\_bs
12\_add\_wrp=lambdatree\_x,tree\_y:pytree\.tree\_map\(operator\.add,tree\_x,tree\_y\)
13
14returnBlockReductionOp\(
15combine\_fn=lambdatree\_x,tree\_y:pytree\.tree\_map\(\_sq\_combine,tree\_x,tree\_y\),
16reduce\_ssa=None,
17reduce\_wrp=lambdaxs:pytree\.tree\_map\(
18lambdax:cute\.arch\.warp\_reduction\(
19x,
20op=\_add\_wrp,
21threads\_in\_group=HOPPER\_WARP\_REDUCTION\_WIDTH,
22\),
23xs,
24\),
25init\_value=init\_value,
26\)
27
28
29classEVTRowVecMulPostAct\(EpilogueVisitorTree\):
30"""
31Loadsaper\-NrowvectorW\(cp\.asynctosmem,thens2r\),multipliesthe
32accumulatorbyWintoaseparateregistertile,andstoresthatscaled
33tiletoasideoutputmPostActviaTMA\.tRS\_rDitselfisleftunchanged
34sothemainDoutput\(theunscaledGEMMresult\)isunaffected\.
35
36Thismirrorstherowvec=norm\_weightside\-outputpathoftrainstation’s
37‘gemm\_partial\_rms\_fwd‘,keptlocaltothiskernelratherthanasa
38general\-purposerapierEVT\.
39
40Inputs:
41\-GEMMoutput\(inregisters\):\[MxN\],unchangedbythisop
42\-mRowVec:\[L,N\]\-RMSNormweight,broadcastalongM
43
44Outputs:
45\-mPostAct:\[MxN\]=D\*W\(sideoutput,writtenviaTMA\)
46"""
47
48@struct\_utils\.mlir\_namedtuple
49classEpilogueArguments\(NamedTuple\):
50mPostAct:cute\.Tensor\|None
51mRowVec:cute\.Tensor\|None
52
53@struct\_utils\.register\_pytree\_dataclass
54@dataclass
55classEpilogueParams\(EpilogueVisitorTree\.EpilogueParams\):
56mPostAct:cute\.Tensor\|None
57mRowVec:cute\.Tensor\|None
58epi\_tma\_atom:cute\.CopyAtom
59epi\_gmem\_layout:cutlass\.utils\.LayoutEnum
60epi\_smem\_layout\_staged:cute\.Layout
61
62@struct\_utils\.register\_pytree\_dataclass
63@dataclass
64classEpilogueTensorsSMem\(EpilogueVisitorTree\.EpilogueTensorsSMem\):
65sPostAct:cute\.Tensor\|None
66sRowVec:cute\.Tensor\|None
67
68@struct\_utils\.register\_pytree\_dataclass
69@dataclass
70classEpilogueTensors\(EpilogueVisitorTree\.EpilogueTensors\):
71tDsPostAct:cute\.Tensor
72tDgPostAct:cute\.Tensor
73tRS\_sPostAct:cute\.Tensor
74epi\_tma\_atom:cute\.CopyAtom
75tiled\_copy\_postact\_r2s:cute\.TiledCopy
76tDsRowVec:cute\.Tensor\|None
77
78@struct\_utils\.register\_pytree\_dataclass
79@dataclass
80classEpilogueTensorsLoop\(EpilogueVisitorTree\.EpilogueTensorsLoop\):
81tDsPostAct:cute\.Tensor
82tDgPostAct:cute\.Tensor
83tRS\_rPostAct:cute\.Tensor\|None
84tRS\_sPostAct:cute\.Tensor
85epi\_tma\_atom:cute\.CopyAtom
86tiled\_copy\_postact\_r2s:cute\.TiledCopy
87tDrRowVec\_epi:cute\.Tensor\|None
88
89@struct\_utils\.register\_pytree\_dataclass
90@dataclass
91classEpiloguePipelines\(EpilogueVisitorTree\.EpiloguePipelines\):
92pass
93
94def\_\_init\_\_\(
95self,
96acc\_dtype:type\[cute\.Numeric\],
97post\_act\_dtype:type\[cute\.Numeric\],
98tile\_shape\_mnk:tuple\[int,int,int\],
99buffer\_align\_bytes:int,
100\)\-\>None:
101super\(\)\.\_\_init\_\_\(\)
102self\.arch=90
103self\.acc\_dtype=acc\_dtype
104self\.post\_act\_dtype=post\_act\_dtype
105self\.container\_dtype=post\_act\_dtype
106self\.tile\_shape\_mnk=tile\_shape\_mnk
107self\.buffer\_align\_bytes=buffer\_align\_bytes
108
109@cute\.jit
110defto\_underlying\_arguments\(
111self,
112epi\_tile:cute\.Tile,
113epi\_stage:int,
114epi\_load\_stage:int,
115epi\_args:EpilogueArguments,
116\)\-\>EpilogueParams:
117
118ifcutlass\.const\_expr\(epi\_args\.mPostActisnotNone\):
119mPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_args\.mPostAct\)
120misc\_utils\.static\_assert\(get\_dtype\(mPostAct\)isself\.container\_dtype\)
121\(
122epi\_gmem\_layout,
123epi\_smem\_layout\_staged,
124epi\_tma\_atom,
125epi\_tma\_tensor,
126\)=epilogue\_utils\.prepare\_tma\(
127tma\_op="s2g",
128epi\_tile=epi\_tile,
129epi\_stage=epi\_stage,
130epi\_tensor=mPostAct,
131\)
132
133ifcutlass\.const\_expr\(epi\_args\.mRowVecisnotNone\):
134misc\_utils\.static\_assert\(epi\_args\.mPostActisnotNone\)
135mRowVec=misc\_utils\.static\_assert\_is\_Tensor\(epi\_args\.mRowVec\)
136mRowVec=layout\_utils\.assumed\_align\_stride\(
137mRowVec,
138assumed\_align=4,
139\)
140else:
141mRowVec=None
142
143returnself\.EpilogueParams\(
144mPostAct=epi\_tma\_tensor,
145mRowVec=mRowVec,
146epi\_tma\_atom=epi\_tma\_atom,
147epi\_gmem\_layout=epi\_gmem\_layout,
148epi\_smem\_layout\_staged=epi\_smem\_layout\_staged,
149\)
150
151@cute\.jit
152defprefetch\_tma\_descriptors\(
153self,
154epi\_params:EpilogueParams,
155\)\-\>None:
156cute\.nvgpu\.cpasync\.prefetch\_descriptor\(epi\_params\.epi\_tma\_atom\)
157
158@cute\.jit
159defconsumer\_begin\(
160self,
161tiled\_copy\_r2s:cute\.TiledCopy,
162tile\_coord\_mnkl:cute\.Coord,
163tidx:cute\.Int32,
164tiled\_mma:cute\.TiledMma,
165tRS\_rD\_layout:cute\.Layout,
166epi\_tile:cute\.Tile,
167epi\_num\_threads:int,
168epi\_num\_matrices:int,
169epi\_barrier:cutlass\.pipeline\.NamedBarrier,
170epi\_params:EpilogueParams,
171epi\_tensors\_smem:EpilogueTensorsSMem,
172\)\-\>EpilogueTensors:
173
174tile\_M=self\.tile\_shape\_mnk\[0\]
175tile\_N=self\.tile\_shape\_mnk\[1\]
176m\_idx,n\_idx,\_,batch\_idx=tile\_coord\_mnkl
177thr\_copy\_r2s=tiled\_copy\_r2s\.get\_slice\(tidx\)
178
179
180mPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_params\.mPostAct\)
181sPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_smem\.sPostAct\)
182tiled\_copy\_postact\_r2s,\_,tRS\_sPostAct=epilogue\_utils\.prepare\_copy\_r2s\_sm90\(
183tiled\_copy\_r2s=tiled\_copy\_r2s,
184tidx=tidx,
185dst=sPostAct,
186epi\_layout=epi\_params\.epi\_gmem\_layout,
187epi\_dtype=self\.container\_dtype,
188acc\_dtype=self\.acc\_dtype,
189\)
190gPostAct=mPostAct\[None,None,batch\_idx\]
191gPostAct=cute\.local\_tile\(gPostAct,\(tile\_M,tile\_N\),\(m\_idx,n\_idx\)\)
192gPostAct=cute\.zipped\_divide\(gPostAct,epi\_tile\)
193
194tDsPostAct,tDgPostAct=cute\.nvgpu\.cpasync\.tma\_partition\(
195atom=epi\_params\.epi\_tma\_atom,
196cta\_coord=0,
197cta\_layout=cute\.make\_layout\(1\),
198smem\_tensor=cute\.group\_modes\(sPostAct,0,cute\.rank\(sPostAct\)\-1\),
199gmem\_tensor=cute\.group\_modes\(gPostAct,0,cute\.rank\(gPostAct\)\-1\),
200\)
201
202
203ifcutlass\.const\_expr\(epi\_params\.mRowVecisnotNone\):
204mRowVec=misc\_utils\.static\_assert\_is\_Tensor\(epi\_params\.mRowVec\)
205sRowVec=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_smem\.sRowVec\)
206mRowVec=mRowVec\[batch\_idx,None\]
207gRowVec=cute\.local\_tile\(mRowVec,\(tile\_N,\),\(n\_idx,\)\)
208cRowVec=cute\.make\_identity\_tensor\(tile\_N\)
209limit\_n=min\(mRowVec\.shape\[0\]\-n\_idx\*tile\_N,tile\_N\)
210memory\_utils\.g2s\_copy\_1d\(
211src=gRowVec,
212dst=sRowVec,
213crd=cRowVec,
214shape=\(limit\_n,\),
215num\_threads=epi\_num\_threads,
216thread\_index=tidx,
217\)
218sRowVec\_view\_layout=cute\.make\_layout\(
219shape=\(tile\_M,tile\_N\),
220stride=\(0,1\),
221\)
222sRowVec\_view=cute\.make\_tensor\(
223iterator=sRowVec\.iterator,
224layout=sRowVec\_view\_layout,
225\)
226tDsRowVec=thr\_copy\_r2s\.partition\_S\(
227cute\.flat\_divide\(sRowVec\_view,epi\_tile\)
228\)
229cute\.arch\.cp\_async\_commit\_group\(\)
230cute\.arch\.cp\_async\_wait\_group\(0\)
231epi\_barrier\.arrive\_and\_wait\(\)
232else:
233tDsRowVec=None
234
235returnself\.EpilogueTensors\(
236tDsPostAct=tDsPostAct,
237tDgPostAct=tDgPostAct,
238tRS\_sPostAct=tRS\_sPostAct,
239epi\_tma\_atom=epi\_params\.epi\_tma\_atom,
240tiled\_copy\_postact\_r2s=tiled\_copy\_postact\_r2s,
241tDsRowVec=tDsRowVec,
242\)
243
244@cute\.jit
245defconsumer\_end\(
246self,
247tiled\_copy\_r2s:cute\.TiledCopy,
248tile\_coord\_mnkl:cute\.Coord,
249tidx:cute\.Int32,
250shape\_mnk:cute\.Shape,
251epi\_tile:cute\.Tile,
252epi\_num\_threads:int,
253epi\_barrier:cutlass\.pipeline\.NamedBarrier,
254epi\_params:EpilogueParams,
255epi\_tensors:EpilogueTensors,
256epi\_tensors\_smem:EpilogueTensorsSMem,
257\)\-\>None:
258pass
259
260@cute\.jit
261defconsumer\_begin\_loop\(
262self,
263epi\_coord:cute\.Coord,
264epi\_params:EpilogueParams,
265epi\_tensors:EpilogueTensors,
266epi\_pipelines:EpiloguePipelines,
267\)\-\>tuple\[EpilogueTensorsLoop,EpiloguePipelines\]:
268
269ifcutlass\.const\_expr\(epi\_tensors\.tDsRowVecisnotNone\):
270tDsRowVec=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\.tDsRowVec\)
271tDsRowVec\_cur=cute\.group\_modes\(tDsRowVec,3,cute\.rank\(tDsRowVec\)\)
272tDsRowVec\_cur=tDsRowVec\_cur\[None,None,None,epi\_coord\]
273tDrRowVec\_cvt=memory\_utils\.s2r\_copy\_1d\(tDsRowVec\_cur,dtype=self\.acc\_dtype\)
274else:
275tDrRowVec\_cvt=None
276
277return\(
278self\.EpilogueTensorsLoop\(
279tDsPostAct=epi\_tensors\.tDsPostAct,
280tDgPostAct=epi\_tensors\.tDgPostAct,
281tRS\_rPostAct=None,
282tRS\_sPostAct=epi\_tensors\.tRS\_sPostAct,
283epi\_tma\_atom=epi\_tensors\.epi\_tma\_atom,
284tiled\_copy\_postact\_r2s=epi\_tensors\.tiled\_copy\_postact\_r2s,
285tDrRowVec\_epi=tDrRowVec\_cvt,
286\),
287self\.EpiloguePipelines\(\),
288\)
289
290@cute\.jit
291defconsumer\_visit\(
292self,
293tRS\_rD:cute\.Tensor,
294shape\_mnk:cute\.Shape,
295epi\_params:EpilogueParams,
296epi\_tensors\_loop:EpilogueTensorsLoop,
297\)\-\>EpilogueTensorsLoop:
298
299tRS\_rPostAct=creation\_utils\.allocate\_tensor\_like\(
300tensor=tRS\_rD,
301memspace="rmem",
302smem\_allocator=None,
303dtype=self\.acc\_dtype,
304\)
305ifcutlass\.const\_expr\(self\.arch<100\):
306ifcutlass\.const\_expr\(epi\_tensors\_loop\.tDrRowVec\_epiisnotNone\):
307tDrRowVec\_epi=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_loop\.tDrRowVec\_epi\)
308foriincutlass\.range\_constexpr\(cute\.size\(tRS\_rPostAct\)\):
309tRS\_rPostAct\[i\]=tRS\_rD\[i\]\*tDrRowVec\_epi\[i\]
310else:
311foriincutlass\.range\_constexpr\(cute\.size\(tRS\_rPostAct\)\):
312tRS\_rPostAct\[i\]=tRS\_rD\[i\]
313else:
314raiseNotImplementedError
315
316tRS\_rPostAct=dtype\_utils\.convert\(
317tRS\_rPostAct,
318dtype=self\.post\_act\_dtype,
319\)
320
321returnself\.EpilogueTensorsLoop\(
322tDsPostAct=epi\_tensors\_loop\.tDsPostAct,
323tDgPostAct=epi\_tensors\_loop\.tDgPostAct,
324tRS\_rPostAct=tRS\_rPostAct,
325tRS\_sPostAct=epi\_tensors\_loop\.tRS\_sPostAct,
326epi\_tma\_atom=epi\_tensors\_loop\.epi\_tma\_atom,
327tiled\_copy\_postact\_r2s=epi\_tensors\_loop\.tiled\_copy\_postact\_r2s,
328tDrRowVec\_epi=epi\_tensors\_loop\.tDrRowVec\_epi,
329\)
330
331@cute\.jit
332defconsumer\_smem\_store\(
333self,
334epi\_coord:cute\.Coord,
335epi\_buffer:cute\.Int32,
336epi\_params:EpilogueParams,
337epi\_tensors\_loop:EpilogueTensorsLoop,
338\)\-\>None:
339tiled\_copy=epi\_tensors\_loop\.tiled\_copy\_postact\_r2s
340tRS\_rPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_loop\.tRS\_rPostAct\)
341tRS\_sPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_loop\.tRS\_sPostAct\)
342src=tiled\_copy\.retile\(tRS\_rPostAct\)
343dst=tRS\_sPostAct\[None,None,None,epi\_buffer\]
344cute\.copy\(atom=tiled\_copy,src=src,dst=dst\)
345
346@cute\.jit
347defconsumer\_tma\_store\(
348self,
349epi\_coord:cute\.Coord,
350epi\_buffer:cute\.Int32,
351epi\_params:EpilogueParams,
352epi\_tensors\_loop:EpilogueTensorsLoop,
353\)\-\>None:
354atom=epi\_tensors\_loop\.epi\_tma\_atom
355tDsPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_loop\.tDsPostAct\)
356tDgPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_loop\.tDgPostAct\)
357src=tDsPostAct\[None,epi\_buffer\]
358dst=tDgPostAct\[None,epi\_coord\]
359cute\.copy\(atom=atom,src=src,dst=dst\)
360
361@cute\.jit
362defget\_smem\_struct\(
363self,
364epi\_load\_stage:int,
365epi\_num\_threads:int,
366epi\_params:EpilogueParams,
367\)\-\>type\[EpilogueSharedStorage\]:
368
369ifcutlass\.const\_expr\(epi\_params\.mPostActisnotNone\):
370post\_act\_smem\_size=cute\.cosize\(epi\_params\.epi\_smem\_layout\_staged\)
371else:
372post\_act\_smem\_size=0
373
374ifcutlass\.const\_expr\(epi\_params\.mRowVecisnotNone\):
375mRowVec=misc\_utils\.static\_assert\_is\_Tensor\(epi\_params\.mRowVec\)
376row\_vec\_dtype=get\_dtype\(mRowVec\)
377row\_vec\_smem\_size=epilogue\_utils\.get\_smem\_size\_vector\(
378mTensor=mRowVec,
379epi\_tile=self\.tile\_shape\_mnk\[1\],
380epi\_num\_threads=epi\_num\_threads,
381\)
382else:
383row\_vec\_dtype=cute\.Float32
384row\_vec\_smem\_size=0
385
386@cute\.struct
387classSharedStorage\(EpilogueSharedStorage\):
388sPostAct:cute\.struct\.Align\[cute\.struct\.MemRange\[self\.container\_dtype,post\_act\_smem\_size\],self\.buffer\_align\_bytes\]
389sRowVec:cute\.struct\.Align\[cute\.struct\.MemRange\[row\_vec\_dtype,row\_vec\_smem\_size\],16\]
390
391returnSharedStorage
392
393@cute\.jit
394defget\_smem\_tensors\(
395self,
396storage:EpilogueSharedStorage,
397epi\_num\_threads:int,
398epi\_params:EpilogueParams,
399\)\-\>EpilogueTensorsSMem:
400
401ifcutlass\.const\_expr\(epi\_params\.mPostActisnotNone\):
402sPostAct=storage\.sPostAct\.get\_tensor\(
403epi\_params\.epi\_smem\_layout\_staged\.outer,
404swizzle=epi\_params\.epi\_smem\_layout\_staged\.inner,
405\)
406else:
407sPostAct=None
408
409ifcutlass\.const\_expr\(epi\_params\.mRowVecisnotNone\):
410sRowVec\_layout=cute\.make\_layout\(self\.tile\_shape\_mnk\[1\]\)
411sRowVec=storage\.sRowVec\.get\_tensor\(sRowVec\_layout\)
412else:
413sRowVec=None
414
415returnself\.EpilogueTensorsSMem\(
416sPostAct=sPostAct,
417sRowVec=sRowVec,
418\)
419
420@cute\.jit
421defget\_smem\_bytes\_per\_stage\(
422self,
423epi\_tile:cute\.Tile,
424epi\_num\_threads:int,
425epi\_args:EpilogueArguments,
426\)\-\>tuple\[int,int,int\]:
427epi\_smem\_bytes\_fixed=0
428epi\_smem\_bytes\_per\_stage\_cst=0
429epi\_smem\_bytes\_per\_stage\_pld=0
430
431ifcutlass\.const\_expr\(epi\_args\.mPostActisnotNone\):
432mPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_args\.mPostAct\)
433misc\_utils\.static\_assert\(get\_dtype\(mPostAct\)isself\.container\_dtype\)
434epi\_smem\_bytes\_per\_stage\_cst=epi\_smem\_bytes\_per\_stage\_cst\+\(
435epilogue\_utils\.get\_epi\_smem\_bytes\_per\_stage\_matrix\(
436mTensor=mPostAct,
437epi\_tile=epi\_tile,
438\)
439\)
440
441ifcutlass\.const\_expr\(epi\_args\.mRowVecisnotNone\):
442mRowVec=misc\_utils\.static\_assert\_is\_Tensor\(epi\_args\.mRowVec\)
443epi\_smem\_bytes\_fixed=epi\_smem\_bytes\_fixed\+\(
444epilogue\_utils\.get\_epi\_smem\_bytes\_per\_stage\_fixed\_vector\(
445mTensor=mRowVec,
446epi\_tile=self\.tile\_shape\_mnk\[1\],
447epi\_num\_threads=epi\_num\_threads,
448\)
449\)
450
451return\(
452epi\_smem\_bytes\_fixed,
453epi\_smem\_bytes\_per\_stage\_cst,
454epi\_smem\_bytes\_per\_stage\_pld,
455\)
456
457
458defprepare\_epilogue\(
459shape\_mnkl:tuple\[int,int,int,int\],
460tile\_shape\_mn:tuple\[int,int\],
461C:torch\.Tensor,
462S:torch\.Tensor,
463W:torch\.Tensor,
464O:torch\.Tensor,
465\)\-\>tuple\[
466Callable\[\.\.\.,EpilogueVisitorTree\],
467EpilogueVisitorTree\.EpilogueArguments,
468dict,
469tuple,
470\]:
471"""PrepareepilogueforGEMMwithresidual,partialmean\-of\-squares,and
472fusedper\-NRMSNorm\-weightscaling\-mirrorstrainstation’s‘gemm\_partial\_rms\_fwd‘\.
473
474ComposesthreeEVTvisitors:
4751\.EVTResidual:D=acc\+C
4762\.EVTColBlockReductionStore:S\[m,nb\]=mean\(D\[m,nb\*bs:\(nb\+1\)\*bs\]^2\)
4773\.EVTRowVecMulPostAct\(local\):O\[m,n\]=D\[m,n\]\*W\[n\],sideoutputviaTMA
478
479Thepartialsum\-of\-squaresiscomputedonthe\*unscaled\*D,soadownstream
480rstdreductionseestheGEMMoutputbeforeWisapplied\.tRS\_rDispreserved
481sothemainDoutputisalsounscaled\.
482
483Args:
484shape\_mnkl:Problemshape\(M,N,K,L\)whereLisbatchdimension\.
485tile\_shape\_mn:CTAtileshape\(tile\_M,tile\_N\)\.
486C:Residualmatrixofshape\(M,N\)\.
487S:Outputforpartialmean\-of\-squaresofshape\(M,num\_blocks\)infp32\.
488W:RMSNormweightofshape\(N,\),broadcastacrossM\.
489O:Outputofshape\(M,N\)forD\*W\.
490
491Returns:
492Tupleof\(epi\_cls,epi\_args,epi\_outs,epi\_keys\)\.
493"""
494M,N,K,L=shape\_mnkl
495
496epi\_dtype=torch2cute\_dtype\_map\[C\.dtype\]
497post\_act\_dtype=torch2cute\_dtype\_map\[O\.dtype\]
498
499epi\_cls=lambdaacc\_dtype,tile\_shape\_mnk,buffer\_align\_bytes:EVTList\(\[
500EVTResidual\(
501acc\_dtype=acc\_dtype,
502epi\_dtype=epi\_dtype,
503tile\_shape\_mnk=tile\_shape\_mnk,
504buffer\_align\_bytes=buffer\_align\_bytes,
505\),
506EVTColBlockReductionStore\(
507reduction\_op=\_create\_mean\_sq\_reduction\_op\(
508element\_type=acc\_dtype,
509inv\_block\_size=1\.0/tile\_shape\_mnk\[1\],
510\),
511tile\_shape\_mnk=tile\_shape\_mnk,
512\),
513EVTRowVecMulPostAct\(
514acc\_dtype=acc\_dtype,
515post\_act\_dtype=post\_act\_dtype,
516tile\_shape\_mnk=tile\_shape\_mnk,
517buffer\_align\_bytes=buffer\_align\_bytes,
518\),
519\]\)
520
521epi\_args=EVTList\.EpilogueArguments\(\[
522EVTResidual\.EpilogueArguments\(
523mMatrix=C,
524\),
525EVTColBlockReductionStore\.EpilogueArguments\(
526mColVec=S,
527\),
528EVTRowVecMulPostAct\.EpilogueArguments\(
529mPostAct=O,
530mRowVec=W,
531\),
532\]\)
533
534epi\_keys=\(
535C\.dtype,
536S\.dtype,
537W\.dtype,
538O\.dtype,
539EVTResidual,
540EVTColBlockReductionStore,
541EVTRowVecMulPostAct,
542\)
543
544epi\_outs=\{\}
545
546returnepi\_cls,epi\_args,epi\_outs,epi\_keys
Listing 2:Kernel Example\.
## Appendix CExperiments
### C\.1List of Kernels
We summarize the kernels implemented inCODA\. Each kernel is a GEMM followed by an epilogue program\.
#### C\.1\.1Basic Epilogue Kernels
We first list three basic GEMM\-plus\-epilogue kernels\. These are useful for isolating individual epilogue primitives, although they are not always used directly in the Transformer forward pass\.
Kernel 1: GEMM with RoPE\.This kernel applies RoPE\[[15](https://arxiv.org/html/2605.19269#bib.bib15)\]to pairs of adjacent features in the GEMM output:
𝑫\\displaystyle\{\\bm\{D\}\}=𝑨𝑩,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\},𝑶\\displaystyle\{\\bm\{O\}\}=RoPE\(𝑫\)\.\\displaystyle=\\operatorname\{RoPE\}\(\{\\bm\{D\}\}\)\.
Kernel 2: GEMM with SwiGLU\.This kernel applies a fused SwiGLU activation to an interleaved GEMM output:
𝑫\\displaystyle\{\\bm\{D\}\}=𝑨𝑩,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\},\[𝑮,𝑼\]\\displaystyle\[\{\\bm\{G\}\},\{\\bm\{U\}\}\]=interleavedSplit\(𝑫\),\\displaystyle=\\operatorname\{interleavedSplit\}\(\{\\bm\{D\}\}\),𝑶\\displaystyle\{\\bm\{O\}\}=silu\(𝑮\)⊙𝑼\.\\displaystyle=\\operatorname\{silu\}\(\{\\bm\{G\}\}\)\\odot\{\\bm\{U\}\}\.
Kernel 3: GEMM with partial cross\-entropy\.This kernel computes logits, selects the target logit, and emits block\-wise log\-sum\-exp statistics for the cross\-entropy loss:
𝒁\\displaystyle\{\\bm\{Z\}\}=𝑨𝑩,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\},𝒛tgt\\displaystyle\{\\bm\{z\}\}\_\{\\mathrm\{tgt\}\}=𝒁\[𝒚\],\\displaystyle=\{\\bm\{Z\}\}\[\{\\bm\{y\}\}\],𝒍^lse\\displaystyle\\widehat\{\{\\bm\{l\}\}\}\_\{\\mathrm\{lse\}\}=reduceTilelog∑exp\(𝒁\)\.\\displaystyle=\\operatorname\{reduceTile\}\_\{\\log\\sum\\exp\}\(\{\\bm\{Z\}\}\)\.
#### C\.1\.2Forward\-Pass Kernels
The following kernels implement the reparameterized Transformer forward pass\. They compose the basic epilogue primitives with RMSNorm scaling, residual updates, and partial reductions\.
Kernel 4: GEMM with residual, partial RMSNorm, and weight scaling\.This kernel implements the first stage of the GEMM–Residual–RMSNorm–GEMM pattern\. It forms the residual\-updated activation, emits partial RMS statistics, and applies the RMSNorm weight:
𝑫\\displaystyle\{\\bm\{D\}\}=𝑨𝑩\+𝑪,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\}\+\{\\bm\{C\}\},𝒓^\\displaystyle\\widehat\{\{\\bm\{r\}\}\}=reduceTilecols\(𝑫⊙𝑫\),\\displaystyle=\\operatorname\{reduceTile\}\_\{\\mathrm\{cols\}\}\(\{\\bm\{D\}\}\\odot\{\\bm\{D\}\}\),𝑶\\displaystyle\{\\bm\{O\}\}=𝑫⊙𝜸\.\\displaystyle=\{\\bm\{D\}\}\\odot\\bm\{\\gamma\}\.
Kernel 5: GEMM with RMSNorm scaling\.This kernel consumes a precomputed row\-wise normalization factor and applies it in the GEMM epilogue:
𝑫\\displaystyle\{\\bm\{D\}\}=𝑨𝑩,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\},𝑶\\displaystyle\{\\bm\{O\}\}=𝑫⊙𝒓\.\\displaystyle=\{\\bm\{D\}\}\\odot\{\\bm\{r\}\}\.
Kernel 6: GEMM with RMSNorm and SwiGLU\.This kernel composes row\-wise RMSNorm scaling with SwiGLU, corresponding to the MLP gate/up projection:
𝑫\\displaystyle\{\\bm\{D\}\}=𝑨𝑩,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\},𝑫′\\displaystyle\{\\bm\{D\}\}^\{\\prime\}=𝑫⊙𝒓,\\displaystyle=\{\\bm\{D\}\}\\odot\{\\bm\{r\}\},\[𝑮,𝑼\]\\displaystyle\[\{\\bm\{G\}\},\{\\bm\{U\}\}\]=interleavedSplit\(𝑫′\),\\displaystyle=\\operatorname\{interleavedSplit\}\(\{\\bm\{D\}\}^\{\\prime\}\),𝑶\\displaystyle\{\\bm\{O\}\}=silu\(𝑮\)⊙𝑼\.\\displaystyle=\\operatorname\{silu\}\(\{\\bm\{G\}\}\)\\odot\{\\bm\{U\}\}\.
Kernel 7: GEMM with RMSNorm and RoPE\.This kernel composes row\-wise RMSNorm scaling with RoPE, corresponding to QKV projection followed by rotary positional embedding:
𝑫\\displaystyle\{\\bm\{D\}\}=𝑨𝑩,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\},𝑫′\\displaystyle\{\\bm\{D\}\}^\{\\prime\}=𝑫⊙𝒓,\\displaystyle=\{\\bm\{D\}\}\\odot\{\\bm\{r\}\},𝑶\\displaystyle\{\\bm\{O\}\}=RoPE\(𝑫′\)\.\\displaystyle=\\operatorname\{RoPE\}\(\{\\bm\{D\}\}^\{\\prime\}\)\.
Kernel 8: GEMM with RMSNorm and partial cross\-entropy\.This kernel adds row\-wise RMSNorm scaling before target\-logit selection and partial log\-sum\-exp reduction, corresponding to the language\-modeling head:
𝒁\\displaystyle\{\\bm\{Z\}\}=\(𝑨𝑩\)⊙𝒓,\\displaystyle=\(\{\\bm\{A\}\}\{\\bm\{B\}\}\)\\odot\{\\bm\{r\}\},𝒛tgt\\displaystyle\{\\bm\{z\}\}\_\{\\mathrm\{tgt\}\}=𝒁\[𝒚\],\\displaystyle=\{\\bm\{Z\}\}\[\{\\bm\{y\}\}\],𝒍^lse\\displaystyle\\widehat\{\{\\bm\{l\}\}\}\_\{\\mathrm\{lse\}\}=reduceTilelog∑exp\(𝒁\)\.\\displaystyle=\\operatorname\{reduceTile\}\_\{\\log\\sum\\exp\}\(\{\\bm\{Z\}\}\)\.
#### C\.1\.3Backward\-Pass Kernels
Finally, we list the backward kernels\. These kernels mirror the forward structure: each performs a GEMM, applies the local backward rule in the epilogue, and emits partial reductions needed by neighboring RMSNorm backward computations\.
Kernel 9: GEMM with residual and RMSNorm backward\.This kernel implements the local part of RMSNorm backward\. Let𝑪\{\\bm\{C\}\}denote the RMSNorm input,𝒓\{\\bm\{r\}\}the row\-wise inverse RMS factor,𝜸\\bm\{\\gamma\}the RMSNorm weight, and𝒛Δz\{\\bm\{z\}\}\_\{\\Delta z\}the row\-wise RMSNorm backward statistic:
𝑫\\displaystyle\{\\bm\{D\}\}=𝑨𝑩⊤,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\}^\{\\top\},𝑪norm\\displaystyle\{\\bm\{C\}\}\_\{\\mathrm\{norm\}\}=𝑪⊙𝒓,\\displaystyle=\{\\bm\{C\}\}\\odot\{\\bm\{r\}\},𝑶out\\displaystyle\{\\bm\{O\}\}\_\{\\mathrm\{out\}\}=𝑶in\+\(𝑫⊙𝜸−𝑪norm⊙𝒛Δz\)⊙𝒓,\\displaystyle=\{\\bm\{O\}\}\_\{\\mathrm\{in\}\}\+\\left\(\{\\bm\{D\}\}\\odot\\bm\{\\gamma\}\-\{\\bm\{C\}\}\_\{\\mathrm\{norm\}\}\\odot\{\\bm\{z\}\}\_\{\\Delta z\}\\right\)\\odot\{\\bm\{r\}\},𝑪out\\displaystyle\{\\bm\{C\}\}\_\{\\mathrm\{out\}\}=𝑪norm⊙𝜸,\\displaystyle=\{\\bm\{C\}\}\_\{\\mathrm\{norm\}\}\\odot\\bm\{\\gamma\},∇𝜸ℒ^\\displaystyle\\widehat\{\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}\}=reduceTilerows\(𝑫⊙𝑪norm\)\.\\displaystyle=\\operatorname\{reduceTile\}\_\{\\mathrm\{rows\}\}\\left\(\{\\bm\{D\}\}\\odot\{\\bm\{C\}\}\_\{\\mathrm\{norm\}\}\\right\)\.
Kernel 10: GEMM with SwiGLU backward\.This kernel computes the backward pass of a fused SwiGLU epilogue and emits the row\-wise statistic needed by the preceding RMSNorm backward\. Let𝒁\{\\bm\{Z\}\}be the saved interleaved pre\-activation tensor:
𝑫\\displaystyle\{\\bm\{D\}\}=𝑨𝑩⊤,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\}^\{\\top\},\[𝑮,𝑼\]\\displaystyle\[\{\\bm\{G\}\},\{\\bm\{U\}\}\]=interleavedSplit\(𝒁\),\\displaystyle=\\operatorname\{interleavedSplit\}\(\{\\bm\{Z\}\}\),𝑶\\displaystyle\{\\bm\{O\}\}=silu\(𝑮\)⊙𝑼,\\displaystyle=\\operatorname\{silu\}\(\{\\bm\{G\}\}\)\\odot\{\\bm\{U\}\},∇𝑼ℒ\\displaystyle\\nabla\_\{\{\\bm\{U\}\}\}\\mathcal\{L\}=𝑫⊙silu\(𝑮\),\\displaystyle=\{\\bm\{D\}\}\\odot\\operatorname\{silu\}\(\{\\bm\{G\}\}\),∇𝑮ℒ\\displaystyle\\nabla\_\{\{\\bm\{G\}\}\}\\mathcal\{L\}=𝑫⊙𝑼⊙\(σ\(𝑮\)\+silu\(𝑮\)⊙\(1−σ\(𝑮\)\)\),\\displaystyle=\{\\bm\{D\}\}\\odot\{\\bm\{U\}\}\\odot\\left\(\\sigma\(\{\\bm\{G\}\}\)\+\\operatorname\{silu\}\(\{\\bm\{G\}\}\)\\odot\(1\-\\sigma\(\{\\bm\{G\}\}\)\)\\right\),∇𝒁ℒ\\displaystyle\\nabla\_\{\{\\bm\{Z\}\}\}\\mathcal\{L\}=interleavedConcat\(∇𝑮ℒ,∇𝑼ℒ\),\\displaystyle=\\operatorname\{interleavedConcat\}\\left\(\\nabla\_\{\{\\bm\{G\}\}\}\\mathcal\{L\},\\nabla\_\{\{\\bm\{U\}\}\}\\mathcal\{L\}\\right\),𝒛Δz^\\displaystyle\\widehat\{\{\\bm\{z\}\}\_\{\\Delta z\}\}=reduceTilecols\(𝑮⊙∇𝑮ℒ\+𝑼⊙∇𝑼ℒ\)\.\\displaystyle=\\operatorname\{reduceTile\}\_\{\\mathrm\{cols\}\}\\left\(\{\\bm\{G\}\}\\odot\\nabla\_\{\{\\bm\{G\}\}\}\\mathcal\{L\}\+\{\\bm\{U\}\}\\odot\\nabla\_\{\{\\bm\{U\}\}\}\\mathcal\{L\}\\right\)\.
### C\.2Setup Details
Experiments are conducted using a single H100 GPU\. We use the following package versions\.
1. 1\.PyTorch 2\.10\.0
2. 2\.CuTeDSL 4\.4\.2
3. 3\.Liger Kernels 0\.8\.0
4. 4\.FlashInfer 0\.6\.10\.post1
5. 5\.QuACK Kernels 0\.4\.1Similar Articles
@juleslogs: Want to understand modern AI? Start here: 1. Transformers → Illustrated Transformer 2. LLMs → Build a Large Language Mo…
A tweet curating foundational resources for understanding modern AI, covering topics from transformers to physical AI, including key papers and models.
@loganthorneloe: Read this to get started learning ML infra. This is an excellent high-level overview of important considerations in ML …
CMU Software Engineering Institute publishes an overview of ML training infrastructure, covering hardware considerations like GPU vs CPU and memory requirements.
@FinanceYF5: Anthropic is hiring 1000 freelance software engineers to train Claude Code. Each task pays $280. They write prompts, compare code outputs, test the model's follow-up responses, and teach Claude how real developers work. It's like handing...
Anthropic is hiring 1000 freelance software engineers to train Claude Code, with each task paying $280. The engineers will write prompts, compare code outputs, test model responses, and teach Claude how real developers work.
@FeitengLi: Asynchronous, Sparse, and the Fifth Decimal Place: Engineering Details of Cursor Training Composer 2 https://lattifai.com/zh/podcasts/SequoiaCapital/UDTr9yUnLUI…
This article delves into the technical details such as asynchronous and sparse methods used in Cursor training Composer 2 model, and provides a comprehensive analysis of the RL infrastructure.
Effect of Demographic Bias on Skin Lesion Classification
This paper investigates the impact of demographic bias (sex and age) on skin lesion classification using ResNet models, finding that sex biases stem from data imbalances while age biases consistently favor younger groups, and evaluating multi-task and adversarial learning mitigation strategies.