CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs

Hacker News Top Papers

Summary

Introduces CODA, a GPU kernel abstraction that expresses Transformer operations as GEMM-plus-epilogue programs to reduce data movement, covering nearly all non-attention computation in a Transformer block.

No content available
Original Article
View Cached Full Text

Cached at: 05/22/26, 06:43 AM

# CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs
Source: [https://arxiv.org/html/2605.19269](https://arxiv.org/html/2605.19269)
Han Guo1Jack Zhang2Arjun Menon2 Driss Guessous4Vijay Thakkar4Yoon Kim1Tri Dao2,3 1Massachusetts Institute of Technology2Princeton University3Together AI4Meta hanguo@mit\.edu

###### Abstract

Transformer training systems are built around dense linear algebra, yet a nontrivial fraction of end\-to\-end time is spent on surrounding memory\-bound operators\. Normalization, activations, residual updates, reductions, and related computations repeatedly move large intermediate tensors through global memory while performing little arithmetic, making data movement an increasingly important bottleneck in otherwise highly optimized training stacks\. We introduceCODA, a GPU kernel abstraction that expresses these computations as GEMM\-plus\-epilogue programs\.CODAis based on the observation that many Transformer operators exposed as separate framework kernels can be algebraically reparameterized to execute while a GEMM output tile remains on chip, before it is written to memory\. The abstraction fixes the GEMM mainloop and exposes a small set of composable epilogue primitives for scaling, reductions, pairwise transformations, and accumulation\. This constrained interface preserves the performance structure of expert\-written GEMMs while remaining expressive enough to cover nearly all non\-attention computation in the forward and backward pass of a standard Transformer block\. Across representative Transformer workloads, both human\- and LLM\-authoredCODAkernels achieve high performance, suggesting that GEMM\-plus\-epilogue programming offers a practical path toward combining framework\-level productivity with hardware\-level efficiency\.111Code available at[https://github\.com/HanGuo97/coda\-kernels](https://github.com/HanGuo97/coda-kernels)\.

## 1Introduction

![Refer to caption](https://arxiv.org/html/2605.19269v2/x1.png)Figure 1:Runtime breakdown for LLaMA\-3\-style 1B model training on a single H100 using TorchTitan\.LLM training has become just as much of a systems problem as a modeling one\. FLOPs in modern Transformer\-based LLMs are dominated by matrix multiplications and attention, whose kernels have been heavily optimized for Tensor Core execution\. Yet Transformers, and deep learning architectures more broadly, also contain normalization, activations, residual updates, reductions, and other bandwidth\-limited operations that move large tensors through memory while doing little arithmetic\. Prior work has shown that data movement is a central bottleneck in Transformer training\[[7](https://arxiv.org/html/2605.19269#bib.bib7)\]; as[Figure˜1](https://arxiv.org/html/2605.19269#S1.F1)shows, when training a LLaMA\-3\-style 1B model on a single H100 using TorchTitan\[[11](https://arxiv.org/html/2605.19269#bib.bib11)\], these non\-GEMM operations account for a nontrivial fraction of end\-to\-end runtime\. As hardware increasingly accelerates low\-precision matrix multiplication through formats such as FP8 and FP4, this bottleneck is becoming more important, as the cost of materializing intermediate tensors does not improve at the same rate\.

Existing programming models make this issue difficult to address\. High\-level frameworks such as PyTorch express Transformer blocks as operator sequences, with autograd making the backward pass similarly convenient\. This is productive, but operator boundaries often become materialization boundaries and obscure fusion opportunities across forward and backward computation\. Production\-level LLM systems therefore often bypass framework abstractions with hand\-written backward passes or custom kernels, as in large\-scale LLaMA training\[[5](https://arxiv.org/html/2605.19269#bib.bib5)\]and inference systems\[[9](https://arxiv.org/html/2605.19269#bib.bib9),[25](https://arxiv.org/html/2605.19269#bib.bib25)\]\. This work asks whether there is a middle ground\. That is, is it possible to recover much of the performance of custom kernels without giving up the structure needed for programmability and automation?

Our starting observation is that many Transformer computations that appear at the framework level as separate operators can be algebraically reparameterized as GEMM\-plus\-epilogue programs \([Figure˜2](https://arxiv.org/html/2605.19269#S1.F2)\)\. In this form, a highly optimized GEMM mainloop produces output tiles, while a programmable epilogue performs tile\-local transformations before the result is written to memory \([Figure˜3](https://arxiv.org/html/2605.19269#S2.F3)\)\. This is efficient on GPUs because the epilogue operates on data that is already produced by the GEMM tile, avoiding additional global\-memory round trips for intermediate tensors\. With modern pipelined schedules, this epilogue work can often be placed in the shadow of other tiles’ mainloops, as in Hopper Ping\-Pong GEMM and Blackwell TMEM\-based pipelines\. Thus, we extend the epilogue beyond a place for simple post\-processing such as scaling or bias addition, elevating it into a structured interface for fusing memory\-bound computation into the lifetime of a GEMM tile\.

Based on the above, we introduceCODA, a kernel abstraction prototype that realizes this interface\.CODAkeeps the GEMM mainloop fixed and exposes a small set of composable epilogue primitives for scaling, reductions, pairwise transformations, and accumulation\. This programming model is deliberately constrained and yet expressive, as after reparameterization, these primitives cover nearly the entire forward and backward pass of a standard Transformer model while preserving efficiency\.CODAinserts computation into the epilogue of a known high\-performance GEMM before intermediate tensors are materialized in global memory, capturing a broad class of memory\-bound computations surrounding dense linear algebra\. Transformers are our primary application, but the same GEMM\-plus\-epilogue view applies more broadly whenever high\-throughput matrix multiplication is surrounded by tile\-expressible, data\-movement\-bound computations\.

Finally, this structure makes automation more practical\. Epilogue fusion is already established in high\-performance GEMM libraries, but applying it to Transformer workloads remains a low\-level engineering task\.CODAtargets this gap by providing Transformer\-specific epilogue primitives on top of a tuned GEMM mainloop\. Human or LLM\-based authors can assemble these primitives into reparameterized Transformer kernels rather than synthesizing arbitrary CUDA\. Across representative workloads, both authoring modes achieve high performance, suggesting that domain\-specific epilogue abstractions can make established GEMM fusion techniques more programmable for LLM kernels\.

![Refer to caption](https://arxiv.org/html/2605.19269v2/x2.png)Figure 2:Forward pass of a standard Transformer layer\. The top row shows the canonical formulation, which maps to a mix of compute\- and memory\-bound kernels\. We reparameterize the computation so that most memory\-bound operations are subsumed into the epilogues of compute\-bound kernels\.
## 2Background and Related Works

### 2\.1Programming Models for LLM Systems

Modern LLM systems are programmed at multiple abstraction levels\. Frameworks such as PyTorch and JAX express models as tensor\-operator graphs and integrate naturally with automatic differentiation, but operator boundaries often become materialization ones\.

Compiler systems lower tensor programs to optimized kernels through graph rewriting, scheduling, code generation, and autotuning\[[1](https://arxiv.org/html/2605.19269#bib.bib1),[2](https://arxiv.org/html/2605.19269#bib.bib2),[19](https://arxiv.org/html/2605.19269#bib.bib19)\]\. Algebraic reformulation is another important source of performance, as shown by TASO\[[8](https://arxiv.org/html/2605.19269#bib.bib8)\]and Mirage\[[22](https://arxiv.org/html/2605.19269#bib.bib22)\]\. However, rapidly evolving accelerators make peak performance a moving target for general\-purpose compilers\.

Closer to the hardware level, programmers use kernel DSLs and libraries such as Triton\[[19](https://arxiv.org/html/2605.19269#bib.bib19)\], ThunderKittens\[[14](https://arxiv.org/html/2605.19269#bib.bib14),[13](https://arxiv.org/html/2605.19269#bib.bib13),[17](https://arxiv.org/html/2605.19269#bib.bib17)\], TileLang\[[20](https://arxiv.org/html/2605.19269#bib.bib20)\], CuTeDSL\[[18](https://arxiv.org/html/2605.19269#bib.bib18)\], Gluon, and TLX, or rely on specialized LLM kernels in vLLM\[[9](https://arxiv.org/html/2605.19269#bib.bib9)\], SGLang\[[25](https://arxiv.org/html/2605.19269#bib.bib25)\], FlashInfer\[[23](https://arxiv.org/html/2605.19269#bib.bib23)\], and Liger Kernels\[[6](https://arxiv.org/html/2605.19269#bib.bib6)\]\. These approaches deliver high performance, but extending them to new transformations or backward computations still requires substantial low\-level engineering\.

### 2\.2GEMM Mainloops and Epilogue Fusion

Matrix multiplication is the central compute primitive in modern LLM workloads\. A high\-performance GEMM kernel is typically divided into a mainloop and an epilogue\. The mainloop performs the tiled matrix multiply\-accumulate computation, while the epilogue transforms the computed output tile and efficiently writes it back to global memory\.

![Refer to caption](https://arxiv.org/html/2605.19269v2/x3.png)Figure 3:A GEMM mainloop computes output tiles; the epilogue transforms each tile before the final global\-memory store\.The epilogue is a natural place to implement fusions because the output of the matmul is already present on chip close to compute cores\. Practical epilogues commonly perform scaling, bias addition, activations, residual updates, data type conversions, tile\-wise reductions and other output elementwise operations, avoiding separate kernel launches and extra global\-memory round trips\. Modern kernel libraries formalize this separation directly: CUTLASS\[[18](https://arxiv.org/html/2605.19269#bib.bib18)\]represents GEMM kernels as a composition of a collective mainloop and a collective epilogue, while Epilogue Visitor Trees further express epilogues as compositions of primitives\[[4](https://arxiv.org/html/2605.19269#bib.bib4)\]\.

This flexibility operates under a locality constraint\. An epilogue sees only the local output tile, its accumulators, and consistently indexed auxiliary tensors, meaning that operations requiring global reductions or cross\-tile communication must be reformulated into tile\-local pieces or handled in a separate pass\. CODA builds on this interface, keeping the high\-performance GEMM mainloop fixed and using the epilogue as a programmable site for nearby memory\-bound computation\.

## 3CODA

The previous section argued that GEMM epilogues are a natural place to fuse memory\-bound computation into dense linear algebra\. We now describeCODA, a GPU kernel abstraction that realizes this idea\.[Section˜3\.1](https://arxiv.org/html/2605.19269#S3.SS1)identifies a small set of epilogue primitives that map efficiently to GPU execution\.[Section˜3\.2](https://arxiv.org/html/2605.19269#S3.SS2)shows how the non\-attention and non\-embedding portions of the Transformer forward and backward pass can be reparameterized using these primitives\. Finally,[Section˜3\.3](https://arxiv.org/html/2605.19269#S3.SS3)describes their implementation and our LLM\-oriented authoring workflow\.

### 3\.1Efficient Epilogue Primitives

CODAprograms the GEMM epilogue while keeping the mainloop fixed and highly optimized\. For each output tile, an epilogue may load auxiliary data, transform accumulator values, emit auxiliary results, and store the final output\. This interface is deliberately restricted to tile\-local computation rather than arbitrary global communication\. Our epilogue template, shown in[Section˜B\.1](https://arxiv.org/html/2605.19269#A2.SS1), is inspired by Epilogue Visitor Trees\[[4](https://arxiv.org/html/2605.19269#bib.bib4)\]\.CODAprovides five classes of epilogue primitives:

1. 1\.*Elementwise and pairwise maps:*apply local transformations to accumulator values, including residual updates, activations, RoPE\-style rotations, and SwiGLU\-style gates\.
2. 2\.*Vector \(Rank\-1 Tensor\) loads and stores:*load row or column vectors, broadcast them over an output tile, and optionally write vector\-valued auxiliary results\.
3. 3\.*Tile \(Rank\-2 Tensor\) loads and stores:*load or store matrix tiles, such as residual streams, saved activations, or intermediate values needed by the backward pass\.
4. 4\.*Tile \(Rank\-2 Tensor\) reductions:*compute partial reductions over rows or columns of an output tile, to be combined later by a lightweight auxiliary kernel\.
5. 5\.*Stateful transforms:*maintain running tile state, such as the max and sum\-exp statistics used in online log\-sum\-exp and cross\-entropy\.

These primitives are intentionally narrow, operating at a level low enough to compile to efficient epilogue code and expressive enough to capture the memory\-bound operations surrounding GEMMs in our Transformer reparameterizations, as shown next\.

### 3\.2Reparameterizing Transformers as Epilogues

We now show that the primitive set above is sufficient for much of Transformer computation\. After lightweight algebraic reparameterization, many non\-attention and non\-embedding components of a standard Transformer forward pass can be written as

GEMM:𝒉=𝒙​𝑾,Epilogue:𝒚​\[i,j\]=𝒇​\[i,j\]​\(𝒉​\[i,j\]\),\\displaystyle\\text\{GEMM:\}\\quad\{\\bm\{h\}\}=\{\\bm\{x\}\}\{\\bm\{W\}\},\\qquad\\text\{Epilogue:\}\\quad\{\\bm\{y\}\}\[i,j\]=\{\\bm\{f\}\}\[i,j\]\\\!\\left\(\{\\bm\{h\}\}\[i,j\]\\right\),where\[i,j\]\[i,j\]indexes an output tile and𝒇​\[i,j\]\{\\bm\{f\}\}\[i,j\]is the tile function implemented in the GEMM epilogue\. The epilogue is either fully tile\-local, or tile\-local up to partial results that are combined by a lightweight auxiliary reduction\. We first apply this view to the forward pass, then show that independent tile functions preserve the same GEMM\-epilogue structure in the backward pass\.

#### 3\.2\.1GEMM\-Residual\-RMSNorm\-GEMM Pattern

A recurring pattern in pre\-normalized Transformers is a GEMM followed by a residual update and normalization, then another GEMM\. This pattern appears across several adjacent sublayers:

1. 1\.attention output projection→\\rightarrowresidual stream→\\rightarrowRMSNorm→\\rightarrowMLP gate/up projection;
2. 2\.MLP down projection→\\rightarrowresidual stream→\\rightarrowRMSNorm→\\rightarrowattention QKV projection;
3. 3\.final MLP down projection→\\rightarrowresidual stream→\\rightarrowfinal RMSNorm→\\rightarrowlanguage modeling head\.

Although these cases are usually written as parts of different modules, they share the same computational structure:

𝒚=RMSNorm⁡\(𝒙​𝑾0\+𝒛,𝜸\)​𝑾1=\(r​\(𝒙​𝑾0\+𝒛\)⊙𝜸\)​𝑾1,\\displaystyle\{\\bm\{y\}\}=\\operatorname\{RMSNorm\}\(\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\}\+\{\\bm\{z\}\},\\bm\{\\gamma\}\)\{\\bm\{W\}\}\_\{1\}=\\Bigl\(r\\,\\bigl\(\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\}\+\{\\bm\{z\}\}\\bigr\)\\odot\\bm\{\\gamma\}\\Bigr\)\{\\bm\{W\}\}\_\{1\},where𝒛\{\\bm\{z\}\}denotes the residual stream andr=1/rms⁡\(𝒙​𝑾0\+𝒛\)r=1/\\operatorname\{rms\}\(\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\}\+\{\\bm\{z\}\}\)is the row\-wise inverse RMS factor\. This pattern crosses the usual module boundary: it couples the output projection of one sublayer with the input projection of the next\.

Residual addition and multiplication by the RMSNorm weight𝜸\\bm\{\\gamma\}are tile\-local, so they can be fused into a GEMM epilogue\. The row\-wise factorrr, however, requires a reduction across the hidden dimension, which is larger than a single output tile\. In the canonical computation,rris applied before the second GEMM, creating an apparent dependency between normalization and the next projection\.

![Refer to caption](https://arxiv.org/html/2605.19269v2/x4.png)Figure 4:GEMM\-RMSNorm\-GEMM reparameterization\.We address the reduction by splitting it into two levels\. The first GEMM epilogue computes tile\-local partial reductions, and a small auxiliary kernel reduces these partials across tiles to obtainrr\. Since the auxiliary kernel reads a few partial values per tile rather than the full activation tensor, its memory traffic is much smaller than that of a standalone RMSNorm\.

The apparent dependency onrrcan be removed algebraically\. Sincerris shared across the row, it commutes with the second GEMM:

𝒚=\(r​\(𝒙​𝑾0\+𝒛\)⊙𝜸\)​𝑾1=r​\(\(𝒙​𝑾0\+𝒛\)⊙𝜸\)​𝑾1\.\\displaystyle\{\\bm\{y\}\}=\\Bigl\(r\\,\\bigl\(\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\}\+\{\\bm\{z\}\}\\bigr\)\\odot\\bm\{\\gamma\}\\Bigr\)\{\\bm\{W\}\}\_\{1\}=r\\,\\Bigl\(\\bigl\(\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\}\+\{\\bm\{z\}\}\\bigr\)\\odot\\bm\{\\gamma\}\\Bigr\)\{\\bm\{W\}\}\_\{1\}\.Thus, the row\-wise scale does not need to be applied before the second GEMM\. It can instead be delayed to the epilogue of the second GEMM, after the projection has been computed\.

Concretely, the computation decomposes into two GEMMs and one lightweight reduction \([Figure˜4](https://arxiv.org/html/2605.19269#S3.F4)\):

GEMM 1:𝒉0\\displaystyle\{\\bm\{h\}\}\_\{0\}=𝒙​𝑾0,\\displaystyle=\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\},\\quadEpilogue 1:𝒉1​\[i,j\]\\displaystyle\{\\bm\{h\}\}\_\{1\}\[i,j\]=𝒉0​\[i,j\]\+𝒛​\[i,j\],\\displaystyle=\{\\bm\{h\}\}\_\{0\}\[i,j\]\+\{\\bm\{z\}\}\[i,j\],𝒉2​\[i,j\]\\displaystyle\{\\bm\{h\}\}\_\{2\}\[i,j\]=𝒉1​\[i,j\]⊙𝜸​\[j\],\\displaystyle=\{\\bm\{h\}\}\_\{1\}\[i,j\]\\odot\\bm\{\\gamma\}\[j\],𝒓^​\[i,j\]\\displaystyle\\widehat\{\{\\bm\{r\}\}\}\[i,j\]=partialRMS⁡\(𝒉1​\[i,j\]\),\\displaystyle=\\operatorname\{partialRMS\}\(\{\\bm\{h\}\}\_\{1\}\[i,j\]\),GEMM 2:𝒉3\\displaystyle\{\\bm\{h\}\}\_\{3\}=𝒉2​𝑾1,\\displaystyle=\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\},\\qquadEpilogue 2:𝒚​\[i,j\]\\displaystyle\{\\bm\{y\}\}\[i,j\]=r​\[i\]​𝒉3​\[i,j\]\.\\displaystyle=r\[i\]\\,\{\\bm\{h\}\}\_\{3\}\[i,j\]\.
![Refer to caption](https://arxiv.org/html/2605.19269v2/x5.png)Figure 5:Benchmarks\.Herer=1/reduce⁡\(𝒓^\)\+ϵr=1/\\sqrt\{\\operatorname\{reduce\}\(\\widehat\{\{\\bm\{r\}\}\}\)\+\\epsilon\}is computed by a small auxiliary reduction over the tile partials\. This decomposition replaces a standalone RMSNorm kernel with tile\-local epilogue work around the two GEMMs, plus a lightweight auxiliary reduction\.

In[Figure˜5](https://arxiv.org/html/2605.19269#S3.F5), we benchmark this reparameterization against existing implementations on LLaMA\-style configurations with a batch of 16K tokens\. We vary the hidden dimension across representative model scales, withd∈\{2048,4096,8192\}d\\in\\\{2048,4096,8192\\\}corresponding roughly to 1B, 7B, and 70B models, respectively\. Our GEMM\-Epilogue kernel is generated by an LLM provided with the above abstractions \(explained in more detail in[Section˜3\.3\.1](https://arxiv.org/html/2605.19269#S3.SS3.SSS1)\)\.

![Refer to caption](https://arxiv.org/html/2605.19269v2/x6.png)Figure 6:Relative error\.##### Numerics\.

The reparameterization changes where the RMSNorm scale is applied: the row\-wise factorrris delayed from before the second GEMM to the second GEMM epilogue\. We compareBF16GEMM\-RMSNorm\-GEMM outputs against anFP32reference on Llama\-3 8B layers\. We report the errors ofCODAand QuACK, on which our GEMM template is based, normalized by the error of the standard PyTorch path\.[Figure˜6](https://arxiv.org/html/2605.19269#S3.F6)suggests that a more accurate GEMM mainloop can reduce numerical error, and thatCODA’s reparameterized epilogue can reduce it further\.

#### 3\.2\.2GEMM with Pairwise Activations

A second common pattern in Transformers is a GEMM followed by a*pairwise*activation\. Unlike an elementwise activation, which transforms each feature independently, a pairwise activation consumes two adjacent feature values and produces one or two outputs\.

𝒉=𝒙​𝑾,𝒉a​\[i,j\],𝒉b​\[i,j\]=split⁡\(𝒉​\[i,j\]\),𝒚​\[i,j\]=𝒇​\[i,j\]​\(𝒉a​\[i,j\],𝒉b​\[i,j\]\)\.\\displaystyle\{\\bm\{h\}\}=\{\\bm\{x\}\}\{\\bm\{W\}\},\\qquad\{\\bm\{h\}\}\_\{a\}\[i,j\],\{\\bm\{h\}\}\_\{b\}\[i,j\]=\\operatorname\{split\}\\left\(\{\\bm\{h\}\}\[i,j\]\\right\),\\qquad\{\\bm\{y\}\}\[i,j\]=\{\\bm\{f\}\}\[i,j\]\\left\(\{\\bm\{h\}\}\_\{a\}\[i,j\],\{\\bm\{h\}\}\_\{b\}\[i,j\]\\right\)\.
![Refer to caption](https://arxiv.org/html/2605.19269v2/x7.png)Figure 7:Pairwise activations operate on local feature pairs in the GEMM epilogue\.This form captures several operations in Transformer blocks:

- •RoPE rotates each feature pair and return two outputs;
- •SwiGLU combines gate and value stream into one output;
- •SwiGLU backward pass maps one incoming gradient into gradients for both paired inputs\.

Pairwise activations couple neighboring feature lanes and may change the feature dimension\. A naive implementation materializes the GEMM output, splits it into paires, and applies the activation in a separate kernel\. This adds memory traffic and sometimes materializes an expanded intermediate, as in SwiGLU\.

Instead, we arrange paired features to be adjacent along the output\-feature dimension\. This matches the Hopper Tensor Core accumulator layout exposed to the epilogue, where each thread holds a small tuple of adjacent output values in registers before they are stored\. The epilogue can therefore applyffdirectly to each pair with register\-level computation\.

This removes the standalone activation kernel and avoids materializing the paired intermediate in global memory\. The same idea applies to dimension\-preserving operations such as RoPE, dimension\-reducing operations such as SwiGLU, and dimension\-expanding operations in the backward pass, as long as the pairing is reflected in the GEMM output layout\. See[Figure˜8](https://arxiv.org/html/2605.19269#S3.F8)for performance benchmarks\.

![Refer to caption](https://arxiv.org/html/2605.19269v2/x8.png)Figure 8:Kernel\-level speedups for representative GEMM\-plus\-epilogue primitives acrossM​N​KMNKsizes\. RoPE uses an output dimension of3​N3Nfor QKV projections, and cross\-entropy uses a32​K32\\mathrm\{K\}vocabulary\. Speedups are relative to cuBLAS withtorch\.compile\.
#### 3\.2\.3GEMM with Cross\-Entropy Loss

Cross\-entropy loss can also be expressed as a GEMM with epilogue\-side reductions, as shown by Cut Cross\-Entropy\[[21](https://arxiv.org/html/2605.19269#bib.bib21)\]\. Let𝒉i=𝒙i​𝑾lm\{\\bm\{h\}\}\_\{i\}=\{\\bm\{x\}\}\_\{i\}\{\\bm\{W\}\}\_\{\\mathrm\{lm\}\}be the logits for tokenii, and let𝒚i\{\\bm\{y\}\}\_\{i\}be its target label\. The per\-token loss is

ℓi=−𝒉i,𝒚i\+log​∑kexp⁡\(𝒉i,k\)\.\\displaystyle\\ell\_\{i\}=\-\{\\bm\{h\}\}\_\{i,\{\\bm\{y\}\}\_\{i\}\}\+\\log\\sum\_\{k\}\\exp\(\{\\bm\{h\}\}\_\{i,k\}\)\.Thus, the loss decomposes into an indexed logit and a row\-wise log\-sum\-exp over vocabulary entries\.

Both terms fit the GEMM\-plus\-epilogue pattern\. The indexed logit can be selected from the GEMM output tile using the target label, while the LSE can be accumulated as tile\-local maximum and sum\-exp statistics\. A small auxiliary reduction then combines these statistics across tiles, avoiding a standalone memory\-bound softmax over the full logits\.222We use a separate final reduction rather than atomics, and materialize logits to simplify the backward pass\.See[Figure˜8](https://arxiv.org/html/2605.19269#S3.F8)for performance benchmarks\.

#### 3\.2\.4Backward Pass

The preceding sections show that much of the Transformer forward pass can be reparameterized as GEMMs with epilogues, plus lightweight auxiliary reductions\. We now show that the backward pass preserves the same structure\.

##### GEMM with elementwise epilogue\.

Consider two GEMMs separated by an elementwise epilogue:

𝒉=𝒙​𝑾0,𝒉′=f​\(𝒉\),𝒚=𝒉′​𝑾1,\\displaystyle\{\\bm\{h\}\}=\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\},\\qquad\{\\bm\{h\}\}^\{\\prime\}=f\(\{\\bm\{h\}\}\),\\qquad\{\\bm\{y\}\}=\{\\bm\{h\}\}^\{\\prime\}\{\\bm\{W\}\}\_\{1\},whereffis applied elementwise\. Given an upstream gradient∇𝒚ℒ\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}, reverse\-mode differentiation gives

∇𝒉′ℒ=∇𝒚ℒ​𝑾1⊤,∇𝒉ℒ=∇𝒉′ℒ⊙f′​\(𝒉\),∇𝒙ℒ=∇𝒉ℒ​𝑾0⊤\.\\displaystyle\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\},\\qquad\\nabla\_\{\{\\bm\{h\}\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}\\odot f^\{\\prime\}\(\{\\bm\{h\}\}\),\\qquad\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{h\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{0\}^\{\\top\}\.Thus, the backward computation has the same structure as the forward computation: GEMM, local transformation, GEMM\. The only difference is the direction of fusion\. In the forward pass,ffis fused into the epilogue of theprecedingGEMM that produces𝒉\{\\bm\{h\}\}; in the backward pass, multiplication byf′​\(𝒉\)f^\{\\prime\}\(\{\\bm\{h\}\}\)is fused into the epilogue of thefollowingGEMM that produces∇𝒉′ℒ\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}\([Figure˜9](https://arxiv.org/html/2605.19269#S3.F9)\)\.

###### Theorem 1\.

Consider a sequence of GEMM\-with\-epilogue blocks followed by a final GEMM:

𝒉ℓ\\displaystyle\{\\bm\{h\}\}\_\{\\ell\}=𝒙ℓ−1​𝑾ℓ,𝒙ℓ​\[i,j\]=𝒇ℓ​\[i,j\]​\(𝒉ℓ​\[i,j\]\),ℓ=1,…,L−1,\\displaystyle=\{\\bm\{x\}\}\_\{\\ell\-1\}\{\\bm\{W\}\}\_\{\\ell\},\\qquad\{\\bm\{x\}\}\_\{\\ell\}\[i,j\]=\{\\bm\{f\}\}\_\{\\ell\}\[i,j\]\\\!\\left\(\{\\bm\{h\}\}\_\{\\ell\}\[i,j\]\\right\),\\qquad\\ell=1,\\ldots,L\-1,𝒉L\\displaystyle\{\\bm\{h\}\}\_\{L\}=𝒙L−1​𝑾L\.\\displaystyle=\{\\bm\{x\}\}\_\{L\-1\}\{\\bm\{W\}\}\_\{L\}\.Assume each tile function𝐟ℓ​\[i,j\]\{\\bm\{f\}\}\_\{\\ell\}\[i,j\]acts only on its corresponding GEMM output tile\. Then the activation gradients can be computed with the same GEMM\-with\-epilogue structure:

∇𝒙ℓ−1ℒ\\displaystyle\\nabla\_\{\{\\bm\{x\}\}\_\{\\ell\-1\}\}\\mathcal\{L\}=∇𝒉ℓℒ​𝑾ℓ⊤,∇𝒉ℓ−1ℒ​\[i,j\]=𝒈ℓ−1​\[i,j\]​\(∇𝒙ℓ−1ℒ​\[i,j\],𝒉ℓ−1​\[i,j\]\),ℓ=L,…,2,\\displaystyle=\\nabla\_\{\{\\bm\{h\}\}\_\{\\ell\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{\\ell\}^\{\\top\},\\qquad\\nabla\_\{\{\\bm\{h\}\}\_\{\\ell\-1\}\}\\mathcal\{L\}\[i,j\]=\{\\bm\{g\}\}\_\{\\ell\-1\}\[i,j\]\\\!\\left\(\\nabla\_\{\{\\bm\{x\}\}\_\{\\ell\-1\}\}\\mathcal\{L\}\[i,j\],\\;\{\\bm\{h\}\}\_\{\\ell\-1\}\[i,j\]\\right\),\\qquad\\ell=L,\\ldots,2,∇𝒙0ℒ\\displaystyle\\nabla\_\{\{\\bm\{x\}\}\_\{0\}\}\\mathcal\{L\}=∇𝒉1ℒ​𝑾1⊤\.\\displaystyle=\\nabla\_\{\{\\bm\{h\}\}\_\{1\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\}\.Here𝐠ℓ​\[i,j\]\{\\bm\{g\}\}\_\{\\ell\}\[i,j\]is the tile\-local backward rule for𝐟ℓ​\[i,j\]\{\\bm\{f\}\}\_\{\\ell\}\[i,j\]: it maps the gradient of the epilogue output tile𝐱ℓ​\[i,j\]\{\\bm\{x\}\}\_\{\\ell\}\[i,j\]to the gradient of the epilogue input tile𝐡ℓ​\[i,j\]\{\\bm\{h\}\}\_\{\\ell\}\[i,j\]\. The weight gradients are GEMMs:

∇𝑾ℓℒ=𝒙ℓ−1⊤​∇𝒉ℓℒ,ℓ=1,…,L\.\\displaystyle\\nabla\_\{\{\\bm\{W\}\}\_\{\\ell\}\}\\mathcal\{L\}=\{\\bm\{x\}\}\_\{\\ell\-1\}^\{\\top\}\\nabla\_\{\{\\bm\{h\}\}\_\{\\ell\}\}\\mathcal\{L\},\\qquad\\ell=1,\\ldots,L\.Thus, tile\-local epilogues in the forward pass induce tile\-local epilogues in the backward pass, while the surrounding linear maps remain GEMMs\.

###### Proof sketch\.

For the GEMM𝒉ℓ=𝒙ℓ−1​𝑾ℓ\{\\bm\{h\}\}\_\{\\ell\}=\{\\bm\{x\}\}\_\{\\ell\-1\}\{\\bm\{W\}\}\_\{\\ell\}, reverse\-mode differentiation gives

∇𝒙ℓ−1ℒ=∇𝒉ℓℒ​𝑾ℓ⊤,∇𝑾ℓℒ=𝒙ℓ−1⊤​∇𝒉ℓℒ,\\displaystyle\\nabla\_\{\{\\bm\{x\}\}\_\{\\ell\-1\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{h\}\}\_\{\\ell\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{\\ell\}^\{\\top\},\\qquad\\nabla\_\{\{\\bm\{W\}\}\_\{\\ell\}\}\\mathcal\{L\}=\{\\bm\{x\}\}\_\{\\ell\-1\}^\{\\top\}\\nabla\_\{\{\\bm\{h\}\}\_\{\\ell\}\}\\mathcal\{L\},which are both GEMMs\. For the epilogue𝒙ℓ​\[i,j\]=𝒇ℓ​\[i,j\]​\(𝒉ℓ​\[i,j\]\)\{\\bm\{x\}\}\_\{\\ell\}\[i,j\]=\{\\bm\{f\}\}\_\{\\ell\}\[i,j\]\(\{\\bm\{h\}\}\_\{\\ell\}\[i,j\]\), define the local backward rule by

∇𝒉ℓℒ​\[i,j\]=𝒈ℓ​\[i,j\]​\(∇𝒙ℓℒ​\[i,j\],𝒉ℓ​\[i,j\]\)\.\\displaystyle\\nabla\_\{\{\\bm\{h\}\}\_\{\\ell\}\}\\mathcal\{L\}\[i,j\]=\{\\bm\{g\}\}\_\{\\ell\}\[i,j\]\\\!\\left\(\\nabla\_\{\{\\bm\{x\}\}\_\{\\ell\}\}\\mathcal\{L\}\[i,j\],\\;\{\\bm\{h\}\}\_\{\\ell\}\[i,j\]\\right\)\.Since each𝒇ℓ​\[i,j\]\{\\bm\{f\}\}\_\{\\ell\}\[i,j\]depends only on its own tile, the corresponding𝒈ℓ​\[i,j\]\{\\bm\{g\}\}\_\{\\ell\}\[i,j\]also acts only on that tile\. The backward pass therefore introduces no new cross\-tile communication and preserves the GEMM\-with\-epilogue structure\. ∎

![Refer to caption](https://arxiv.org/html/2605.19269v2/x9.png)Figure 9:Forward and backward fusion for GEMM–epilogue blocks\. Forward epilogues attach to the GEMM that produces their input, while backward epilogues attach to the GEMM that produces the gradient with respect to their output\.
##### RMSNorm backward\.

RMSNorm is the main case where the backward pass is not purely tile\-local\. Its backward rule introduces two reductions: a row\-wise statistic needed for the input gradient, and a feature\-wise reduction across rows for the RMSNorm weight gradient\. A direct implementation computes both in a standalone RMSNorm backward kernel, requiring additional reads of activation\-sized tensors\. However, the row\-wise statistic can be moved to a neighboring GEMM boundary\. Consider

𝒉0=𝒙​𝑾0,𝒉1=f​\(𝒉0\),𝒉2=RMSNorm⁡\(𝒉1,𝜸\),𝒚=𝒉2​𝑾1\.\\displaystyle\{\\bm\{h\}\}\_\{0\}=\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\},\\qquad\{\\bm\{h\}\}\_\{1\}=f\(\{\\bm\{h\}\}\_\{0\}\),\\qquad\{\\bm\{h\}\}\_\{2\}=\\operatorname\{RMSNorm\}\(\{\\bm\{h\}\}\_\{1\},\\bm\{\\gamma\}\),\\qquad\{\\bm\{y\}\}=\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}\.RMSNorm backward requires the row\-wise inner product

𝒔=1d​sumcols⁡\(∇𝒉2ℒ⊙𝒉2\)\.\\displaystyle\{\\bm\{s\}\}=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\{\\bm\{h\}\}\_\{2\}\\right\)\.Using∇𝒉2ℒ=∇𝒚ℒ​𝑾1⊤\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\}and𝒚=𝒉2​𝑾1\{\\bm\{y\}\}=\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}, this statistic can be equivalently written as

𝒔=1d​sumcols⁡\(∇𝒚ℒ⊙𝒚\)\.\\displaystyle\{\\bm\{s\}\}=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\\odot\{\\bm\{y\}\}\\right\)\.This identity changes where the statistic is computed\. Instead of launching a standalone RMSNorm backward kernel to read𝒉2\{\\bm\{h\}\}\_\{2\}and∇𝒉2ℒ\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}, we accumulate the same row\-wise quantity at a boundary where𝒚\{\\bm\{y\}\}and∇𝒚ℒ\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}are already available, thereby exposing the computation to epilogue fusion\.

In consecutive Transformer patterns, each GEMM epilogue can therefore accumulate the row\-wise statistic needed by the preceding RMSNorm backward\. The RMSNorm weight gradient is handled similarly by emitting tile partials for the reduction across rows\. Overall, RMSNorm backward becomes GEMM–epilogue kernels plus lightweight auxiliary reductions over tile partials\. We give the full derivation and kernel organization in[Section˜A\.2](https://arxiv.org/html/2605.19269#A1.SS2), with benchmarks in[Figure˜11](https://arxiv.org/html/2605.19269#S4.F11)\.

### 3\.3Implementation

We implementCODAon top of CuTeDSL, which provides Python\-level kernel authoring while retaining low\-level control over details such as layouts and memory movement\.

Data movement\.Vector loads handle small broadcast operands such as RMSNorm weights\. These values are staged once in shared memory and reused across subtiles\. Tile loads handle larger operands such as residual activations, and uses Tensor Memory Accelerator transfers between global memory and shared memory, allowing data movement to overlap with epilogue computation\. Stores follow similar tile\-granular path for transformed outputs, saved intermediates, and reduction partials\.

Local computation\.Pairwise maps act on neighboring features, covering dimension\-preserving operations such as RoPE, dimension\-reducing operations such as SwiGLU, and dimension\-expanding operations in the backward pass\. When a map changes the feature dimension, the epilogue performs the corresponding local layout adjustment, such as compacting dimension\-reducing outputs or packing pairs of 16\-bit values for dimension\-expanding outputs\.

Reductions\.Tile\-wise reductions follow the ownership of GEMM output fragments\. Row\-wise reductions are accumulated by the warp that owns the row\. Column\-wise reductions may span multiple warps, so each warp first produces a partial result, and these partials are combined through shared memory\.

#### 3\.3\.1LLM\-Oriented Authoring

CODAis designed both for LLM workloads and for LLM\-assisted authoring\. Rather than asking a model to synthesize arbitrary CUDA or discover a hardware schedule from scratch,CODAexposes a constrained space of epilogue programs around expert\-designed GEMM mainloops\. Its primitives already encode efficient implementation strategies, so LLM\-based authoring becomes a problem of composing vector loads, tile loads, pairwise maps, reductions, and stores for a given Transformer computation\. This lightweight use of LLMs is complementary to prior work on kernel generation, which often relies on orchestration, search, execution feedback, or post\-training\[[12](https://arxiv.org/html/2605.19269#bib.bib12),[16](https://arxiv.org/html/2605.19269#bib.bib16),[3](https://arxiv.org/html/2605.19269#bib.bib3),[10](https://arxiv.org/html/2605.19269#bib.bib10),[24](https://arxiv.org/html/2605.19269#bib.bib24)\]\.

Because CuTeDSL is relatively new, current models have limited exposure to its idioms\. We therefore provide curated demonstrations for each abstraction\. In practice, the repository itself acts as a growing demonstration set, with new kernels being written by adapting and composing existing examples\.

##### Compositions\.

Transformer kernels often combine several epilogue operations, such as residual addition and RMSNorm scaling\. Monolithic fused epilogues lead to large, repetitive implementations that are difficult to place in context\.CODAinstead represents each epilogue as a composition of reusable primitives: the LLM specifies the local epilogue program, while the library supplies the fixed GEMM mainloop and implementation pattern for each primitive\. New fused kernels are therefore assembled from reusable building blocks instead of being rewritten from scratch\.

![Refer to caption](https://arxiv.org/html/2605.19269v2/x10.png)Figure 10:Kernel\-level speedups on reparameterized Transformer kernels relative to cuBLAS withtorch\.compile\. Raw GEMM baselines using PyTorch/cuBLAS and QuACK are included as reference ceilings, since they execute only the matrix multiplication and no epilogue work\.

## 4Experiments

After the reparameterizations in[Section˜3\.2](https://arxiv.org/html/2605.19269#S3.SS2), we obtain a compact benchmark suite of GEMM\-plus\-epilogue kernels spanning the Transformer\+\+ forward and backward pass \([Section˜C\.1](https://arxiv.org/html/2605.19269#A3.SS1)\)\. The suite covers nearly all computation outside attention, embeddings, auxiliary reductions, and lightweight glue operations\. We evaluate two implementations\.CODA \(LLM\)uses Claude Code to generate most kernels from a written specification, curated examples, and a running log of implementation tips, with lightweight human supervision\.CODA \(human\)is written by human programmers using the same high\-level reparameterizations, but without access to the exactCODAprimitive set\.

We compare against cuBLAS withtorch\.compile, as well as optimized LLM kernel libraries including Liger Kernel\[[6](https://arxiv.org/html/2605.19269#bib.bib6)\]and FlashInfer\[[23](https://arxiv.org/html/2605.19269#bib.bib23)\]\. Because our reparameterized kernels do not always have one\-to\-one counterparts in existing libraries, we compose the closest available optimized primitives and fall back to PyTorch operators as needed\. We applytorch\.compileto each method when compatible\. Additional setup details are given in[Section˜C\.2](https://arxiv.org/html/2605.19269#A3.SS2)\.

##### Kernel Benchmarks\.

We first evaluateCODAat the individual\-kernel level\. Unless otherwise noted, we use square GEMM shapes withM=N=K∈\{4096,8192\}M\{=\}N\{=\}K\\in\\\{4096,8192\\\}\. For cross\-entropy kernels, we set the vocabulary dimension to3276832768\. For RoPE kernels, we useNrope=3​NN\_\{\\mathrm\{rope\}\}\{=\}3Nto account for QKV\-style projections and use precomputedcos\\cosandsin\\sintables\.333ForCODA, we additionally pre\-broadcast and extend these tables across batch, head, and QKV dimensions to avoid in\-kernel branching, at the cost of additional input traffic\.For kernels that emit partial reductions, we benchmark only the fused GEMM kernel with reduction tile size128128; auxiliary reductions are included in the block\-level benchmarks below\. We benchmark functions using Triton’sdo\_benchand show means and standard deviations across3030runs\.

We evaluate two groups of kernels from[Section˜C\.1](https://arxiv.org/html/2605.19269#A3.SS1)\. The first group consists of standard Transformer\-style kernels, such as GEMM with RoPE, SwiGLU, or cross\-entropy epilogues\. These kernels have close counterparts in existing libraries, so we compare against cuBLAS withtorch\.compile, Liger Kernels, and FlashInfer when applicable\. Results are shown in[Figure˜8](https://arxiv.org/html/2605.19269#S3.F8)\.

The second group consists of reparameterized Transformer forward and backward kernels, which generally do not have one\-to\-one equivalents in existing libraries\. For these kernels, we primarily compare against cuBLAS withtorch\.compile, and additionally report raw GEMM from PyTorch/cuBLAS and QuACK\.444[https://github\.com/Dao\-AILab/quack](https://github.com/Dao-AILab/quack)These raw GEMM omit epilogue work and therefore serve as reference ceilings for the attainable throughput\. Results are shown in[Figure˜10](https://arxiv.org/html/2605.19269#S3.F10)\.

![Refer to caption](https://arxiv.org/html/2605.19269v2/x11.png)Figure 11:Block\-level speedups for reparameterized Transformer kernel sequences, including auxiliary reductions and lightweight glue operations\. Here, a layer denotes two consecutive GEMM\-Residual\-RMSNorm\-GEMM blocks with the SwiGLU and RoPE activations, respectively\.
##### Block Benchmarks\.

We next benchmark kernel sequences corresponding to Transformer sublayers and full layers, which we call*blocks*\. We use hidden sizes in\{2048,4096,8192\}\\\{2048,4096,8192\\\}, roughly matching 1B, 7B, and 70B model scales, with FFN expansion rate8/38/3rounded to multiples of256256and vocabulary size3276832768\. Unlike isolated kernel benchmarks, these measurements include auxiliary reductions and lightweight glue operations\.

For the forward pass, we compare each reparameterized sequence against the closest available sequence of optimized operators\. For the backward pass, the reparameterization changes the dependency structure: each sublayer emits partial statistics needed by the preceding RMSNorm backward\. Individual backward sublayers therefore do not have direct PyTorch counterparts, so we report backward results at the full\-layer level\. InCODA, a layer consists of two consecutive GEMM\-Residual\-RMSNorm\-GEMM blocks covering the SwiGLU and RoPE paths\. Results are shown in[Figure˜11](https://arxiv.org/html/2605.19269#S4.F11)\.

## 5Conclusion and Limitations

CODAreparameterizes much of Transformer computation as GEMM epilogues, reducing memory\-bound overhead while preserving GEMM efficiency\. Its constrained abstraction supports high\-performance kernels authored by both humans and LLMs\.

##### Limitations\.

Our reparameterizations target a common Transformer architecture; extending them to broader model families is future work\.CODAcurrently focuses on single\-GPU kernels and does not yet address distributed execution\. Finally, while reparameterization improves efficiency, it can obscure module boundaries and algorithmic semantics, making integration with framework\-level abstractions more challenging\.

## Acknowledgment

We thank Beshr Islam Bouli, Kaiming Cheng, Xinle Cheng, Ryan Chin, Tarushii Goel, Wentao Guo, Lucas Torroba Hennigen, Alicia Li, Mayank Mishra, Jyothish Pari, Caiming Xiong, Nicholas Yap, Tianyuan Zhang, and Adam Zweiger for helpful discussion\. We gratefully acknowledge the support of the Schmidt Sciences AI2050 fellowship, the Google ML and Systems Junior Faculty Awards, the Google Research Scholar program, and the National Science Foundation \(Award \#2441872\)\.

## References

- Ansel et al\. \[2024\]J\. Ansel, E\. Yang, H\. He, N\. Gimelshein, A\. Jain, M\. Voznesensky, B\. Bao, P\. Bell, D\. Berard, E\. Burovski, et al\.Pytorch 2: Faster machine learning through dynamic python bytecode transformation and graph compilation\.In*Proceedings of the 29th ACM international conference on architectural support for programming languages and operating systems, volume 2*, pages 929–947, 2024\.
- Chen et al\. \[2018\]T\. Chen, T\. Moreau, Z\. Jiang, L\. Zheng, E\. Yan, H\. Shen, M\. Cowan, L\. Wang, Y\. Hu, L\. Ceze, et al\.\{\\\{TVM\}\\\}: An automated\{\\\{End\-to\-End\}\\\}optimizing compiler for deep learning\.In*13th USENIX Symposium on Operating Systems Design and Implementation \(OSDI 18\)*, pages 578–594, 2018\.
- Chen et al\. \[2025\]W\. Chen, J\. Zhu, Q\. Fan, Y\. Ma, and A\. Zou\.Cuda\-llm: Llms can write efficient cuda kernels\.*arXiv preprint arXiv:2506\.09092*, 2025\.
- Chen et al\. \[2024\]Z\. Chen, A\. Kerr, R\. Cai, J\. Kosaian, H\. Wu, Y\. Ding, and Y\. Xie\.Evt: Accelerating deep learning training with epilogue visitor tree\.In*Proceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 3*, pages 301–316, 2024\.
- Grattafiori et al\. \[2024\]A\. Grattafiori, A\. Dubey, A\. Jauhri, A\. Pandey, A\. Kadian, A\. Al\-Dahle, A\. Letman, A\. Mathur, A\. Schelten, A\. Vaughan, et al\.The llama 3 herd of models\.*arXiv preprint arXiv:2407\.21783*, 2024\.
- Hsu et al\. \[2024\]P\.\-L\. Hsu, Y\. Dai, V\. Kothapalli, Q\. Song, S\. Tang, S\. Zhu, S\. Shimizu, S\. Sahni, H\. Ning, and Y\. Chen\.Liger kernel: Efficient triton kernels for llm training\.*arXiv preprint arXiv:2410\.10989*, 2024\.
- Ivanov et al\. \[2021\]A\. Ivanov, N\. Dryden, T\. Ben\-Nun, S\. Li, and T\. Hoefler\.Data movement is all you need: A case study on optimizing transformers\.*Proceedings of Machine Learning and Systems*, 3:711–732, 2021\.
- Jia et al\. \[2019\]Z\. Jia, O\. Padon, J\. Thomas, T\. Warszawski, M\. Zaharia, and A\. Aiken\.Taso: optimizing deep learning computation with automatic generation of graph substitutions\.In*Proceedings of the 27th ACM Symposium on Operating Systems Principles*, pages 47–62, 2019\.
- Kwon et al\. \[2023\]W\. Kwon, Z\. Li, S\. Zhuang, Y\. Sheng, L\. Zheng, C\. H\. Yu, J\. Gonzalez, H\. Zhang, and I\. Stoica\.Efficient memory management for large language model serving with pagedattention\.In*Proceedings of the 29th symposium on operating systems principles*, pages 611–626, 2023\.
- Lange et al\. \[2025\]R\. T\. Lange, Q\. Sun, A\. Prasad, M\. Faldor, Y\. Tang, and D\. Ha\.Towards robust agentic cuda kernel benchmarking, verification, and optimization\.*arXiv preprint arXiv:2509\.14279*, 2025\.
- Liang et al\. \[2024\]W\. Liang, T\. Liu, L\. Wright, W\. Constable, A\. Gu, C\.\-C\. Huang, I\. Zhang, W\. Feng, H\. Huang, J\. Wang, et al\.Torchtitan: One\-stop pytorch native solution for production ready llm pre\-training\.*arXiv preprint arXiv:2410\.06511*, 2024\.
- Ouyang et al\. \[2025\]A\. Ouyang, S\. Guo, S\. Arora, A\. L\. Zhang, W\. Hu, C\. Ré, and A\. Mirhoseini\.Kernelbench: Can llms write efficient gpu kernels?*arXiv preprint arXiv:2502\.10517*, 2025\.
- Spector et al\. \[2025\]B\. Spector, J\. Juravsky, S\. Sul, O\. Dugan, D\. Lim, D\. Fu, S\. Arora, and C\. Ré\.Look ma, no bubbles\! designing a low\-latency megakernel for llama\-1b, 2025\.
- Spector et al\. \[2024\]B\. F\. Spector, S\. Arora, A\. Singhal, D\. Y\. Fu, and C\. Ré\.Thunderkittens: Simple, fast, and adorable ai kernels\.*arXiv preprint arXiv:2410\.20399*, 2024\.
- Su et al\. \[2024\]J\. Su, M\. Ahmed, Y\. Lu, S\. Pan, W\. Bo, and Y\. Liu\.Roformer: Enhanced transformer with rotary position embedding\.*Neurocomputing*, 568:127063, 2024\.
- Su et al\. \[2025\]S\. Su, X\. Sun, X\. Li, A\. Wang, J\. Li, and C\. Shum\.Cuda\-l2: Surpassing cublas performance for matrix multiplication through reinforcement learning\.*arXiv preprint arXiv:2512\.02551*, 2025\.
- Sul et al\. \[2025\]S\. H\. Sul, S\. Arora, B\. F\. Spector, and C\. Ré\.Parallelkittens: Systematic and practical simplification of multi\-gpu ai kernels\.*arXiv preprint arXiv:2511\.13940*, 2025\.
- Thakkar et al\. \[2023\]V\. Thakkar, P\. Ramani, C\. Cecka, A\. Shivam, H\. Lu, E\. Yan, J\. Kosaian, M\. Hoemmen, H\. Wu, A\. Kerr, M\. Nicely, D\. Merrill, D\. Blasig, A\. Atluri, F\. Qiao, P\. Majcher, P\. Springer, M\. Hohnerbach, J\. Wang, and M\. Gupta\.CUTLASS, Jan\. 2023\.URL[https://github\.com/NVIDIA/cutlass](https://github.com/NVIDIA/cutlass)\.
- Tillet et al\. \[2019\]P\. Tillet, H\.\-T\. Kung, and D\. Cox\.Triton: an intermediate language and compiler for tiled neural network computations\.In*Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages*, pages 10–19, 2019\.
- Wang et al\. \[2025\]L\. Wang, Y\. Cheng, Y\. Shi, Z\. Tang, Z\. Mo, W\. Xie, L\. Ma, Y\. Xia, J\. Xue, F\. Yang, et al\.Tilelang: A composable tiled programming model for ai systems\.*arXiv preprint arXiv:2504\.17577*, 2025\.
- Wijmans et al\. \[2025\]E\. Wijmans, B\. Huval, A\. Hertzberg, V\. Koltun, and P\. Krähenbühl\.Cut your losses in large\-vocabulary language models\.In*International Conference on Learning Representations*, 2025\.
- Wu et al\. \[2025\]M\. Wu, X\. Cheng, S\. Liu, C\. Shi, J\. Ji, M\. K\. Ao, P\. Velliengiri, X\. Miao, O\. Padon, and Z\. Jia\.Mirage: A\{\\\{Multi\-Level\}\\\}superoptimizer for tensor programs\.In*19th USENIX Symposium on Operating Systems Design and Implementation \(OSDI 25\)*, pages 21–38, 2025\.
- Ye et al\. \[2025\]Z\. Ye, L\. Chen, R\. Lai, W\. Lin, Y\. Zhang, S\. Wang, T\. Chen, B\. Kasikci, V\. Grover, A\. Krishnamurthy, et al\.Flashinfer: Efficient and customizable attention engine for llm inference serving\.*Proceedings of Machine Learning and Systems*, 7, 2025\.
- Yuksekgonul et al\. \[2026\]M\. Yuksekgonul, D\. Koceja, X\. Li, F\. Bianchi, J\. McCaleb, X\. Wang, J\. Kautz, Y\. Choi, J\. Zou, C\. Guestrin, et al\.Learning to discover at test time\.*arXiv preprint arXiv:2601\.16175*, 2026\.
- Zheng et al\. \[2024\]L\. Zheng, L\. Yin, Z\. Xie, C\. L\. Sun, J\. Huang, C\. H\. Yu, S\. Cao, C\. Kozyrakis, I\. Stoica, J\. E\. Gonzalez, et al\.Sglang: Efficient execution of structured language model programs\.*Advances in neural information processing systems*, 37:62557–62583, 2024\.

## Appendix ABackward Pass

### A\.1Tile\-wise Epilogue

Partition the GEMM output𝒉\{\\bm\{h\}\}into tiles𝒉\[i,j\]\{\\bm\{h\}\}\_\{\[i,j\]\}\. A tile\-wise epilogue applies an independent transformation to each tile:

𝒉\\displaystyle\{\\bm\{h\}\}=𝒙​𝑾0,𝒉′=\[𝒇​\[0,0\]​\(𝒉​\[0,0\]\)⋯𝒇​\[0,N\]​\(𝒉​\[0,N\]\)⋮⋱⋮𝒇​\[M,0\]​\(𝒉​\[M,0\]\)⋯𝒇​\[M,N\]​\(𝒉​\[M,N\]\)\],𝒚=𝒉′​𝑾1\.\\displaystyle=\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\},\\qquad\{\\bm\{h\}\}^\{\\prime\}=\\begin\{bmatrix\}\{\\bm\{f\}\}\[0,0\]\\left\(\{\\bm\{h\}\}\[0,0\]\\right\)&\\cdots&\{\\bm\{f\}\}\[0,N\]\\left\(\{\\bm\{h\}\}\[0,N\]\\right\)\\\\ \\vdots&\\ddots&\\vdots\\\\ \{\\bm\{f\}\}\[M,0\]\\left\(\{\\bm\{h\}\}\[M,0\]\\right\)&\\cdots&\{\\bm\{f\}\}\[M,N\]\\left\(\{\\bm\{h\}\}\[M,N\]\\right\)\\end\{bmatrix\},\\qquad\{\\bm\{y\}\}=\{\\bm\{h\}\}^\{\\prime\}\{\\bm\{W\}\}\_\{1\}\.
The backward pass has the same block structure\.

∇𝒉′ℒ\\displaystyle\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}=∇𝒚ℒ​𝑾1⊤,∇𝒉ℒ=\[𝒈​\[0,0\]​\(∇𝒉′ℒ​\[0,0\]\)⋯𝒈​\[0,N\]​\(∇𝒉′ℒ​\[0,N\]\)⋮⋱⋮𝒈​\[M,0\]​\(∇𝒉′ℒ​\[M,0\]\)⋯𝒈​\[M,N\]​\(∇𝒉′ℒ​\[M,N\]\)\],∇𝒙ℒ=∇𝒉ℒ​𝑾0⊤\.\\displaystyle=\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\},\\hskip 17\.00024pt\\nabla\_\{\{\\bm\{h\}\}\}\\mathcal\{L\}=\\begin\{bmatrix\}\{\\bm\{g\}\}\[0,0\]\\left\(\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}\[0,0\]\\right\)&\\cdots&\{\\bm\{g\}\}\[0,N\]\\left\(\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}\[0,N\]\\right\)\\\\ \\vdots&\\ddots&\\vdots\\\\ \{\\bm\{g\}\}\[M,0\]\\left\(\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}\[M,0\]\\right\)&\\cdots&\{\\bm\{g\}\}\[M,N\]\\left\(\\nabla\_\{\{\\bm\{h\}\}^\{\\prime\}\}\\mathcal\{L\}\[M,N\]\\right\)\\end\{bmatrix\},\\hskip 17\.00024pt\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{h\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{0\}^\{\\top\}\.Here each𝒈​\[i,j\]\{\\bm\{g\}\}\[i,j\]is the local backward rule for the corresponding tile:

𝒈​\[i,j\]​\(Δ\)=unvec⁡\(𝑱\[i,j\]⊤​vec⁡\(Δ\)\),𝑱\[i,j\]=∂vec⁡\(f​\[i,j\]​\(𝒉​\[i,j\]\)\)∂vec⁡\(𝒉​\[i,j\]\)\.\\displaystyle\{\\bm\{g\}\}\[i,j\]\(\\Delta\)=\\operatorname\{unvec\}\\\!\\left\(\{\\bm\{J\}\}\_\{\[i,j\]\}^\{\\top\}\\operatorname\{vec\}\(\\Delta\)\\right\),\\qquad\{\\bm\{J\}\}\_\{\[i,j\]\}=\\frac\{\\partial\\operatorname\{vec\}\\left\(f\[i,j\]\\left\(\{\\bm\{h\}\}\[i,j\]\\right\)\\right\)\}\{\\partial\\operatorname\{vec\}\\left\(\{\\bm\{h\}\}\[i,j\]\\right\)\}\.
Thus, each gradient tile depends only on the corresponding forward tile and upstream gradient tile\. No cross\-tile communication is introduced, so the backward operation remains tile\-local and can be implemented as a GEMM epilogue\. As shown in[Figure˜9](https://arxiv.org/html/2605.19269#S3.F9), the only structural change is the direction of fusion: forward epilogues are fused into the GEMM that produces their input, while backward epilogues are fused into the GEMM that produces the gradient with respect to their output\.

### A\.2GEMM\-RMSNorm\-GEMM Backward Pass

We now describe the backward pass for the GEMM–epilogue–RMSNorm–GEMM pattern\. RMSNorm is the first case where the backward pass is not purely tile\-local\. The reason is simple: RMSNorm contains a row\-wise normalization factor, so its backward pass needs a row\-wise statistic\. In addition, the RMSNorm weight𝜸\\bm\{\\gamma\}is shared across rows, so its gradient requires a reduction across the row dimension\. The goal of this section is to show that these are the only non\-local pieces\. Everything else can still be fused into GEMM epilogues, with the non\-local pieces handled by lightweight reductions over tile partials\.

Consider the forward computation

𝒉0\\displaystyle\{\\bm\{h\}\}\_\{0\}=𝒙​𝑾0,𝒉1=f​\(𝒉0\),𝒉2=RMSNorm⁡\(𝒉1,𝜸\),𝒚=𝒉2​𝑾1\.\\displaystyle=\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\},\\qquad\{\\bm\{h\}\}\_\{1\}=f\(\{\\bm\{h\}\}\_\{0\}\),\\qquad\{\\bm\{h\}\}\_\{2\}=\\operatorname\{RMSNorm\}\(\{\\bm\{h\}\}\_\{1\},\\bm\{\\gamma\}\),\\qquad\{\\bm\{y\}\}=\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}\.Let𝒓\{\\bm\{r\}\}be the row\-wise inverse RMS factor\. We write𝒓¯=𝒓​𝟏⊤\\overline\{\{\\bm\{r\}\}\}=\{\\bm\{r\}\}\\mathbf\{1\}^\{\\top\}and𝜸¯=𝟏​𝜸⊤\\overline\{\\bm\{\\gamma\}\}=\\mathbf\{1\}\\bm\{\\gamma\}^\{\\top\}for the broadcasts of𝒓\{\\bm\{r\}\}and𝜸\\bm\{\\gamma\}to the shape of𝒉1\{\\bm\{h\}\}\_\{1\}\. Then

𝒉2=𝒓¯⊙𝒉1⊙𝜸¯\.\\displaystyle\{\\bm\{h\}\}\_\{2\}=\\overline\{\{\\bm\{r\}\}\}\\odot\{\\bm\{h\}\}\_\{1\}\\odot\\overline\{\\bm\{\\gamma\}\}\.
Given the upstream gradient∇𝒚ℒ\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}, the first backward operation is the GEMM

∇𝒉2ℒ=∇𝒚ℒ​𝑾1⊤\.\\displaystyle\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\}\.The RMSNorm backward can be written as

∇𝒉1ℒ\\displaystyle\\nabla\_\{\{\\bm\{h\}\}\_\{1\}\}\\mathcal\{L\}=𝒓¯⊙\(∇𝒉2ℒ⊙𝜸¯−𝒓¯⊙𝒉1⊙𝒔¯\),\\displaystyle=\\overline\{\{\\bm\{r\}\}\}\\odot\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\\overline\{\\bm\{\\gamma\}\}\-\\overline\{\{\\bm\{r\}\}\}\\odot\{\\bm\{h\}\}\_\{1\}\\odot\\overline\{\{\\bm\{s\}\}\}\\right\),∇𝜸ℒ\\displaystyle\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}=sumrows⁡\(∇𝒉2ℒ⊙𝒉1⊙𝒓¯\),\\displaystyle=\\operatorname\{sum\}\_\{\\mathrm\{rows\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\{\\bm\{h\}\}\_\{1\}\\odot\\overline\{\{\\bm\{r\}\}\}\\right\),where𝒔¯=𝒔​𝟏⊤\\overline\{\{\\bm\{s\}\}\}=\{\\bm\{s\}\}\\mathbf\{1\}^\{\\top\}broadcasts one scalar per row\. The row\-wise statistic𝒔\{\\bm\{s\}\}is

𝒔\\displaystyle\{\\bm\{s\}\}=1d⊙𝒓⊙sumcols⁡\(∇𝒉2ℒ⊙𝜸¯⊙𝒉1\),\\displaystyle=\\frac\{1\}\{d\}\\odot\{\\bm\{r\}\}\\odot\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\\overline\{\\bm\{\\gamma\}\}\\odot\{\\bm\{h\}\}\_\{1\}\\right\),=1d​sumcols⁡\(∇𝒉2ℒ⊙𝒓¯⊙𝜸¯⊙𝒉1\),\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\\overline\{\{\\bm\{r\}\}\}\\odot\\overline\{\\bm\{\\gamma\}\}\\odot\{\\bm\{h\}\}\_\{1\}\\right\),=1d​sumcols⁡\(∇𝒉2ℒ⊙𝒉2\),\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\{\\bm\{h\}\}\_\{2\}\\right\),whereddis the hidden dimension\. This expression identifies the two non\-local operations in RMSNorm backward\. The statistic𝒔\{\\bm\{s\}\}is a reduction across columns, producing one scalar per row\. The weight gradient∇𝜸ℒ\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}is a reduction across rows, producing one scalar per hidden feature\.

A standalone RMSNorm backward kernel would compute these reductions by reading activation\-sized tensors\. The key observation is that the row\-wise statistic𝒔\{\\bm\{s\}\}can be moved to a different boundary\. Using∇𝒉2ℒ=∇𝒚ℒ​𝑾1⊤\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}=\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\\,\{\\bm\{W\}\}\_\{1\}^\{\\top\}and𝒚=𝒉2​𝑾1\{\\bm\{y\}\}=\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}, we have

𝒔\\displaystyle\{\\bm\{s\}\}=1d​sumcols⁡\(∇𝒉2ℒ⊙𝒉2\),\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\{\\bm\{h\}\}\_\{2\}\\right\),=1d​sumcols⁡\(\(∇𝒚ℒ​𝑾1⊤\)⊙𝒉2\)\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\(\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\}\)\\odot\{\\bm\{h\}\}\_\{2\}\\right\)=1d​diag⁡\(\(∇𝒚ℒ​𝑾1⊤\)​𝒉2⊤\)\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{diag\}\\left\(\(\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\}\)\\;\{\\bm\{h\}\}\_\{2\}^\{\\top\}\\right\)=1d​diag⁡\(∇𝒚ℒ​\(𝒉2​𝑾1\)⊤\)\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{diag\}\\left\(\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\\;\(\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}\)^\{\\top\}\\right\)=1d​sumcols⁡\(∇𝒚ℒ⊙\(𝒉2​𝑾1\)\)\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{cols\}\}\\left\(\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\\odot\(\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}\)\\right\)=1d​sumoutput⁡\(∇𝒚ℒ⊙𝒚\)\.\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{sum\}\_\{\\mathrm\{output\}\}\\left\(\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\\odot\{\\bm\{y\}\}\\right\)\.Intuitively, the RMSNorm backward needs the inner product between an activation and its gradient along each row\. The identity above says that this inner product can be computed either before or after the following GEMM\. This lets us compute the statistic at a boundary where𝒚\{\\bm\{y\}\}and∇𝒚ℒ\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}are already available\.

This is useful because Transformer layers contain consecutive GEMM–epilogue–RMSNorm–GEMM patterns\. During the backward pass of one pattern, the GEMM that produces∇𝒙ℒ\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}already has access to both𝒙\{\\bm\{x\}\}and∇𝒙ℒ\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}\. Since this𝒙\{\\bm\{x\}\}is the output of the preceding pattern, the epilogue of the current pattern can accumulate the RMSNorm statistic needed by the preceding pattern:

𝒔^prev=reduceTilecols⁡\(𝒙⊙∇𝒙ℒ\),𝒔prev=1d​reduce⁡\(𝒔^prev\)\.\\displaystyle\\widehat\{\{\\bm\{s\}\}\}\_\{\\mathrm\{prev\}\}=\\operatorname\{reduceTile\}\_\{\\mathrm\{cols\}\}\\left\(\{\\bm\{x\}\}\\odot\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}\\right\),\\qquad\{\\bm\{s\}\}\_\{\\mathrm\{prev\}\}=\\frac\{1\}\{d\}\\operatorname\{reduce\}\\left\(\\widehat\{\{\\bm\{s\}\}\}\_\{\\mathrm\{prev\}\}\\right\)\.Thus, each pattern computes the row\-wise RMSNorm backward statistic required by the pattern before it\. The reduction is still present, but it is now a small reduction over tile partials rather than a standalone activation\-sized RMSNorm backward kernel\.

The RMSNorm weight gradient is handled similarly, except that its reduction is across rows rather than columns\. We accumulate tile partials in the RMSNorm backward epilogue:

∇𝜸ℒ^=reduceTilerows⁡\(∇𝒉2ℒ⊙𝒉1⊙𝒓¯\),∇𝜸ℒ=reducerows⁡\(∇𝜸ℒ^\)\.\\displaystyle\\widehat\{\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}\}=\\operatorname\{reduceTile\}\_\{\\mathrm\{rows\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\{\\bm\{h\}\}\_\{1\}\\odot\\overline\{\{\\bm\{r\}\}\}\\right\),\\qquad\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}=\\operatorname\{reduce\}\_\{\\mathrm\{rows\}\}\\left\(\\widehat\{\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}\}\\right\)\.
Putting these pieces together, the backward pass is organized as follows:

GEMM 1:∇𝒉2ℒ\\displaystyle\\text\{GEMM 1:\}\\quad\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}=∇𝒚ℒ​𝑾1⊤,\\displaystyle=\\nabla\_\{\{\\bm\{y\}\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{1\}^\{\\top\},Epilogue 1:∇𝒉1ℒ\\displaystyle\\text\{Epilogue 1:\}\\quad\\nabla\_\{\{\\bm\{h\}\}\_\{1\}\}\\mathcal\{L\}=𝒓¯⊙\(∇𝒉2ℒ⊙𝜸¯−𝒉1⊙𝒓¯⊙𝒔¯\),\\displaystyle=\\overline\{\{\\bm\{r\}\}\}\\odot\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\\overline\{\\bm\{\\gamma\}\}\-\{\\bm\{h\}\}\_\{1\}\\odot\\overline\{\{\\bm\{r\}\}\}\\odot\\overline\{\{\\bm\{s\}\}\}\\right\),∇𝒉0ℒ\\displaystyle\\nabla\_\{\{\\bm\{h\}\}\_\{0\}\}\\mathcal\{L\}=g​\(∇𝒉1ℒ\),\\displaystyle=g\(\\nabla\_\{\{\\bm\{h\}\}\_\{1\}\}\\mathcal\{L\}\),∇𝜸ℒ^\\displaystyle\\widehat\{\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}\}=reduceTilerows⁡\(∇𝒉2ℒ⊙𝒉1⊙𝒓¯\),\\displaystyle=\\operatorname\{reduceTile\}\_\{\\mathrm\{rows\}\}\\left\(\\nabla\_\{\{\\bm\{h\}\}\_\{2\}\}\\mathcal\{L\}\\odot\{\\bm\{h\}\}\_\{1\}\\odot\\overline\{\{\\bm\{r\}\}\}\\right\),GEMM 2:∇𝒙ℒ\\displaystyle\\text\{GEMM 2:\}\\quad\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}=∇𝒉0ℒ​𝑾0⊤,\\displaystyle=\\nabla\_\{\{\\bm\{h\}\}\_\{0\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{0\}^\{\\top\},Epilogue 2:𝒔^prev\\displaystyle\\text\{Epilogue 2:\}\\quad\\widehat\{\{\\bm\{s\}\}\}\_\{\\mathrm\{prev\}\}=reduceTilecols⁡\(𝒙⊙∇𝒙ℒ\),\\displaystyle=\\operatorname\{reduceTile\}\_\{\\mathrm\{cols\}\}\\left\(\{\\bm\{x\}\}\\odot\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}\\right\),Auxiliary reductions:𝒔prev\\displaystyle\\text\{Auxiliary reductions:\}\\quad\{\\bm\{s\}\}\_\{\\mathrm\{prev\}\}=1d​reducecols⁡\(𝒔^prev\),\\displaystyle=\\frac\{1\}\{d\}\\operatorname\{reduce\}\_\{\\mathrm\{cols\}\}\\left\(\\widehat\{\{\\bm\{s\}\}\}\_\{\\mathrm\{prev\}\}\\right\),∇𝜸ℒ\\displaystyle\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}=reducerows⁡\(∇𝜸ℒ^\)\.\\displaystyle=\\operatorname\{reduce\}\_\{\\mathrm\{rows\}\}\\left\(\\widehat\{\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}\}\\right\)\.Hereggdenotes the tile\-local backward rule for the epilogueff\. The statistic𝒔\{\\bm\{s\}\}used inEpilogue 1is assumed to have already been accumulated by the following pattern in the backward order\.

Finally, the output of a pattern often passes through another epilogue before the next pattern begins:

𝒉0\\displaystyle\{\\bm\{h\}\}\_\{0\}=𝒙​𝑾0,𝒉1=f0​\(𝒉0\),𝒉2=RMSNorm⁡\(𝒉1,𝜸\),𝒉3=𝒉2​𝑾1𝒚=f1​\(𝒉3\)\.\\displaystyle=\{\\bm\{x\}\}\{\\bm\{W\}\}\_\{0\},\\qquad\{\\bm\{h\}\}\_\{1\}=f\_\{0\}\(\{\\bm\{h\}\}\_\{0\}\),\\qquad\{\\bm\{h\}\}\_\{2\}=\\operatorname\{RMSNorm\}\(\{\\bm\{h\}\}\_\{1\},\\bm\{\\gamma\}\),\\qquad\{\\bm\{h\}\}\_\{3\}=\{\\bm\{h\}\}\_\{2\}\{\\bm\{W\}\}\_\{1\}\\qquad\{\\bm\{y\}\}=f\_\{1\}\(\{\\bm\{h\}\}\_\{3\}\)\.In this case, the statistic should be accumulated using the pre\-epilogue tensor𝒉3\{\\bm\{h\}\}\_\{3\}and its gradient\. The backward rule forf1f\_\{1\}is tile\-local, so it can be fused before accumulating the statistic:

GEMM 2:∇𝒙ℒ\\displaystyle\\text\{GEMM 2:\}\\quad\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}=∇𝒉0ℒ​𝑾0⊤,\\displaystyle=\\nabla\_\{\{\\bm\{h\}\}\_\{0\}\}\\mathcal\{L\}\{\\bm\{W\}\}\_\{0\}^\{\\top\},Epilogue 2:∇𝒉3←ℒ\\displaystyle\\text\{Epilogue 2:\}\\quad\\nabla\_\{\\overleftarrow\{\{\\bm\{h\}\}\_\{3\}\}\}\\mathcal\{L\}=f1←​\(∇𝒙ℒ\),\\displaystyle=\\overleftarrow\{f\_\{1\}\}\(\\nabla\_\{\{\\bm\{x\}\}\}\\mathcal\{L\}\),𝒔^prev\\displaystyle\\widehat\{\{\\bm\{s\}\}\}\_\{\\mathrm\{prev\}\}=reduceTilecols⁡\(𝒉3←⊙∇𝒉3←ℒ\)\\displaystyle=\\operatorname\{reduceTile\}\_\{\\mathrm\{cols\}\}\\left\(\\overleftarrow\{\{\\bm\{h\}\}\_\{3\}\}\\odot\\nabla\_\{\\overleftarrow\{\{\\bm\{h\}\}\_\{3\}\}\}\\mathcal\{L\}\\right\)Here𝒉3←\\overleftarrow\{\{\\bm\{h\}\}\_\{3\}\}denotes the pre\-epilogue GEMM output from the preceding pattern, andf1←\\overleftarrow\{f\_\{1\}\}denotes the local backward rule for its following epilogue\. This case preserves the same structure: apply the local backward epilogue first, then accumulate the row\-wise statistic from the pre\-epilogue activation and its gradient\.

Overall, this organization removes the activation\-sized RMSNorm backward kernel\. The remaining non\-local work consists only of reductions over tile partials: a column reduction for the row\-wise RMSNorm statistic, and a row reduction for the RMSNorm weight gradient\. The GEMMs and local backward updates are fused into GEMM epilogues, preserving the same GEMM\-plus\-epilogue structure used in the forward pass\.

## Appendix BCODA

### B\.1Epilogue Template

1

2

3

4epilogue\.consumer\_begin\(\.\.\.\)

5epilogue\.producer\_begin\(\.\.\.\)

6

7

8forepi\_idxinrange\(num\_epi\_tiles\):

9

10epilogue\.consumer\_begin\_loop\(\.\.\.\)

11

12epilogue\.producer\_tma\_load\(\.\.\.\)

13

14rD=load\_accumulator\_fragment\(\.\.\.\)

15epilogue\.consumer\_visit\(rD,\.\.\.\)

16

17store\_regs\_to\_smem\(\.\.\.\)

18

19epilogue\.consumer\_smem\_store\(\.\.\.\)

20tma\_store\_from\_smem\_to\_gmem\(\.\.\.\)

21epilogue\.consumer\_tma\_store\(\.\.\.\)

22

23epilogue\.consumer\_end\_loop\(gmem\_coord\)

24

25

26epilogue\.consumer\_end\(\.\.\.\)

Listing 1:Epilogue Kernel Abstraction\.
### B\.2Epilogue Example

1def\_create\_mean\_sq\_reduction\_op\(element\_type,inv\_block\_size\):

2"""Createareductionopthataccumulatesmeanofsquares:acc\+val^2\*inv\_block\_size\.

3

4Thecombine\_fnsquareseachnewelementandscalesby1/block\_sizebeforeadding

5totheaccumulator\.Thewarp\-levelreductionusesstandardadditionsincepartial

6sumsarealreadyaccumulatedandscaled\.

7"""

8init\_value=element\_type\(0\.\)

9inv\_bs=element\_type\(inv\_block\_size\)

10

11\_sq\_combine=lambdax,y:x\+y\*y\*inv\_bs

12\_add\_wrp=lambdatree\_x,tree\_y:pytree\.tree\_map\(operator\.add,tree\_x,tree\_y\)

13

14returnBlockReductionOp\(

15combine\_fn=lambdatree\_x,tree\_y:pytree\.tree\_map\(\_sq\_combine,tree\_x,tree\_y\),

16reduce\_ssa=None,

17reduce\_wrp=lambdaxs:pytree\.tree\_map\(

18lambdax:cute\.arch\.warp\_reduction\(

19x,

20op=\_add\_wrp,

21threads\_in\_group=HOPPER\_WARP\_REDUCTION\_WIDTH,

22\),

23xs,

24\),

25init\_value=init\_value,

26\)

27

28

29classEVTRowVecMulPostAct\(EpilogueVisitorTree\):

30"""

31Loadsaper\-NrowvectorW\(cp\.asynctosmem,thens2r\),multipliesthe

32accumulatorbyWintoaseparateregistertile,andstoresthatscaled

33tiletoasideoutputmPostActviaTMA\.tRS\_rDitselfisleftunchanged

34sothemainDoutput\(theunscaledGEMMresult\)isunaffected\.

35

36Thismirrorstherowvec=norm\_weightside\-outputpathoftrainstation’s

37‘gemm\_partial\_rms\_fwd‘,keptlocaltothiskernelratherthanasa

38general\-purposerapierEVT\.

39

40Inputs:

41\-GEMMoutput\(inregisters\):\[MxN\],unchangedbythisop

42\-mRowVec:\[L,N\]\-RMSNormweight,broadcastalongM

43

44Outputs:

45\-mPostAct:\[MxN\]=D\*W\(sideoutput,writtenviaTMA\)

46"""

47

48@struct\_utils\.mlir\_namedtuple

49classEpilogueArguments\(NamedTuple\):

50mPostAct:cute\.Tensor\|None

51mRowVec:cute\.Tensor\|None

52

53@struct\_utils\.register\_pytree\_dataclass

54@dataclass

55classEpilogueParams\(EpilogueVisitorTree\.EpilogueParams\):

56mPostAct:cute\.Tensor\|None

57mRowVec:cute\.Tensor\|None

58epi\_tma\_atom:cute\.CopyAtom

59epi\_gmem\_layout:cutlass\.utils\.LayoutEnum

60epi\_smem\_layout\_staged:cute\.Layout

61

62@struct\_utils\.register\_pytree\_dataclass

63@dataclass

64classEpilogueTensorsSMem\(EpilogueVisitorTree\.EpilogueTensorsSMem\):

65sPostAct:cute\.Tensor\|None

66sRowVec:cute\.Tensor\|None

67

68@struct\_utils\.register\_pytree\_dataclass

69@dataclass

70classEpilogueTensors\(EpilogueVisitorTree\.EpilogueTensors\):

71tDsPostAct:cute\.Tensor

72tDgPostAct:cute\.Tensor

73tRS\_sPostAct:cute\.Tensor

74epi\_tma\_atom:cute\.CopyAtom

75tiled\_copy\_postact\_r2s:cute\.TiledCopy

76tDsRowVec:cute\.Tensor\|None

77

78@struct\_utils\.register\_pytree\_dataclass

79@dataclass

80classEpilogueTensorsLoop\(EpilogueVisitorTree\.EpilogueTensorsLoop\):

81tDsPostAct:cute\.Tensor

82tDgPostAct:cute\.Tensor

83tRS\_rPostAct:cute\.Tensor\|None

84tRS\_sPostAct:cute\.Tensor

85epi\_tma\_atom:cute\.CopyAtom

86tiled\_copy\_postact\_r2s:cute\.TiledCopy

87tDrRowVec\_epi:cute\.Tensor\|None

88

89@struct\_utils\.register\_pytree\_dataclass

90@dataclass

91classEpiloguePipelines\(EpilogueVisitorTree\.EpiloguePipelines\):

92pass

93

94def\_\_init\_\_\(

95self,

96acc\_dtype:type\[cute\.Numeric\],

97post\_act\_dtype:type\[cute\.Numeric\],

98tile\_shape\_mnk:tuple\[int,int,int\],

99buffer\_align\_bytes:int,

100\)\-\>None:

101super\(\)\.\_\_init\_\_\(\)

102self\.arch=90

103self\.acc\_dtype=acc\_dtype

104self\.post\_act\_dtype=post\_act\_dtype

105self\.container\_dtype=post\_act\_dtype

106self\.tile\_shape\_mnk=tile\_shape\_mnk

107self\.buffer\_align\_bytes=buffer\_align\_bytes

108

109@cute\.jit

110defto\_underlying\_arguments\(

111self,

112epi\_tile:cute\.Tile,

113epi\_stage:int,

114epi\_load\_stage:int,

115epi\_args:EpilogueArguments,

116\)\-\>EpilogueParams:

117

118ifcutlass\.const\_expr\(epi\_args\.mPostActisnotNone\):

119mPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_args\.mPostAct\)

120misc\_utils\.static\_assert\(get\_dtype\(mPostAct\)isself\.container\_dtype\)

121\(

122epi\_gmem\_layout,

123epi\_smem\_layout\_staged,

124epi\_tma\_atom,

125epi\_tma\_tensor,

126\)=epilogue\_utils\.prepare\_tma\(

127tma\_op="s2g",

128epi\_tile=epi\_tile,

129epi\_stage=epi\_stage,

130epi\_tensor=mPostAct,

131\)

132

133ifcutlass\.const\_expr\(epi\_args\.mRowVecisnotNone\):

134misc\_utils\.static\_assert\(epi\_args\.mPostActisnotNone\)

135mRowVec=misc\_utils\.static\_assert\_is\_Tensor\(epi\_args\.mRowVec\)

136mRowVec=layout\_utils\.assumed\_align\_stride\(

137mRowVec,

138assumed\_align=4,

139\)

140else:

141mRowVec=None

142

143returnself\.EpilogueParams\(

144mPostAct=epi\_tma\_tensor,

145mRowVec=mRowVec,

146epi\_tma\_atom=epi\_tma\_atom,

147epi\_gmem\_layout=epi\_gmem\_layout,

148epi\_smem\_layout\_staged=epi\_smem\_layout\_staged,

149\)

150

151@cute\.jit

152defprefetch\_tma\_descriptors\(

153self,

154epi\_params:EpilogueParams,

155\)\-\>None:

156cute\.nvgpu\.cpasync\.prefetch\_descriptor\(epi\_params\.epi\_tma\_atom\)

157

158@cute\.jit

159defconsumer\_begin\(

160self,

161tiled\_copy\_r2s:cute\.TiledCopy,

162tile\_coord\_mnkl:cute\.Coord,

163tidx:cute\.Int32,

164tiled\_mma:cute\.TiledMma,

165tRS\_rD\_layout:cute\.Layout,

166epi\_tile:cute\.Tile,

167epi\_num\_threads:int,

168epi\_num\_matrices:int,

169epi\_barrier:cutlass\.pipeline\.NamedBarrier,

170epi\_params:EpilogueParams,

171epi\_tensors\_smem:EpilogueTensorsSMem,

172\)\-\>EpilogueTensors:

173

174tile\_M=self\.tile\_shape\_mnk\[0\]

175tile\_N=self\.tile\_shape\_mnk\[1\]

176m\_idx,n\_idx,\_,batch\_idx=tile\_coord\_mnkl

177thr\_copy\_r2s=tiled\_copy\_r2s\.get\_slice\(tidx\)

178

179

180mPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_params\.mPostAct\)

181sPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_smem\.sPostAct\)

182tiled\_copy\_postact\_r2s,\_,tRS\_sPostAct=epilogue\_utils\.prepare\_copy\_r2s\_sm90\(

183tiled\_copy\_r2s=tiled\_copy\_r2s,

184tidx=tidx,

185dst=sPostAct,

186epi\_layout=epi\_params\.epi\_gmem\_layout,

187epi\_dtype=self\.container\_dtype,

188acc\_dtype=self\.acc\_dtype,

189\)

190gPostAct=mPostAct\[None,None,batch\_idx\]

191gPostAct=cute\.local\_tile\(gPostAct,\(tile\_M,tile\_N\),\(m\_idx,n\_idx\)\)

192gPostAct=cute\.zipped\_divide\(gPostAct,epi\_tile\)

193

194tDsPostAct,tDgPostAct=cute\.nvgpu\.cpasync\.tma\_partition\(

195atom=epi\_params\.epi\_tma\_atom,

196cta\_coord=0,

197cta\_layout=cute\.make\_layout\(1\),

198smem\_tensor=cute\.group\_modes\(sPostAct,0,cute\.rank\(sPostAct\)\-1\),

199gmem\_tensor=cute\.group\_modes\(gPostAct,0,cute\.rank\(gPostAct\)\-1\),

200\)

201

202

203ifcutlass\.const\_expr\(epi\_params\.mRowVecisnotNone\):

204mRowVec=misc\_utils\.static\_assert\_is\_Tensor\(epi\_params\.mRowVec\)

205sRowVec=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_smem\.sRowVec\)

206mRowVec=mRowVec\[batch\_idx,None\]

207gRowVec=cute\.local\_tile\(mRowVec,\(tile\_N,\),\(n\_idx,\)\)

208cRowVec=cute\.make\_identity\_tensor\(tile\_N\)

209limit\_n=min\(mRowVec\.shape\[0\]\-n\_idx\*tile\_N,tile\_N\)

210memory\_utils\.g2s\_copy\_1d\(

211src=gRowVec,

212dst=sRowVec,

213crd=cRowVec,

214shape=\(limit\_n,\),

215num\_threads=epi\_num\_threads,

216thread\_index=tidx,

217\)

218sRowVec\_view\_layout=cute\.make\_layout\(

219shape=\(tile\_M,tile\_N\),

220stride=\(0,1\),

221\)

222sRowVec\_view=cute\.make\_tensor\(

223iterator=sRowVec\.iterator,

224layout=sRowVec\_view\_layout,

225\)

226tDsRowVec=thr\_copy\_r2s\.partition\_S\(

227cute\.flat\_divide\(sRowVec\_view,epi\_tile\)

228\)

229cute\.arch\.cp\_async\_commit\_group\(\)

230cute\.arch\.cp\_async\_wait\_group\(0\)

231epi\_barrier\.arrive\_and\_wait\(\)

232else:

233tDsRowVec=None

234

235returnself\.EpilogueTensors\(

236tDsPostAct=tDsPostAct,

237tDgPostAct=tDgPostAct,

238tRS\_sPostAct=tRS\_sPostAct,

239epi\_tma\_atom=epi\_params\.epi\_tma\_atom,

240tiled\_copy\_postact\_r2s=tiled\_copy\_postact\_r2s,

241tDsRowVec=tDsRowVec,

242\)

243

244@cute\.jit

245defconsumer\_end\(

246self,

247tiled\_copy\_r2s:cute\.TiledCopy,

248tile\_coord\_mnkl:cute\.Coord,

249tidx:cute\.Int32,

250shape\_mnk:cute\.Shape,

251epi\_tile:cute\.Tile,

252epi\_num\_threads:int,

253epi\_barrier:cutlass\.pipeline\.NamedBarrier,

254epi\_params:EpilogueParams,

255epi\_tensors:EpilogueTensors,

256epi\_tensors\_smem:EpilogueTensorsSMem,

257\)\-\>None:

258pass

259

260@cute\.jit

261defconsumer\_begin\_loop\(

262self,

263epi\_coord:cute\.Coord,

264epi\_params:EpilogueParams,

265epi\_tensors:EpilogueTensors,

266epi\_pipelines:EpiloguePipelines,

267\)\-\>tuple\[EpilogueTensorsLoop,EpiloguePipelines\]:

268

269ifcutlass\.const\_expr\(epi\_tensors\.tDsRowVecisnotNone\):

270tDsRowVec=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\.tDsRowVec\)

271tDsRowVec\_cur=cute\.group\_modes\(tDsRowVec,3,cute\.rank\(tDsRowVec\)\)

272tDsRowVec\_cur=tDsRowVec\_cur\[None,None,None,epi\_coord\]

273tDrRowVec\_cvt=memory\_utils\.s2r\_copy\_1d\(tDsRowVec\_cur,dtype=self\.acc\_dtype\)

274else:

275tDrRowVec\_cvt=None

276

277return\(

278self\.EpilogueTensorsLoop\(

279tDsPostAct=epi\_tensors\.tDsPostAct,

280tDgPostAct=epi\_tensors\.tDgPostAct,

281tRS\_rPostAct=None,

282tRS\_sPostAct=epi\_tensors\.tRS\_sPostAct,

283epi\_tma\_atom=epi\_tensors\.epi\_tma\_atom,

284tiled\_copy\_postact\_r2s=epi\_tensors\.tiled\_copy\_postact\_r2s,

285tDrRowVec\_epi=tDrRowVec\_cvt,

286\),

287self\.EpiloguePipelines\(\),

288\)

289

290@cute\.jit

291defconsumer\_visit\(

292self,

293tRS\_rD:cute\.Tensor,

294shape\_mnk:cute\.Shape,

295epi\_params:EpilogueParams,

296epi\_tensors\_loop:EpilogueTensorsLoop,

297\)\-\>EpilogueTensorsLoop:

298

299tRS\_rPostAct=creation\_utils\.allocate\_tensor\_like\(

300tensor=tRS\_rD,

301memspace="rmem",

302smem\_allocator=None,

303dtype=self\.acc\_dtype,

304\)

305ifcutlass\.const\_expr\(self\.arch<100\):

306ifcutlass\.const\_expr\(epi\_tensors\_loop\.tDrRowVec\_epiisnotNone\):

307tDrRowVec\_epi=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_loop\.tDrRowVec\_epi\)

308foriincutlass\.range\_constexpr\(cute\.size\(tRS\_rPostAct\)\):

309tRS\_rPostAct\[i\]=tRS\_rD\[i\]\*tDrRowVec\_epi\[i\]

310else:

311foriincutlass\.range\_constexpr\(cute\.size\(tRS\_rPostAct\)\):

312tRS\_rPostAct\[i\]=tRS\_rD\[i\]

313else:

314raiseNotImplementedError

315

316tRS\_rPostAct=dtype\_utils\.convert\(

317tRS\_rPostAct,

318dtype=self\.post\_act\_dtype,

319\)

320

321returnself\.EpilogueTensorsLoop\(

322tDsPostAct=epi\_tensors\_loop\.tDsPostAct,

323tDgPostAct=epi\_tensors\_loop\.tDgPostAct,

324tRS\_rPostAct=tRS\_rPostAct,

325tRS\_sPostAct=epi\_tensors\_loop\.tRS\_sPostAct,

326epi\_tma\_atom=epi\_tensors\_loop\.epi\_tma\_atom,

327tiled\_copy\_postact\_r2s=epi\_tensors\_loop\.tiled\_copy\_postact\_r2s,

328tDrRowVec\_epi=epi\_tensors\_loop\.tDrRowVec\_epi,

329\)

330

331@cute\.jit

332defconsumer\_smem\_store\(

333self,

334epi\_coord:cute\.Coord,

335epi\_buffer:cute\.Int32,

336epi\_params:EpilogueParams,

337epi\_tensors\_loop:EpilogueTensorsLoop,

338\)\-\>None:

339tiled\_copy=epi\_tensors\_loop\.tiled\_copy\_postact\_r2s

340tRS\_rPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_loop\.tRS\_rPostAct\)

341tRS\_sPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_loop\.tRS\_sPostAct\)

342src=tiled\_copy\.retile\(tRS\_rPostAct\)

343dst=tRS\_sPostAct\[None,None,None,epi\_buffer\]

344cute\.copy\(atom=tiled\_copy,src=src,dst=dst\)

345

346@cute\.jit

347defconsumer\_tma\_store\(

348self,

349epi\_coord:cute\.Coord,

350epi\_buffer:cute\.Int32,

351epi\_params:EpilogueParams,

352epi\_tensors\_loop:EpilogueTensorsLoop,

353\)\-\>None:

354atom=epi\_tensors\_loop\.epi\_tma\_atom

355tDsPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_loop\.tDsPostAct\)

356tDgPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_tensors\_loop\.tDgPostAct\)

357src=tDsPostAct\[None,epi\_buffer\]

358dst=tDgPostAct\[None,epi\_coord\]

359cute\.copy\(atom=atom,src=src,dst=dst\)

360

361@cute\.jit

362defget\_smem\_struct\(

363self,

364epi\_load\_stage:int,

365epi\_num\_threads:int,

366epi\_params:EpilogueParams,

367\)\-\>type\[EpilogueSharedStorage\]:

368

369ifcutlass\.const\_expr\(epi\_params\.mPostActisnotNone\):

370post\_act\_smem\_size=cute\.cosize\(epi\_params\.epi\_smem\_layout\_staged\)

371else:

372post\_act\_smem\_size=0

373

374ifcutlass\.const\_expr\(epi\_params\.mRowVecisnotNone\):

375mRowVec=misc\_utils\.static\_assert\_is\_Tensor\(epi\_params\.mRowVec\)

376row\_vec\_dtype=get\_dtype\(mRowVec\)

377row\_vec\_smem\_size=epilogue\_utils\.get\_smem\_size\_vector\(

378mTensor=mRowVec,

379epi\_tile=self\.tile\_shape\_mnk\[1\],

380epi\_num\_threads=epi\_num\_threads,

381\)

382else:

383row\_vec\_dtype=cute\.Float32

384row\_vec\_smem\_size=0

385

386@cute\.struct

387classSharedStorage\(EpilogueSharedStorage\):

388sPostAct:cute\.struct\.Align\[cute\.struct\.MemRange\[self\.container\_dtype,post\_act\_smem\_size\],self\.buffer\_align\_bytes\]

389sRowVec:cute\.struct\.Align\[cute\.struct\.MemRange\[row\_vec\_dtype,row\_vec\_smem\_size\],16\]

390

391returnSharedStorage

392

393@cute\.jit

394defget\_smem\_tensors\(

395self,

396storage:EpilogueSharedStorage,

397epi\_num\_threads:int,

398epi\_params:EpilogueParams,

399\)\-\>EpilogueTensorsSMem:

400

401ifcutlass\.const\_expr\(epi\_params\.mPostActisnotNone\):

402sPostAct=storage\.sPostAct\.get\_tensor\(

403epi\_params\.epi\_smem\_layout\_staged\.outer,

404swizzle=epi\_params\.epi\_smem\_layout\_staged\.inner,

405\)

406else:

407sPostAct=None

408

409ifcutlass\.const\_expr\(epi\_params\.mRowVecisnotNone\):

410sRowVec\_layout=cute\.make\_layout\(self\.tile\_shape\_mnk\[1\]\)

411sRowVec=storage\.sRowVec\.get\_tensor\(sRowVec\_layout\)

412else:

413sRowVec=None

414

415returnself\.EpilogueTensorsSMem\(

416sPostAct=sPostAct,

417sRowVec=sRowVec,

418\)

419

420@cute\.jit

421defget\_smem\_bytes\_per\_stage\(

422self,

423epi\_tile:cute\.Tile,

424epi\_num\_threads:int,

425epi\_args:EpilogueArguments,

426\)\-\>tuple\[int,int,int\]:

427epi\_smem\_bytes\_fixed=0

428epi\_smem\_bytes\_per\_stage\_cst=0

429epi\_smem\_bytes\_per\_stage\_pld=0

430

431ifcutlass\.const\_expr\(epi\_args\.mPostActisnotNone\):

432mPostAct=misc\_utils\.static\_assert\_is\_Tensor\(epi\_args\.mPostAct\)

433misc\_utils\.static\_assert\(get\_dtype\(mPostAct\)isself\.container\_dtype\)

434epi\_smem\_bytes\_per\_stage\_cst=epi\_smem\_bytes\_per\_stage\_cst\+\(

435epilogue\_utils\.get\_epi\_smem\_bytes\_per\_stage\_matrix\(

436mTensor=mPostAct,

437epi\_tile=epi\_tile,

438\)

439\)

440

441ifcutlass\.const\_expr\(epi\_args\.mRowVecisnotNone\):

442mRowVec=misc\_utils\.static\_assert\_is\_Tensor\(epi\_args\.mRowVec\)

443epi\_smem\_bytes\_fixed=epi\_smem\_bytes\_fixed\+\(

444epilogue\_utils\.get\_epi\_smem\_bytes\_per\_stage\_fixed\_vector\(

445mTensor=mRowVec,

446epi\_tile=self\.tile\_shape\_mnk\[1\],

447epi\_num\_threads=epi\_num\_threads,

448\)

449\)

450

451return\(

452epi\_smem\_bytes\_fixed,

453epi\_smem\_bytes\_per\_stage\_cst,

454epi\_smem\_bytes\_per\_stage\_pld,

455\)

456

457

458defprepare\_epilogue\(

459shape\_mnkl:tuple\[int,int,int,int\],

460tile\_shape\_mn:tuple\[int,int\],

461C:torch\.Tensor,

462S:torch\.Tensor,

463W:torch\.Tensor,

464O:torch\.Tensor,

465\)\-\>tuple\[

466Callable\[\.\.\.,EpilogueVisitorTree\],

467EpilogueVisitorTree\.EpilogueArguments,

468dict,

469tuple,

470\]:

471"""PrepareepilogueforGEMMwithresidual,partialmean\-of\-squares,and

472fusedper\-NRMSNorm\-weightscaling\-mirrorstrainstation’s‘gemm\_partial\_rms\_fwd‘\.

473

474ComposesthreeEVTvisitors:

4751\.EVTResidual:D=acc\+C

4762\.EVTColBlockReductionStore:S\[m,nb\]=mean\(D\[m,nb\*bs:\(nb\+1\)\*bs\]^2\)

4773\.EVTRowVecMulPostAct\(local\):O\[m,n\]=D\[m,n\]\*W\[n\],sideoutputviaTMA

478

479Thepartialsum\-of\-squaresiscomputedonthe\*unscaled\*D,soadownstream

480rstdreductionseestheGEMMoutputbeforeWisapplied\.tRS\_rDispreserved

481sothemainDoutputisalsounscaled\.

482

483Args:

484shape\_mnkl:Problemshape\(M,N,K,L\)whereLisbatchdimension\.

485tile\_shape\_mn:CTAtileshape\(tile\_M,tile\_N\)\.

486C:Residualmatrixofshape\(M,N\)\.

487S:Outputforpartialmean\-of\-squaresofshape\(M,num\_blocks\)infp32\.

488W:RMSNormweightofshape\(N,\),broadcastacrossM\.

489O:Outputofshape\(M,N\)forD\*W\.

490

491Returns:

492Tupleof\(epi\_cls,epi\_args,epi\_outs,epi\_keys\)\.

493"""

494M,N,K,L=shape\_mnkl

495

496epi\_dtype=torch2cute\_dtype\_map\[C\.dtype\]

497post\_act\_dtype=torch2cute\_dtype\_map\[O\.dtype\]

498

499epi\_cls=lambdaacc\_dtype,tile\_shape\_mnk,buffer\_align\_bytes:EVTList\(\[

500EVTResidual\(

501acc\_dtype=acc\_dtype,

502epi\_dtype=epi\_dtype,

503tile\_shape\_mnk=tile\_shape\_mnk,

504buffer\_align\_bytes=buffer\_align\_bytes,

505\),

506EVTColBlockReductionStore\(

507reduction\_op=\_create\_mean\_sq\_reduction\_op\(

508element\_type=acc\_dtype,

509inv\_block\_size=1\.0/tile\_shape\_mnk\[1\],

510\),

511tile\_shape\_mnk=tile\_shape\_mnk,

512\),

513EVTRowVecMulPostAct\(

514acc\_dtype=acc\_dtype,

515post\_act\_dtype=post\_act\_dtype,

516tile\_shape\_mnk=tile\_shape\_mnk,

517buffer\_align\_bytes=buffer\_align\_bytes,

518\),

519\]\)

520

521epi\_args=EVTList\.EpilogueArguments\(\[

522EVTResidual\.EpilogueArguments\(

523mMatrix=C,

524\),

525EVTColBlockReductionStore\.EpilogueArguments\(

526mColVec=S,

527\),

528EVTRowVecMulPostAct\.EpilogueArguments\(

529mPostAct=O,

530mRowVec=W,

531\),

532\]\)

533

534epi\_keys=\(

535C\.dtype,

536S\.dtype,

537W\.dtype,

538O\.dtype,

539EVTResidual,

540EVTColBlockReductionStore,

541EVTRowVecMulPostAct,

542\)

543

544epi\_outs=\{\}

545

546returnepi\_cls,epi\_args,epi\_outs,epi\_keys

Listing 2:Kernel Example\.

## Appendix CExperiments

### C\.1List of Kernels

We summarize the kernels implemented inCODA\. Each kernel is a GEMM followed by an epilogue program\.

#### C\.1\.1Basic Epilogue Kernels

We first list three basic GEMM\-plus\-epilogue kernels\. These are useful for isolating individual epilogue primitives, although they are not always used directly in the Transformer forward pass\.

Kernel 1: GEMM with RoPE\.This kernel applies RoPE\[[15](https://arxiv.org/html/2605.19269#bib.bib15)\]to pairs of adjacent features in the GEMM output:

𝑫\\displaystyle\{\\bm\{D\}\}=𝑨​𝑩,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\},𝑶\\displaystyle\{\\bm\{O\}\}=RoPE⁡\(𝑫\)\.\\displaystyle=\\operatorname\{RoPE\}\(\{\\bm\{D\}\}\)\.
Kernel 2: GEMM with SwiGLU\.This kernel applies a fused SwiGLU activation to an interleaved GEMM output:

𝑫\\displaystyle\{\\bm\{D\}\}=𝑨​𝑩,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\},\[𝑮,𝑼\]\\displaystyle\[\{\\bm\{G\}\},\{\\bm\{U\}\}\]=interleavedSplit⁡\(𝑫\),\\displaystyle=\\operatorname\{interleavedSplit\}\(\{\\bm\{D\}\}\),𝑶\\displaystyle\{\\bm\{O\}\}=silu⁡\(𝑮\)⊙𝑼\.\\displaystyle=\\operatorname\{silu\}\(\{\\bm\{G\}\}\)\\odot\{\\bm\{U\}\}\.
Kernel 3: GEMM with partial cross\-entropy\.This kernel computes logits, selects the target logit, and emits block\-wise log\-sum\-exp statistics for the cross\-entropy loss:

𝒁\\displaystyle\{\\bm\{Z\}\}=𝑨​𝑩,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\},𝒛tgt\\displaystyle\{\\bm\{z\}\}\_\{\\mathrm\{tgt\}\}=𝒁​\[𝒚\],\\displaystyle=\{\\bm\{Z\}\}\[\{\\bm\{y\}\}\],𝒍^lse\\displaystyle\\widehat\{\{\\bm\{l\}\}\}\_\{\\mathrm\{lse\}\}=reduceTilelog​∑exp⁡\(𝒁\)\.\\displaystyle=\\operatorname\{reduceTile\}\_\{\\log\\sum\\exp\}\(\{\\bm\{Z\}\}\)\.

#### C\.1\.2Forward\-Pass Kernels

The following kernels implement the reparameterized Transformer forward pass\. They compose the basic epilogue primitives with RMSNorm scaling, residual updates, and partial reductions\.

Kernel 4: GEMM with residual, partial RMSNorm, and weight scaling\.This kernel implements the first stage of the GEMM–Residual–RMSNorm–GEMM pattern\. It forms the residual\-updated activation, emits partial RMS statistics, and applies the RMSNorm weight:

𝑫\\displaystyle\{\\bm\{D\}\}=𝑨​𝑩\+𝑪,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\}\+\{\\bm\{C\}\},𝒓^\\displaystyle\\widehat\{\{\\bm\{r\}\}\}=reduceTilecols⁡\(𝑫⊙𝑫\),\\displaystyle=\\operatorname\{reduceTile\}\_\{\\mathrm\{cols\}\}\(\{\\bm\{D\}\}\\odot\{\\bm\{D\}\}\),𝑶\\displaystyle\{\\bm\{O\}\}=𝑫⊙𝜸\.\\displaystyle=\{\\bm\{D\}\}\\odot\\bm\{\\gamma\}\.
Kernel 5: GEMM with RMSNorm scaling\.This kernel consumes a precomputed row\-wise normalization factor and applies it in the GEMM epilogue:

𝑫\\displaystyle\{\\bm\{D\}\}=𝑨​𝑩,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\},𝑶\\displaystyle\{\\bm\{O\}\}=𝑫⊙𝒓\.\\displaystyle=\{\\bm\{D\}\}\\odot\{\\bm\{r\}\}\.
Kernel 6: GEMM with RMSNorm and SwiGLU\.This kernel composes row\-wise RMSNorm scaling with SwiGLU, corresponding to the MLP gate/up projection:

𝑫\\displaystyle\{\\bm\{D\}\}=𝑨​𝑩,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\},𝑫′\\displaystyle\{\\bm\{D\}\}^\{\\prime\}=𝑫⊙𝒓,\\displaystyle=\{\\bm\{D\}\}\\odot\{\\bm\{r\}\},\[𝑮,𝑼\]\\displaystyle\[\{\\bm\{G\}\},\{\\bm\{U\}\}\]=interleavedSplit⁡\(𝑫′\),\\displaystyle=\\operatorname\{interleavedSplit\}\(\{\\bm\{D\}\}^\{\\prime\}\),𝑶\\displaystyle\{\\bm\{O\}\}=silu⁡\(𝑮\)⊙𝑼\.\\displaystyle=\\operatorname\{silu\}\(\{\\bm\{G\}\}\)\\odot\{\\bm\{U\}\}\.
Kernel 7: GEMM with RMSNorm and RoPE\.This kernel composes row\-wise RMSNorm scaling with RoPE, corresponding to QKV projection followed by rotary positional embedding:

𝑫\\displaystyle\{\\bm\{D\}\}=𝑨​𝑩,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\},𝑫′\\displaystyle\{\\bm\{D\}\}^\{\\prime\}=𝑫⊙𝒓,\\displaystyle=\{\\bm\{D\}\}\\odot\{\\bm\{r\}\},𝑶\\displaystyle\{\\bm\{O\}\}=RoPE⁡\(𝑫′\)\.\\displaystyle=\\operatorname\{RoPE\}\(\{\\bm\{D\}\}^\{\\prime\}\)\.
Kernel 8: GEMM with RMSNorm and partial cross\-entropy\.This kernel adds row\-wise RMSNorm scaling before target\-logit selection and partial log\-sum\-exp reduction, corresponding to the language\-modeling head:

𝒁\\displaystyle\{\\bm\{Z\}\}=\(𝑨​𝑩\)⊙𝒓,\\displaystyle=\(\{\\bm\{A\}\}\{\\bm\{B\}\}\)\\odot\{\\bm\{r\}\},𝒛tgt\\displaystyle\{\\bm\{z\}\}\_\{\\mathrm\{tgt\}\}=𝒁​\[𝒚\],\\displaystyle=\{\\bm\{Z\}\}\[\{\\bm\{y\}\}\],𝒍^lse\\displaystyle\\widehat\{\{\\bm\{l\}\}\}\_\{\\mathrm\{lse\}\}=reduceTilelog​∑exp⁡\(𝒁\)\.\\displaystyle=\\operatorname\{reduceTile\}\_\{\\log\\sum\\exp\}\(\{\\bm\{Z\}\}\)\.

#### C\.1\.3Backward\-Pass Kernels

Finally, we list the backward kernels\. These kernels mirror the forward structure: each performs a GEMM, applies the local backward rule in the epilogue, and emits partial reductions needed by neighboring RMSNorm backward computations\.

Kernel 9: GEMM with residual and RMSNorm backward\.This kernel implements the local part of RMSNorm backward\. Let𝑪\{\\bm\{C\}\}denote the RMSNorm input,𝒓\{\\bm\{r\}\}the row\-wise inverse RMS factor,𝜸\\bm\{\\gamma\}the RMSNorm weight, and𝒛Δ​z\{\\bm\{z\}\}\_\{\\Delta z\}the row\-wise RMSNorm backward statistic:

𝑫\\displaystyle\{\\bm\{D\}\}=𝑨​𝑩⊤,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\}^\{\\top\},𝑪norm\\displaystyle\{\\bm\{C\}\}\_\{\\mathrm\{norm\}\}=𝑪⊙𝒓,\\displaystyle=\{\\bm\{C\}\}\\odot\{\\bm\{r\}\},𝑶out\\displaystyle\{\\bm\{O\}\}\_\{\\mathrm\{out\}\}=𝑶in\+\(𝑫⊙𝜸−𝑪norm⊙𝒛Δ​z\)⊙𝒓,\\displaystyle=\{\\bm\{O\}\}\_\{\\mathrm\{in\}\}\+\\left\(\{\\bm\{D\}\}\\odot\\bm\{\\gamma\}\-\{\\bm\{C\}\}\_\{\\mathrm\{norm\}\}\\odot\{\\bm\{z\}\}\_\{\\Delta z\}\\right\)\\odot\{\\bm\{r\}\},𝑪out\\displaystyle\{\\bm\{C\}\}\_\{\\mathrm\{out\}\}=𝑪norm⊙𝜸,\\displaystyle=\{\\bm\{C\}\}\_\{\\mathrm\{norm\}\}\\odot\\bm\{\\gamma\},∇𝜸ℒ^\\displaystyle\\widehat\{\\nabla\_\{\\bm\{\\gamma\}\}\\mathcal\{L\}\}=reduceTilerows⁡\(𝑫⊙𝑪norm\)\.\\displaystyle=\\operatorname\{reduceTile\}\_\{\\mathrm\{rows\}\}\\left\(\{\\bm\{D\}\}\\odot\{\\bm\{C\}\}\_\{\\mathrm\{norm\}\}\\right\)\.
Kernel 10: GEMM with SwiGLU backward\.This kernel computes the backward pass of a fused SwiGLU epilogue and emits the row\-wise statistic needed by the preceding RMSNorm backward\. Let𝒁\{\\bm\{Z\}\}be the saved interleaved pre\-activation tensor:

𝑫\\displaystyle\{\\bm\{D\}\}=𝑨​𝑩⊤,\\displaystyle=\{\\bm\{A\}\}\{\\bm\{B\}\}^\{\\top\},\[𝑮,𝑼\]\\displaystyle\[\{\\bm\{G\}\},\{\\bm\{U\}\}\]=interleavedSplit⁡\(𝒁\),\\displaystyle=\\operatorname\{interleavedSplit\}\(\{\\bm\{Z\}\}\),𝑶\\displaystyle\{\\bm\{O\}\}=silu⁡\(𝑮\)⊙𝑼,\\displaystyle=\\operatorname\{silu\}\(\{\\bm\{G\}\}\)\\odot\{\\bm\{U\}\},∇𝑼ℒ\\displaystyle\\nabla\_\{\{\\bm\{U\}\}\}\\mathcal\{L\}=𝑫⊙silu⁡\(𝑮\),\\displaystyle=\{\\bm\{D\}\}\\odot\\operatorname\{silu\}\(\{\\bm\{G\}\}\),∇𝑮ℒ\\displaystyle\\nabla\_\{\{\\bm\{G\}\}\}\\mathcal\{L\}=𝑫⊙𝑼⊙\(σ​\(𝑮\)\+silu⁡\(𝑮\)⊙\(1−σ​\(𝑮\)\)\),\\displaystyle=\{\\bm\{D\}\}\\odot\{\\bm\{U\}\}\\odot\\left\(\\sigma\(\{\\bm\{G\}\}\)\+\\operatorname\{silu\}\(\{\\bm\{G\}\}\)\\odot\(1\-\\sigma\(\{\\bm\{G\}\}\)\)\\right\),∇𝒁ℒ\\displaystyle\\nabla\_\{\{\\bm\{Z\}\}\}\\mathcal\{L\}=interleavedConcat⁡\(∇𝑮ℒ,∇𝑼ℒ\),\\displaystyle=\\operatorname\{interleavedConcat\}\\left\(\\nabla\_\{\{\\bm\{G\}\}\}\\mathcal\{L\},\\nabla\_\{\{\\bm\{U\}\}\}\\mathcal\{L\}\\right\),𝒛Δ​z^\\displaystyle\\widehat\{\{\\bm\{z\}\}\_\{\\Delta z\}\}=reduceTilecols⁡\(𝑮⊙∇𝑮ℒ\+𝑼⊙∇𝑼ℒ\)\.\\displaystyle=\\operatorname\{reduceTile\}\_\{\\mathrm\{cols\}\}\\left\(\{\\bm\{G\}\}\\odot\\nabla\_\{\{\\bm\{G\}\}\}\\mathcal\{L\}\+\{\\bm\{U\}\}\\odot\\nabla\_\{\{\\bm\{U\}\}\}\\mathcal\{L\}\\right\)\.

### C\.2Setup Details

Experiments are conducted using a single H100 GPU\. We use the following package versions\.

1. 1\.PyTorch 2\.10\.0
2. 2\.CuTeDSL 4\.4\.2
3. 3\.Liger Kernels 0\.8\.0
4. 4\.FlashInfer 0\.6\.10\.post1
5. 5\.QuACK Kernels 0\.4\.1

Similar Articles

Effect of Demographic Bias on Skin Lesion Classification

arXiv cs.AI

This paper investigates the impact of demographic bias (sex and age) on skin lesion classification using ResNet models, finding that sex biases stem from data imbalances while age biases consistently favor younger groups, and evaluating multi-task and adversarial learning mitigation strategies.