Learned Relay Representations for Forward-Thinking Discrete Diffusion Models

arXiv cs.LG Papers

Summary

This paper introduces Learned Relay Representations (Relay), a method that allows masked diffusion models to propagate latent information across denoising steps, overcoming the hard reset problem and improving performance-latency trade-offs. The method is shown to outperform standard supervised finetuning on coding tasks while reducing inference latency by up to 32%.

arXiv:2605.22967v1 Announce Type: new Abstract: When Masked Diffusion Models (MDMs) generate sequences through iterative refinement, the rich internal computation over masked positions is discarded, forcing every subsequent refinement step to recompute the valuable internal information stored as model representations. To avoid a hard reset between denoising rounds, we propose Learned Relay Representations (Relay), a method that allows MDMs to be forward-thinking when denoising by explicitly learning how to propagate latent information for the benefit of future denoising steps. Relay introduces a differentiable per-token channel that passes information between forward passes and is trained via truncated backpropagation through time (BPTT). We show that this framework can be scaled to state-of-the-art Diffusion Language Models (DLMs), and is seamlessly compatible with techniques like block diffusion and KV caching. We first provide a thorough justification of the design choices in Relay on a challenging Sudoku-based planning task. We then scale Relay to Fast-dLLM v2, a state-of-the-art DLM, outperforming standard supervised finetuning on coding tasks while reducing inference latency by up to 32%. Our empirical results demonstrate that state-of-the-art DLMs can be explicitly trained to relay latent information forward across decoding steps, advancing the performance-latency Pareto frontier. We provide code for all our experiments.
Original Article
View Cached Full Text

Cached at: 05/25/26, 08:57 AM

# Learned Relay Representations for Forward-Thinking Discrete Diffusion Models
Source: [https://arxiv.org/html/2605.22967](https://arxiv.org/html/2605.22967)
Benjamin Rozonoyer1Jacopo Minniti211footnotemark:1Dhruvesh Patel111footnotemark:1Neil Band3Avishek Joey Bose4,5Tim G\. J\. Rudner2,6Andrew McCallum11University of Massachusetts Amherst2University of Toronto3Stanford University4Imperial College London5Mila6Vijil

###### Abstract

When Masked Diffusion Models \(MDMs\) generate sequences through iterative refinement, the rich internal computation over masked positions is discarded—forcing every subsequent refinement step to recompute the valuable internal information stored as model representations\. To avoid a hard reset between denoising rounds, we propose Learned Relay Representations \(Relay\), a method that allows MDMs to be “forward\-thinking” when denoising—*explicitly learning how to propagate latent information for the benefit of future denoising steps*\.Relayintroduces a differentiable per\-token channel that passes information between forward passes and is trained via truncated backpropagation through time \(BPTT\)\. We show that this framework can be scaled to state\-of\-the\-art Diffusion Language Models \(DLMs\), and is seamlessly compatible with techniques like block diffusion and KV caching\. We first provide a thorough justification of the design choices inRelayon a challenging Sudoku\-based planning task\. We then scaleRelayto Fast\-dLLM v2, a state\-of\-the\-art DLM, outperforming standard supervised finetuning on coding tasks while reducing the inference latency by up to 32%\. Our empirical results demonstrate that state\-of\-the\-art DLMs can be explicitly trained torelaylatent information forward across decoding steps, advancing the performance\-latency Pareto frontier\. We provide code for all our experiments\.

## 1Introduction

Masked Diffusion Models \(MDMs\) generate discrete sequences via iterative denoising\(Austin et al\.,[2021](https://arxiv.org/html/2605.22967#bib.bib1); Campbell et al\.,[2022](https://arxiv.org/html/2605.22967#bib.bib2); Sahoo et al\.,[2024](https://arxiv.org/html/2605.22967#bib.bib3); Shi et al\.,[2024](https://arxiv.org/html/2605.22967#bib.bib4)\): starting from a fully masked canvas, each forward pass unmasks a fraction of the remaining positions\. The Transformer computes hidden states at every position—including those still masked—but discards them at the end of each step, beginning the next pass from the partially unmasked sequence alone\. We call this the*hard reset*problem: the only information that persists across steps is the discrete tokens just committed, leaving MDMs with no way to accumulate intermediate continuous computation\.

This matters because recurrent computation—unrolling a fixed\-parameter model across many steps—is exactly the structural property that recent work has tied to improved performance on difficult reasoning tasks, as it effectively expands the function class the model can approximate\(Gatmiry et al\.,[2024](https://arxiv.org/html/2605.22967#bib.bib5); Saunshi et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib6); Li et al\.,[2024](https://arxiv.org/html/2605.22967#bib.bib7)\)\. MDMs already perform many forward passes per generation; the hard reset is what prevents any of that compute from being reused\.

This raises a natural question: How can the sequential unmasking structure of MDMs support recurrent computation that carries richer information across steps?

Our answer is Learned Relay Representations \(Relay\), a method that makes discrete diffusion models*forward\-thinking*: at each denoising step, alongside any newly unmasked tokens, the model carries its last\-layer hidden states forward as a learned*relay*, giving the next forward pass direct access to the prior step’s continuous computation\. Simply piping these states forward, however, does not by itself ensure they encode anything useful for what follows\.Relaytherefore trains the relay end\-to\-end with truncated backpropagation through time\(BPTT; Werbos,[1990](https://arxiv.org/html/2605.22967#bib.bib8)\), shaping it to be maximally informative for the next several denoising steps and enabling a form of latent chain\-of\-thought across the unmasking trajectory\.

Contributions\. We introduceRelay, which equips MDMs with learned relay representations—continuous latent states passed forward across decoding steps and trained end\-to\-end via truncated BPTT\.Relayis architecture\-agnostic and leaves the inference\-time decoding procedure of MDMs \(unmasking schedule, sampling\) unchanged; the only addition at inference is forwarding the relay alongside the committed tokens\. It is also compatible with prevalent DLM acceleration techniques, including block diffusion\(Arriola et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib9)\)and KV caching\(Wu et al\.,[2025a](https://arxiv.org/html/2605.22967#bib.bib10),[b](https://arxiv.org/html/2605.22967#bib.bib11)\)\.

To summarize, our key contributions are as follows:

1. 1\.We proposeRelay, a general method for incorporating recurrent computation in MDMs by training the model—via truncated BPTT—to pass a learned latent relay forward across decoding steps\.Relaycan train an MDM from scratch or adapt a pre\-trained MDM through lightweight adaptation\.
2. 2\.We validateRelayat LLM scale through full\-parameter adaptation of Fast\-dLLM v2 1\.5B\(Wu et al\.,[2025b](https://arxiv.org/html/2605.22967#bib.bib11)\), outperforming standard supervised finetuning on coding tasks while reducing inference latency by up to 32%\.
3. 3\.We perform extensive ablations that map out the design space ofRelayand validate our choices\.

## 2Background: Masked Diffusion Models

We tackle the hard reset problem in masked diffusion models by training them to pass along a learned relay state\. Before presenting our approach,Relay, we review the training and inference procedure for Masked Diffusion Models \(MDMs\)\(Shi et al\.,[2024](https://arxiv.org/html/2605.22967#bib.bib4); Sahoo et al\.,[2024](https://arxiv.org/html/2605.22967#bib.bib3)\)thatRelaybuilds upon\.

Notation\. We denote the vocabulary as𝒱\{\\mathcal\{V\}\}, including the\[M\]token\. The space of sequences of lengthLLover the vocabulary is𝒱L\{\\mathcal\{V\}\}^\{L\}\. Superscripts denote the position in the sequence, e\.g\.,xix^\{i\}is theii\-th token in the sequence𝒙∈𝒱L\{\\bm\{x\}\}\\in\{\\mathcal\{V\}\}^\{L\}\.ℳ​\(𝒙\)⊆\[L\]\\mathcal\{M\}\(\{\\bm\{x\}\}\)\\subseteq\[L\]denotes the set of masked positions in the sequence𝒙\{\\bm\{x\}\}\.

Training\. The noising process proceeds by sampling a timet∈\[0,1\]t\\in\[0,1\]and masking each position in a clean sequence𝒙0∈\(𝒱∖\{\[M\]\}\)L\{\\bm\{x\}\}\_\{0\}\\in\(\{\\mathcal\{V\}\}\\setminus\\\{\\texttt\{\[M\]\}\\\}\)^\{L\}independently with probabilityαt\\alpha\_\{t\}, to obtain the noised \(partially masked\) sequence𝒙t\{\\bm\{x\}\}\_\{t\}\. The coordinate\-wise posterior distributionℙ​\(X0i=x0i\|𝑿t=𝒙t\)\{\\mathbb\{P\}\}\(X\_\{0\}^\{i\}=x\_\{0\}^\{i\}\\,\|\\,\\bm\{X\}\_\{t\}=\{\\bm\{x\}\}\_\{t\}\)is denoted asp​\(x0i∣𝒙t\)p\(x\_\{0\}^\{i\}\\mid\{\\bm\{x\}\}\_\{t\}\)\. As noted inZheng et al\. \([2024](https://arxiv.org/html/2605.22967#bib.bib12)\), this posterior depends on𝒙t\{\\bm\{x\}\}\_\{t\}only through its masked pattern and revealed tokens, not on the timettitself\. The coordinate\-wise posterior is parameterized by a neural network denoted aspθi\(⋅∣𝒙t\)∈Δp\_\{\\theta\}^\{i\}\(\\cdot\\mid\{\\bm\{x\}\}\_\{t\}\)\\in\\Deltafori∈ℳ​\(𝒙t\)i\\in\\mathcal\{M\}\(\{\\bm\{x\}\}\_\{t\}\)and is trained by minimizing the weighted sum of cross\-entropy losses for each masked position111We have assumed a linear noise schedule\.

ℒ​\(θ\)=𝔼𝒙0,t,𝒙t\[1t​∑i:𝒙ti=\[M\]−log⁡pθi​\(x0i∣𝒙t\)\]\.\{\\mathcal\{L\}\}\(\\theta\)=\\mathop\{\\mathbb\{E\}\}\_\{\{\\bm\{x\}\}\_\{0\},t,\{\\bm\{x\}\}\_\{t\}\}\\left\[\\frac\{1\}\{t\}\\sum\\nolimits\_\{i:\{\\bm\{x\}\}\_\{t\}^\{i\}=\\texttt\{\[M\]\}\}\-\\log p^\{i\}\_\{\\theta\}\(x\_\{0\}^\{i\}\\mid\{\\bm\{x\}\}\_\{t\}\)\\right\]\.\(1\)The coordinate\-wise parametric posterior is implemented using embeddingEmbθ:𝒱→ℝd\\text\{\{Emb\}\}\_\{\\theta\}:\{\\mathcal\{V\}\}\\to\\mathbb\{R\}^\{d\}, unembeddingUnEmbθ:ℝd→ℝ\|𝒱\|\\text\{\{UnEmb\}\}\_\{\\theta\}:\\mathbb\{R\}^\{d\}\\to\\mathbb\{R\}^\{\|\{\\mathcal\{V\}\}\|\}, and a transformer backbonefθ:𝒱L→ℝL×df\_\{\\theta\}:\{\\mathcal\{V\}\}^\{L\}\\to\\mathbb\{R\}^\{L\\times d\}that produce the posterior distribution:

pθi​\(w∣𝒙t\)\\displaystyle p\_\{\\theta\}^\{i\}\(w\\mid\{\\bm\{x\}\}\_\{t\}\)=eℓi​\(w\)∑w′∈𝒱eℓi​\(w′\),whereℓi​\(w\)=UnEmbθ​\(fθ​\(Embθ​\(𝒙t\)\)\)wi\.\\displaystyle=\\frac\{e^\{\\ell^\{i\}\(w\)\}\}\{\\sum\_\{w^\{\\prime\}\\in\{\\mathcal\{V\}\}\}e^\{\\ell^\{i\}\(w^\{\\prime\}\)\}\},\\quad\\text\{where\}\\quad\\ell^\{i\}\(w\)=\\text\{\{UnEmb\}\}\_\{\\theta\}\(f\_\{\\theta\}\(\\text\{\{Emb\}\}\_\{\\theta\}\(\{\\bm\{x\}\}\_\{t\}\)\)\)^\{i\}\_\{w\}\.
Inference\. Generation proceeds along a decreasing time grid1=t0\>t1\>⋯\>tK=01=t\_\{0\}\>t\_\{1\}\>\\cdots\>t\_\{K\}=0, iteratively unmasking positions from the all\-masked sequence𝒙t0=\(\[M\],…,\[M\]\)\{\\bm\{x\}\}\_\{t\_\{0\}\}=\(\\texttt\{\[M\]\},\\ldots,\\texttt\{\[M\]\}\)to a fully unmasked sequence𝒙tK∈\(𝒱∖\{\[M\]\}\)L\{\\bm\{x\}\}\_\{t\_\{K\}\}\\in\(\{\\mathcal\{V\}\}\\setminus\\\{\\texttt\{\[M\]\}\\\}\)^\{L\}\. At each stepkk, given the current partially masked sequence𝒙tk\{\\bm\{x\}\}\_\{t\_\{k\}\}, the model computes logitsℓk\\bm\{\\ell\}\_\{k\}for the per\-position posterior distribution for each masked positioni∈ℳ​\(𝒙tk\)i\\in\\mathcal\{M\}\(\{\\bm\{x\}\}\_\{t\_\{k\}\}\)and tokenw∈𝒱w\\in\{\\mathcal\{V\}\}\. An unmasking policyu\(⋅∣ℓk,𝒙tk\)u\(\\cdot\\mid\\bm\{\\ell\}\_\{k\},\{\\bm\{x\}\}\_\{t\_\{k\}\}\), which may be stochastic, then selects a set of positions𝒰k⊆ℳ​\(𝒙tk\)\\mathcal\{U\}\_\{k\}\\subseteq\\mathcal\{M\}\(\{\\bm\{x\}\}\_\{t\_\{k\}\}\)to reveal, producing the next partially masked sequence𝒙tk\+1\{\\bm\{x\}\}\_\{t\_\{k\+1\}\}\. Common choices foru\(⋅∣ℓk,𝒙tk\)u\(\\cdot\\mid\\bm\{\\ell\}\_\{k\},\{\\bm\{x\}\}\_\{t\_\{k\}\}\)include unmasking a fixed fraction of the remaining masks at each step\(Nie et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib13)\)and confidence\-based rules\(Ben\-Hamu et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib14); Kim et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib15); Patel et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib16)\)\.

The Hard Reset Problem\. After each inference step, MDMs discard the entire computational state used to choose the newly revealed tokens\. The next step starts again from𝒙tk\+1\{\\bm\{x\}\}\_\{t\_\{k\+1\}\}alone\. Thus, standard MDM inference treats each partially masked sequence as a fresh prediction problem—a*hard reset*—rather than as a continuation of an ongoing computation\. Because models can only perform a constant number of FLOPs in each forward pass, hard reset prevents the model from amortizing reasoning across steps effectively\. In the next section, we propose our solution to this problem: we learn a continuous latent state that is passed across the steps of MDM inference\.

## 3Learned Relay Representations

To address the*hard reset*problem, we introduce a continuous differentiable state that is carried across MDM inference steps and can circumvent the hard reset\.

### 3\.1Augmented State Trajectories

The training of MDMs proceeds by sampling a data point𝒙0∼pdata\{\\bm\{x\}\}\_\{0\}\\sim p\_\{\\mathrm\{data\}\}, a timet∼𝒰​\(0,1\)t\\sim\\mathcal\{U\}\(0,1\), and a partially masked sequence𝒙t\{\\bm\{x\}\}\_\{t\}under the noise schedule given the timettand the data point𝒙0\{\\bm\{x\}\}\_\{0\}\. During inference, we have a discretized time grid1=t0\>⋯\>tn=01=t\_\{0\}\>\\cdots\>t\_\{n\}=0, and the corresponding inference trajectory𝒙t0,…,𝒙tn\{\\bm\{x\}\}\_\{t\_\{0\}\},\\ldots,\{\\bm\{x\}\}\_\{t\_\{n\}\}obtained by using some unmasking policyuu, where𝒙t0=\{\[M\]\}L\{\\bm\{x\}\}\_\{t\_\{0\}\}=\\\{\\texttt\{\[M\]\}\\\}^\{L\}\. We wish to pass a continuous state forward across decoding steps, that can carry intermediate computations from the previous step which have not yet been realized as a decoded token\. We can break down this behavior into two primitives: a model must produce a relay state𝒉k\{\\bm\{h\}\}\_\{k\}at inference stepkk, and learn to consume that relay state at stepk\+1k\{\+\}1\.[Figure˜1](https://arxiv.org/html/2605.22967#S3.F1)shows a schematic of the augmented state trajectory produced by the model, where𝒔k=\(𝒙tk,𝒉k\)\{\\bm\{s\}\}\_\{k\}=\(\{\\bm\{x\}\}\_\{t\_\{k\}\},\{\\bm\{h\}\}\_\{k\}\)is the augmented state at stepkk\.

### 3\.2Training

Architecture\. We parameterize the augmented dynamics with a backbonefθf\_\{\\theta\}, relay moduleRθR\_\{\\theta\}, token embeddingEmbθ\\text\{\{Emb\}\}\_\{\\theta\}, and unembedding headUnEmbθ\\text\{\{UnEmb\}\}\_\{\\theta\}\(see[Figure˜1](https://arxiv.org/html/2605.22967#S3.F1)\)\. At stepkk, the model maps the current pair\(𝒙tk,𝒉k\)\(\{\\bm\{x\}\}\_\{t\_\{k\}\},\{\\bm\{h\}\}\_\{k\}\)to the next relay state and per\-position logits via

𝒉k\+1\\displaystyle\{\\bm\{h\}\}\_\{k\+1\}=fθ​\(Embθ​\(𝒙tk\)\+Rθ​\(𝒉k\)\),\\displaystyle=f\_\{\\theta\}\\\!\\left\(\\text\{\{Emb\}\}\_\{\\theta\}\(\{\\bm\{x\}\}\_\{t\_\{k\}\}\)\+R\_\{\\theta\}\(\{\\bm\{h\}\}\_\{k\}\)\\right\),ℓk\\displaystyle\\bm\{\\ell\}\_\{k\}=UnEmbθ​\(𝒉k\+1\),\\displaystyle=\\text\{\{UnEmb\}\}\_\{\\theta\}\(\{\\bm\{h\}\}\_\{k\+1\}\),\(2\)initialized with𝒉0=𝟎\{\\bm\{h\}\}\_\{0\}=\{\\bm\{0\}\}\. The per\-position posteriorpθi\(⋅∣𝒙tk,𝒉k\)p\_\{\\theta\}^\{i\}\(\\cdot\\mid\{\\bm\{x\}\}\_\{t\_\{k\}\},\{\\bm\{h\}\}\_\{k\}\)is read off fromℓk\\bm\{\\ell\}\_\{k\}by a softmax, exactly as in standard MDMs\.

Since we only care about the terminal state𝒙tn\{\\bm\{x\}\}\_\{t\_\{n\}\}, we continue to provide supervision using the same cross\-entropy loss as in standard MDMs, and train the model to produce useful relay states𝒉k\{\\bm\{h\}\}\_\{k\}that help improve predictionsKKsteps ahead using truncated BPTT\. Specifically, instead of sampling𝒙t\{\\bm\{x\}\}\_\{t\}as in standard MDMs, we start from an all masked sequence𝒙t0=\{\[M\]\}L\{\\bm\{x\}\}\_\{t\_\{0\}\}=\\\{\\texttt\{\[M\]\}\\\}^\{L\}and roll out under[Equation˜2](https://arxiv.org/html/2605.22967#S3.E2)together with an unmasking policyuu\(see below\), producing the augmented trajectory\(𝒙t0,𝒉0\),…,\(𝒙tn,𝒉n\)\(\{\\bm\{x\}\}\_\{t\_\{0\}\},\{\\bm\{h\}\}\_\{0\}\),\\ldots,\(\{\\bm\{x\}\}\_\{t\_\{n\}\},\{\\bm\{h\}\}\_\{n\}\)\. The total training loss is the expected sum of per\-step cross\-entropies over the trajectory:

ℒ​\(θ\)\\displaystyle\\mathcal\{L\}\(\\theta\)=𝔼𝒙0,ξ0:n−1\[∑k=0n−1∑i∈ℳ​\(𝒙tk\)−log⁡pθi​\(x0i∣𝒙tk,𝒉k\)\],\\displaystyle=\\mathop\{\\mathbb\{E\}\}\_\{\{\\bm\{x\}\}\_\{0\},\\,\\xi\_\{0:n\-1\}\}\\\!\\left\[\\,\\sum\_\{k=0\}^\{n\-1\}\\,\\sum\_\{i\\in\\mathcal\{M\}\(\{\\bm\{x\}\}\_\{t\_\{k\}\}\)\}\-\\log p^\{i\}\_\{\\theta\}\\\!\\left\(x\_\{0\}^\{i\}\\mid\{\\bm\{x\}\}\_\{t\_\{k\}\},\{\\bm\{h\}\}\_\{k\}\\right\)\\right\],\(3\)whereξ0:n−1\\xi\_\{0:n\-1\}denotes the exogenous randomness used by the unmasking policy along the rollout\. Unlike an externally observed conditioning variable,𝒉k\{\\bm\{h\}\}\_\{k\}is an internal artifact of the rollout, part of the computational trajectory rather than of the generated object\. At inference time each step carries forward the realized pair\(𝒙tk,𝒉k\)\(\{\\bm\{x\}\}\_\{t\_\{k\}\},\{\\bm\{h\}\}\_\{k\}\), but only𝒙tk\{\\bm\{x\}\}\_\{t\_\{k\}\}is eventually decoded into text, while𝒉k\{\\bm\{h\}\}\_\{k\}serves as a differentiable memory channel for future predictions\. The full procedure is summarized in[Algorithm˜1](https://arxiv.org/html/2605.22967#algorithm1); we derive the gradient estimator below\.

Input:model

fθf\_\{\\theta\}, relay module

RθR\_\{\\theta\}, unroll horizon

KK, unmasking policy

uu, training steps

NN, learning rate

η\\eta
1

2for*t∈\{1,…,N\}t\\in\\\{1,\\ldots,N\\\}*do

3if*t=1t=1orℳ​\(𝐳\)=∅\\mathcal\{M\}\(\{\\bm\{z\}\}\)=\\emptyset*then

4

𝒙0∼pdata,𝒛←\{\[M\]\}L,𝒉←𝟎\{\\bm\{x\}\}\_\{0\}\\sim p\_\{\\mathrm\{data\}\},\\penalty 10000\\ \{\\bm\{z\}\}\\leftarrow\\\{\\texttt\{\[M\]\}\\\}^\{L\},\\penalty 10000\\ \{\\bm\{h\}\}\\leftarrow\{\\bm\{0\}\}
5

6end if

7

L←0L\\leftarrow 0
8for*k∈\{0,…,K−1\}k\\in\\\{0,\\ldots,K\-1\\\}*do

9

𝒉←fθ​\(Embθ​\(𝒛\)\+Rθ​\(𝒉\)\)\{\\bm\{h\}\}\\leftarrow f\_\{\\theta\}\\\!\\left\(\\text\{\{Emb\}\}\_\{\\theta\}\(\{\\bm\{z\}\}\)\+R\_\{\\theta\}\(\{\\bm\{h\}\}\)\\right\)
10

ℓ←UnEmbθ​\(𝒉\)\\bm\{\\ell\}\\leftarrow\\text\{\{UnEmb\}\}\_\{\\theta\}\(\{\\bm\{h\}\}\)
L←L\+ℒ​\(ℓ,𝒙0\)L\\leftarrow L\+\\mathcal\{L\}\(\\bm\{\\ell\},\{\\bm\{x\}\}\_\{0\}\)
⊳\\trianglerightmasked positions only

11

𝒰∼u\(⋅∣ℓ,𝒛\)\\mathcal\{U\}\\sim u\(\\cdot\\mid\\bm\{\\ell\},\{\\bm\{z\}\}\)
12

zi←x0i​∀i∈𝒰z^\{i\}\\leftarrow x\_\{0\}^\{i\}\\penalty 10000\\ \\penalty 10000\\ \\forall i\\in\\mathcal\{U\}
13

14end for

15

θ←θ−η​∇θL\\theta\\leftarrow\\theta\-\\eta\\,\\nabla\_\{\\theta\}L
16

17end for

return*θ\\theta*

Algorithm 1RelayTraining\[M\]\[M\]\[M\]\[M\]𝒙tk\{\\bm\{x\}\}\_\{t\_\{k\}\}Embθ\\textnormal\{\{Emb\}\}\_\{\\theta\}fθf\_\{\\theta\}UnEmbθ\\textnormal\{\{UnEmb\}\}\_\{\\theta\}\[M\]\[M\]\[M\]𝒉k\{\\bm\{h\}\}\_\{k\}\+\+RθR\_\{\\theta\}\[M\]\[M\]\[M\]𝒙tk\+1\{\\bm\{x\}\}\_\{t\_\{k\+1\}\}Embθ\\textnormal\{\{Emb\}\}\_\{\\theta\}fθf\_\{\\theta\}UnEmbθ\\textnormal\{\{UnEmb\}\}\_\{\\theta\}\[M\]𝒉k\+1\{\\bm\{h\}\}\_\{k\+1\}\+\+RθR\_\{\\theta\}

Figure 1:Schematic ofRelayover two consecutive inference steps\. At each stepkk, the backbonefθf\_\{\\theta\}consumes the sum of embedded tokensEmbθ​\(𝒙tk\)\\text\{\{Emb\}\}\_\{\\theta\}\(\{\\bm\{x\}\}\_\{t\_\{k\}\}\)and the projected relay stateRθ​\(𝒉k\)R\_\{\\theta\}\(\{\\bm\{h\}\}\_\{k\}\), producing a hidden state𝒉k\+1\{\\bm\{h\}\}\_\{k\+1\}that is both unembedded into logits for the cross\-entropy loss and forwarded through the relay moduleRθR\_\{\\theta\}\(orangepath\) into the next step\. Tokens are progressively unmasked between steps \(e\.g\.\[M\]→\\texttt\{\[M\]\}\\\!\\to\\\!f at stepkk,\[M\]→\\texttt\{\[M\]\}\\\!\\to\\\!b, c at stepk\+1k\{\+\}1\), while𝒉\{\\bm\{h\}\}provides a continuous, differentiable channel for information that has not yet been committed to a discrete token\.
Constructing rollouts\. In order to perform truncated BPTT, we need to construct rollouts of the augmented state trajectory under an unmasking policyuu\. Given the current augmented state\(𝒙tk,𝒉k\)\(\{\\bm\{x\}\}\_\{t\_\{k\}\},\{\\bm\{h\}\}\_\{k\}\), one step of rollout proceeds as follows:

- •Position selection:Sample which positions to unmask,𝒰∼u\(⋅∣ℓk,𝒙tk\)\{\\mathcal\{U\}\}\\sim u\(\\cdot\\mid\\bm\{\\ell\}\_\{k\},\{\\bm\{x\}\}\_\{t\_\{k\}\}\)\. The policy may use the model’s own logitsℓk\\bm\{\\ell\}\_\{k\}\.
- •Token forcing:For eachi∈𝒰i\\in\\mathcal\{U\}, commit the token from ground truth:xtk\+1i=x0ix\_\{t\_\{k\+1\}\}^\{i\}=x\_\{0\}^\{i\}\.

We teacher\-force the token*values*\(rather than sampling from the model’s posteriorpθi\(⋅∣𝒙tk,𝒉k\)p\_\{\\theta\}^\{i\}\(\\cdot\\mid\{\\bm\{x\}\}\_\{t\_\{k\}\},\{\\bm\{h\}\}\_\{k\}\)\) because sampled values would inject errors that the rollout has no mechanism to correct\. The*position*sampler, by contrast, may use the model’s own posterior without affecting the ideal minimizer: in absence of the continuous channel this leaves the standard MDM training objective \([Equation˜1](https://arxiv.org/html/2605.22967#S2.E1)\) unchanged\(Kim et al\.,[2026](https://arxiv.org/html/2605.22967#bib.bib17)\), and for the augmented\-state trajectory the same argument applies but a formal proof requires additional assumptions and is more involved\.

Gradient estimation\. We now derive the gradient estimator for oneKK\-step window of the recurrence[Equation˜2](https://arxiv.org/html/2605.22967#S3.E2)\. Letξk\\xi\_\{k\}denote the exogenous randomness used in the sampled unmasking step atkk:

𝒰k\\displaystyle\{\\mathcal\{U\}\}\_\{k\}∼u\(⋅∣ℓk,𝒙tk\),andxtk\+1i←x0i∀i∈𝒰k\.\\displaystyle\\sim u\(\\cdot\\mid\\bm\{\\ell\}\_\{k\},\{\\bm\{x\}\}\_\{t\_\{k\}\}\),\\quad\\text\{and\}\\quad x\_\{t\_\{k\+1\}\}^\{i\}\\leftarrow x\_\{0\}^\{i\}\\quad\\forall i\\in\{\\mathcal\{U\}\}\_\{k\}\.\(4\)Conditioning on the realizedξ0:K−1\\xi\_\{0:K\-1\}, the per\-window loss is

ℒK​\(θ;𝒙0,ξ0:K−1\)\\displaystyle\\mathcal\{L\}\_\{K\}\(\\theta;\{\\bm\{x\}\}\_\{0\},\\xi\_\{0:K\-1\}\)=∑k=0K−1Lk​\(ℓk,𝒙0\),\\displaystyle=\\sum\_\{k=0\}^\{K\-1\}L\_\{k\}\(\\bm\{\\ell\}\_\{k\},\{\\bm\{x\}\}\_\{0\}\),\(5\)whereLkL\_\{k\}is the per\-step cross\-entropy at stepkk\(summed over the masked positions of𝒙tk\{\\bm\{x\}\}\_\{t\_\{k\}\}\), andℓk\\bm\{\\ell\}\_\{k\},𝒉k\+1\{\\bm\{h\}\}\_\{k\+1\}are computed from\(𝒙tk,𝒉k\)\(\{\\bm\{x\}\}\_\{t\_\{k\}\},\{\\bm\{h\}\}\_\{k\}\)via[Equation˜2](https://arxiv.org/html/2605.22967#S3.E2)\. The discrete update𝒙tk→𝒙tk\+1\{\\bm\{x\}\}\_\{t\_\{k\}\}\\to\{\\bm\{x\}\}\_\{t\_\{k\+1\}\}is treated as fixed after the rollout is sampled\. Equivalently, this estimator sets∂𝒙tk\+1/∂ℓk=0\\partial\{\\bm\{x\}\}\_\{t\_\{k\+1\}\}/\\partial\\bm\{\\ell\}\_\{k\}=0and does not differentiate through the sampled unmasking decisions\. The BPTT adjoints over the differentiable relay state are then defined by

λK\\displaystyle\\lambda\_\{K\}=0,\\displaystyle=0,λk\\displaystyle\\lambda\_\{k\}=\(∂𝒉kℓk\)⊤​∇ℓkLk​\(ℓk,𝒙0\)\+\(∂𝒉k𝒉k\+1\)⊤​λk\+1,k=K−1,…,0\.\\displaystyle=\\left\(\\partial\_\{\{\\bm\{h\}\}\_\{k\}\}\\bm\{\\ell\}\_\{k\}\\right\)^\{\\top\}\\nabla\_\{\\bm\{\\ell\}\_\{k\}\}L\_\{k\}\(\\bm\{\\ell\}\_\{k\},\{\\bm\{x\}\}\_\{0\}\)\+\\left\(\\partial\_\{\{\\bm\{h\}\}\_\{k\}\}\{\\bm\{h\}\}\_\{k\+1\}\\right\)^\{\\top\}\\lambda\_\{k\+1\},\\qquad k=K\-1,\\ldots,0\.\(6\)Throughout,∂𝒉kℓk\\partial\_\{\{\\bm\{h\}\}\_\{k\}\}\\bm\{\\ell\}\_\{k\}and∂θℓk\\partial\_\{\\theta\}\\bm\{\\ell\}\_\{k\}denote the*total*derivatives along the single\-step chain𝒉k→𝒉k\+1→ℓk\{\\bm\{h\}\}\_\{k\}\\to\{\\bm\{h\}\}\_\{k\+1\}\\to\\bm\{\\ell\}\_\{k\}, i\.e\.,∂𝒉kℓk=\(∂𝒉k\+1UnEmbθ\)​\(∂𝒉k𝒉k\+1\)\\partial\_\{\{\\bm\{h\}\}\_\{k\}\}\\bm\{\\ell\}\_\{k\}=\(\\partial\_\{\{\\bm\{h\}\}\_\{k\+1\}\}\\text\{\{UnEmb\}\}\_\{\\theta\}\)\(\\partial\_\{\{\\bm\{h\}\}\_\{k\}\}\{\\bm\{h\}\}\_\{k\+1\}\), and analogously forθ\\theta; the companion factor\(∂θ𝒉k\+1\)⊤​λk\+1\(\\partial\_\{\\theta\}\{\\bm\{h\}\}\_\{k\+1\}\)^\{\\top\}\\lambda\_\{k\+1\}below uses the*direct*partial of stepkk’s transition only \(𝒉k\{\\bm\{h\}\}\_\{k\}held fixed\)\. The boundaryλK=0\\lambda\_\{K\}=0therefore reads as “no downstream losses past stepK−1K\{\-\}1\.” The corresponding sampled gradient estimator is

∇θℒK\\displaystyle\\nabla\_\{\\theta\}\\mathcal\{L\}\_\{K\}=∑k=0K−1\[\(∂θℓk\)⊤​∇ℓkLk​\(ℓk,𝒙0\)⏟direct gradient from immediate cross\-entropy\+\(∂θ𝒉k\+1\)⊤​λk\+1⏟BPTT through relay state\]\.\\displaystyle=\\sum\_\{k=0\}^\{K\-1\}\\left\[\\begin\{aligned\} &\\underbrace\{\\left\(\\partial\_\{\\theta\}\\bm\{\\ell\}\_\{k\}\\right\)^\{\\top\}\\nabla\_\{\\bm\{\\ell\}\_\{k\}\}L\_\{k\}\(\\bm\{\\ell\}\_\{k\},\{\\bm\{x\}\}\_\{0\}\)\}\_\{\\text\{direct gradient from immediate cross\-entropy\}\}\+\\underbrace\{\\left\(\\partial\_\{\\theta\}\{\\bm\{h\}\}\_\{k\+1\}\\right\)^\{\\top\}\\lambda\_\{k\+1\}\}\_\{\\text\{BPTT through relay state\}\}\\end\{aligned\}\\right\]\.\(7\)For a two\-step truncation beginning at stepkk, we haveλk\+2=0\\lambda\_\{k\+2\}=0, so the only downstream adjoint is

λk\+1=\(∂𝒉k\+1ℓk\+1\)⊤​∇ℓk\+1Lk\+1​\(ℓk\+1,𝒙0\)\.\\displaystyle\\lambda\_\{k\+1\}=\\left\(\\partial\_\{\{\\bm\{h\}\}\_\{k\+1\}\}\\bm\{\\ell\}\_\{k\+1\}\\right\)^\{\\top\}\\nabla\_\{\\bm\{\\ell\}\_\{k\+1\}\}L\_\{k\+1\}\(\\bm\{\\ell\}\_\{k\+1\},\{\\bm\{x\}\}\_\{0\}\)\.\(8\)The two\-step gradient is therefore

∇θ\(Lk\+Lk\+1\)\\displaystyle\\nabla\_\{\\theta\}\\left\(L\_\{k\}\+L\_\{k\+1\}\\right\)=∑j=kk\+1\(∂θℓj\)⊤​∇ℓjLj​\(ℓj,𝒙0\)⏟direct gradient from immediate cross\-entropy\+\(∂θ𝒉k\+1\)⊤​λk\+1⏟BPTT through relay state\.\\displaystyle=\\underbrace\{\\sum\_\{j=k\}^\{k\+1\}\\left\(\\partial\_\{\\theta\}\\bm\{\\ell\}\_\{j\}\\right\)^\{\\top\}\\nabla\_\{\\bm\{\\ell\}\_\{j\}\}L\_\{j\}\(\\bm\{\\ell\}\_\{j\},\{\\bm\{x\}\}\_\{0\}\)\}\_\{\\text\{direct gradient from immediate cross\-entropy\}\}\+\\underbrace\{\\left\(\\partial\_\{\\theta\}\{\\bm\{h\}\}\_\{k\+1\}\\right\)^\{\\top\}\\lambda\_\{k\+1\}\}\_\{\\text\{BPTT through relay state\}\}\.
Thus, each step receives the local cross\-entropy gradient through its logitsℓk\\bm\{\\ell\}\_\{k\}, and the additional recurrent gradient is back\-propagated through the differentiable relay path𝒉k→𝒉k\+1\{\\bm\{h\}\}\_\{k\}\\rightarrow\{\\bm\{h\}\}\_\{k\+1\}\.

## 4Experiments

Through our experiments we seek to address the following research questions:

RQ1Does training to be forward\-thinking with BPTT improve performance and latency?

RQ2Does weight\-tyingEmbθ\\text\{\{Emb\}\}\_\{\\theta\}andUnEmbθ\\text\{\{UnEmb\}\}\_\{\\theta\}have an impact onRelay, sincefθf\_\{\\theta\}at the first layer must learn to consume theUnEmbθ\\text\{\{UnEmb\}\}\_\{\\theta\}\-aligned relay𝒉\{\\bm\{h\}\}from the last layer?

RQ3Can we efficiently adapt state\-of\-the\-art DLMs to use relay representations and improve their performance\-latency frontiers with negligible additional training FLOPs?

We first motivate the design choices forRelaywith a thorough ablation on Sudoku\. Subsequently, we post\-train Fast\-dLLM v2\(Wu et al\.,[2025b](https://arxiv.org/html/2605.22967#bib.bib11)\), a state\-of\-the\-art DLM, demonstrating the effectiveness ofRelayon model adaptation for DLMs\.

### 4\.1Sudoku

Dataset\. The objective of a Sudoku puzzle is to fill in a 9x9 board \(of nine 3x3 sub\-squares\) with digits 1\-9 such that each row, column, and 3x3 square contains all the nine unique digits\. A puzzle has a minimum of 17 clues, which is a mathematical prerequisite for it to have a unique solution\(McGuire et al\.,[2012](https://arxiv.org/html/2605.22967#bib.bib18)\)\.

Setup\. We choose the Sudoku\-Extreme dataset\(Wang et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib19)\)as a challenging benchmark that allows us to focus on modeling choices without the risk of overfitting\. We release a derived version222[https://huggingface\.co/datasets/brozonoyer/sapientinc\-sudoku\-extreme\-timvink\-sudoku\-solver](https://huggingface.co/datasets/brozonoyer/sapientinc-sudoku-extreme-timvink-sudoku-solver)that augments each puzzle with a step\-by\-step solver trajectory, step count, and the set of deduction strategies invoked, obtained by running the Sudoku solver ofVink \([2024](https://arxiv.org/html/2605.22967#bib.bib20)\)333[https://github\.com/timvink/sudoku\-solver](https://github.com/timvink/sudoku-solver)over every example; thestrategiesfield underpins the*deduction\-only*evaluation slice in[Table˜1](https://arxiv.org/html/2605.22967#S4.T1)\. For the experiments in[Figure˜2](https://arxiv.org/html/2605.22967#S4.F2)we evaluate on the first 50k examples of the test split, which are representative in difficulty \([Section˜A\.1\.1](https://arxiv.org/html/2605.22967#A1.SS1.SSS1)\)\. All methods use the same small Transformer backbone \(∼7\\sim 7M parameters; full architecture in[Appendix˜A](https://arxiv.org/html/2605.22967#A1)\) with rotary position embeddings and are trained to convergence, in line with our experimental protocol of comparing methods by their test\-time performance versus latency frontiers\.444We use the xLM package \([https://github\.com/dhruvdcoder/xlm\-core](https://github.com/dhruvdcoder/xlm-core)\) for all small\-scale experiments on Sudoku\.Our predictor uses top\-probabilities as confidence valuescic\_\{i\}, sorts by increasing1−ci1\-c\_\{i\}, and unmasks all positions whose cumulative confidence falls below a thresholdτ\\tau, falling back on the argmax if no such position exists\. ForRelay’s on\-policy training rollout we use a stochastic thresholdτ∼𝒩​\(μ=0\.15,σ=0\.1\)\\tau\\sim\{\\mathcal\{N\}\}\(\\mu=0\.15,\\sigma=0\.1\)for robustness \(the threshold is a hyperparameter of the sampling decision𝒰∼u\(⋅∣ℓ,𝒛\)\{\\mathcal\{U\}\}\\sim u\(\\cdot\\mid\\bm\{\\ell\},\{\\bm\{z\}\}\)in[Algorithm˜1](https://arxiv.org/html/2605.22967#algorithm1)line 10\)\.

Baselines and ablations\. We compare four training objectives that progressively turn on the components of[Algorithm˜1](https://arxiv.org/html/2605.22967#algorithm1), each instantiated with both*tied*and*untied*embeddings \(whetherEmbθ\\text\{\{Emb\}\}\_\{\\theta\}andUnEmbθ\\text\{\{UnEmb\}\}\_\{\\theta\}share weights\)\.MLM\(Sahoo et al\.,[2024](https://arxiv.org/html/2605.22967#bib.bib3); Shi et al\.,[2024](https://arxiv.org/html/2605.22967#bib.bib4)\)is standard uniform masked diffusion: a single forward pass per training step \(K=1K\{=\}1\), no relay \(Rθ≡0R\_\{\\theta\}\\\!\\equiv\\\!0, so𝒉←fθ​\(Embθ​\(𝒛\)\)\{\\bm\{h\}\}\\leftarrow f\_\{\\theta\}\(\\text\{\{Emb\}\}\_\{\\theta\}\(\{\\bm\{z\}\}\)\)\), and no inner rollout\. Instead, the masked input𝒛\{\\bm\{z\}\}is drawn fresh each step by samplingt∼𝒰​\(0,1\)t\\sim\\mathcal\{U\}\(0,1\)and masking each token of𝒙0\{\\bm\{x\}\}\_\{0\}independently with probabilitytt\. The remaining three objectives all shareRelay’s on\-policy*position*sampleru\(⋅∣ℓ,𝒛\)u\(\\cdot\\mid\\bm\{\\ell\},\{\\bm\{z\}\}\)\([Algorithm˜1](https://arxiv.org/html/2605.22967#algorithm1)line 10\) and teacher\-force the committed positions to the values in𝒙0\{\\bm\{x\}\}\_\{0\}between passes \(line 11\), differing only in whether and how the relay channel is used \(a related rollout training procedure is studied byKim et al\.,[2026](https://arxiv.org/html/2605.22967#bib.bib17)\)\.RolloutunrollsK=2K\{=\}2inner steps but keepsRθ≡0R\_\{\\theta\}\\\!\\equiv\\\!0so each step recomputes𝒉\{\\bm\{h\}\}fromEmbθ​\(𝒛\)\\text\{\{Emb\}\}\_\{\\theta\}\(\{\\bm\{z\}\}\)alone; this isolates the contribution of*which*positions get committed between forward passes\.Relay\(sg\)additionally enables the relay pathRθ​\(𝒉\)R\_\{\\theta\}\(\{\\bm\{h\}\}\)inside the inner loop but stop\-gradients𝒉\{\\bm\{h\}\}before feeding it back, so the backbone receives no temporal credit across theKKsteps\.Relayis the full method:K=2K\{=\}2BPTT through the relay \([Algorithm˜1](https://arxiv.org/html/2605.22967#algorithm1)\)\. At inference we sweep deterministic thresholdsτ∈\{0\.05,0\.10,0\.15,0\.20,0\.25\}\\tau\\in\\\{0\.05,0\.10,0\.15,0\.20,0\.25\\\}and trace each method’s accuracy\-NFE frontier; lowerτ\\taucommits fewer cells per forward pass and so spends more NFEs\.

![Refer to caption](https://arxiv.org/html/2605.22967v1/x1.png)![Refer to caption](https://arxiv.org/html/2605.22967v1/x2.png)◆\\blacklozengeMLM■\\blacksquareRollout∙\\bulletRelay\(sg\)▲\\blacktriangleRelaytieduntied∙\\bulletsize∝τ\\,\\propto\\,\\tau
Figure 2:Accuracy\-NFE frontier on Sudoku\-Extreme validation\. Each curve traces a single training method as we sweep the inference confidence thresholdτ∈\{0\.05,0\.10,0\.15,0\.20,0\.25\}\\tau\\in\\\{0\.05,0\.10,0\.15,0\.20,0\.25\\\}\. A lowerτ\\taucommits fewer cells per forward pass and so spends more NFEs \(rightward\), and vice\-versa\. Shaded ribbons denote±1\\pm 1sample standard deviation across three training seeds\.Results and analysis\.[Figure˜2](https://arxiv.org/html/2605.22967#S4.F2)plots validation metrics at the latest checkpoint for each seed\. Replacing uniform masking \(MLM\) with the on\-policy confidence\-thresholded sampler under teacher forcing of the unmasked values \(Rollout\) yields the first improvement\. Turning the relay channel on \(Relay\(sg\)\) contributes the next big jump, highlighting the importance of a soft state carried between forward passes\. Finally, replacing the stop\-gradient withK=2K\{=\}2BPTT through the relay \(Relay,[Algorithm˜1](https://arxiv.org/html/2605.22967#algorithm1)\) yields a further separation and the best accuracy\-NFE frontier across thresholds\.

We are able to trace this last separation ofRelayoverRelay\(sg\) to the fact that, at the same thresholdτ\\tau,Relaycommits more cells per forward pass while keeping the partial board legal— where a board is*legal*when no row, column, or 3×\\times3 box yet contains a repeated digit\. Legality is a necessary condition for correctness, and is well\-defined at every intermediate denoising step, not only at the end\. Since the studied architectures cannot perform recursive search, we restrict this qualitative analysis to a*deduction\-only*cohort of 2,000 test puzzles for which the solver uses only human\-like deduction strategies \(Advanced or Master heuristics; cohort construction detailed in[Section˜A\.1\.1](https://arxiv.org/html/2605.22967#A1.SS1.SSS1)\)\.

At the matched thresholdτ=0\.15\\tau=0\.15,Relayproduces a fully legal final board74\.8%74\.8\\%of the time versus70\.7%70\.7\\%forRelay\(sg\) \(\+4\.1\+4\.1pp\), and incurs15%15\\%fewer row/column/box violations across the rollout \(0\.900\.90vs\.1\.061\.06on average per puzzle\); these legality gains are uniform across the Advanced \(\+4\.0\+4\.0pp\) and Master \(\+4\.1\+4\.1pp\) strata\. In other words,BPTT teaches the relay to keep the partial board self\-consistent under more aggressive unmasking: at the same confidence thresholdτ\\tau,Relaycommits more cells per forward pass while still honoring the row/column/box constraints, so the rollout reaches the same accuracy in fewer total forward passes — producing the strict outward shift of the\(τ→accuracy\-NFE\)\(\\tau\\to\\text\{accuracy\-NFE\}\)frontier in[Figure˜2](https://arxiv.org/html/2605.22967#S4.F2)\.[Table˜1](https://arxiv.org/html/2605.22967#S4.T1)reports the corresponding exact match and mean NFE atτ=0\.15\\tau=0\.15on both the unfiltered test split and the deduction\-only cohort:Relayattains the highest exact match and the lowest mean NFE in every \(slice, tying\) cell, with a\+4\+4to\+6\+6pp gain overRelay\(sg\) at uniformly lower NFE\. Tying versus untyingEmbθ\\text\{\{Emb\}\}\_\{\\theta\}andUnEmbθ\\text\{\{UnEmb\}\}\_\{\\theta\}has only a marginal effect on any objective \(≤3\\leq 3pp exact match across all rows\), consistent with the small Sudoku vocabulary leaving the residual stream ample capacity to carry both predictive and relay\-bearing information\.

Table 1:Sudoku exact match and mean NFE atτ=0\.15\\tau=0\.15\.*Unfiltered*reports performance on puzzles iterated from the test split in dataset order;*deduction\-only*restricts to puzzles whose solver trace requires Advanced/Master heuristics \(no recursive backtracking\)\. Accuracies are % exact match, with sample s\.d\. across 3 training seeds\. See[Section˜A\.1\.1](https://arxiv.org/html/2605.22967#A1.SS1.SSS1)for more details\.
### 4\.2Pretrained Model Adaptation: Fast\-dLLM v2

Next, we investigate whether state\-of\-the\-art DLMs can be efficiently adapted intoRelaydiffusion models with a limited amount of finetuning, and whether this adaptation can improve their accuracy\-latency frontiers\.

Base model\. As our base model, we choose Fast\-dLLM v2 \(1\.5B parameters\)\(Wu et al\.,[2025b](https://arxiv.org/html/2605.22967#bib.bib11)\), a state\-of\-the\-art DLM adapted from Qwen2\.5\(Yang et al\.,[2024](https://arxiv.org/html/2605.22967#bib.bib21)\)by finetuning on the LLaMA\-Nemotron dataset\(Bercovich et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib22)\)\.

Training\. ForRelayadaptation we apply supervised fine\-tuning to all parameters for200200optimizer steps at effective batch size3232on a 60,000\-example mixture of filtered OpenCodeInstruct and OpenMathInstruct\-2 examples with a40/6040\{/\}60code/math proportion \(dataset and hardware details in[Section˜A\.2\.3](https://arxiv.org/html/2605.22967#A1.SS2.SSS3)\)\. To make[Algorithm˜1](https://arxiv.org/html/2605.22967#algorithm1)compatible with state\-of\-the\-art DLMs that combine block\-autoregressive decoding with KV caching, we make two careful adaptations to the on\-policy rollout\. First, we run theK=2K\{=\}2relay rollout*only inside the active block*of Fast\-dLLM v2’s BD3\-LM\-style doubled \(block\-causal⊕\\oplusblock\-bidirectional\) attention\(Arriola et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib9); Wu et al\.,[2025b](https://arxiv.org/html/2605.22967#bib.bib11)\), leaving previously decoded blocks frozen so their inter\-block KV cache is reused unchanged across both passes\. Second, within the active block we update the relay state𝒉\{\\bm\{h\}\}*only at positions that are still masked*: clean \(already\-committed\) sub\-block tokens contribute attention but their relay entries are not overwritten, which keeps within\-block sub\-block KV cache entries valid as the block fills in\.

Table 2:Pretrained adaptation on Fast\-dLLM\-v2 \(1\.5B\), evaluated at threshold0\.850\.85\. Average NFE is computed as the mean per\-example count of active denoising forward calls during batched sample generation, excluding prompt prefill and final cache\-update next\-token forwards\. Bold values are selected among adapted rows only, excluding the off\-the\-shelf baseline\.Evaluation\. Inference follows Fast\-dLLM v2’s confidence\-based parallel decoding\(Wu et al\.,[2025b](https://arxiv.org/html/2605.22967#bib.bib11)\): within each block, the backbone applies the token\-shift head so masked positions are read from the preceding token’s logit row, samples are drawn with top\-ppfiltering \(p=0\.95p\{=\}0\.95, temperature0\), and a position unmasks when the probability of its sampled token exceeds a confidence thresholdτ\\tau, while the argmax masked position in each active sub\-block is always unmasked so every forward makes progress\. We use block length3232, sub\-block length88, andτ=0\.85\\tau\{=\}0\.85for all HumanEval/MBPP numbers below \(including NFE in[Table˜2](https://arxiv.org/html/2605.22967#S4.T2)\)\. The*Plus*columns report HumanEval\+ and MBPP\+ from EvalPlus\(Liu et al\.,[2023](https://arxiv.org/html/2605.22967#bib.bib23)\)—expanded unit\-test suites released with the EvalPlus framework555[https://github\.com/evalplus/evalplus](https://github.com/evalplus/evalplus)—in the same*Base*/*Plus*layout used for code results inWu et al\. \([2025b](https://arxiv.org/html/2605.22967#bib.bib11)\)\.

As in Sudoku,Relaypushes the accuracy\-NFE frontier here: it attains the best raw NFE among adapted methods on both HumanEval and MBPP, while also reaching the best accuracies\. Notably,on HumanEval,Relayeven surpasses the vanilla SFT accuracy at 32% less NFEs\(88\.388\.3vs\.130\.7130\.7\), demonstrating that theRelayimproves both accuracy and the number of denoising steps required to reach it\.

#### 4\.2\.1Training memory overhead

A natural concern is that BPTT throughK=2K\{=\}2forward passes inflates training memory\.[Figure˜3](https://arxiv.org/html/2605.22967#S4.F3)profiles one micro\-step on an A100 80GB\. Each regime is shown with two curves: the solid trace samples live GPU memory at every transformer\-layer hook, and the dashed trace is its running maximum, a high\-water mark whose final value is the peak the run actually demanded\. Thus a short\-lived allocation can lift the dashed trace even if it is freed before the next solid\-line sample\. The largest such transient—and the binding peak of the whole micro\-step in both regimes—is the cross\-entropy backward through the vocabulary\-projection head \(lm\_head\), which materializes aB×T×VB\\\!\\times\\\!T\\\!\\times\\\!Vfp32 grad\-of\-logits buffer at the start ofbwd\.

Relay’s second forward raises the live trace by≈5\\approx\\\!5GiB throughfwd2: the saved activations of forward 1 and the relay state𝒉\{\\bm\{h\}\}coexist with forward 2 to route credit through both passes \([Algorithm˜1](https://arxiv.org/html/2605.22967#algorithm1)\)\. Most of that elevation is autograd intermediates rather than saved\-for\-backward state, and PyTorch releases it in a single step before thelm\_headspike fires—live drops by≈7\\approx\\\!7GiB forRelayversus≈2\.7\\approx\\\!2\.7GiB for vanilla, leaving the two regimes within≈0\.5\\approx\\\!0\.5GiB of each other just before the spike\. Adding the spike yields nearly identical peaks,20\.120\.1GiB forRelayversus21\.221\.2GiB for vanilla SFT—in fact,Relay’s larger pre\-spike drop edges its peak slightly below vanilla’s\. BPTT throughK=2K\{=\}2therefore does not double peak memory in this setup \(gradient checkpointing, ZeRO\-3, non\-fused CE\), and we expect the same whenever thelm\_headbackward dominates the activation footprint\. Per\-phase numbers and the profiling protocol are deferred to[Appendix˜B](https://arxiv.org/html/2605.22967#A2)\.

![Refer to caption](https://arxiv.org/html/2605.22967v1/x3.png)Figure 3:GPU memory during one training micro\-step of Fast\-dLLM v2 on an A100 80GB\. Solid lines show the live GPU memory at every decoder\-layer forward/backward hook\. Dashed lines show the running maximum of live memory within the same micro\-step \(high\-water mark\)\. Phase labels \(fwd,fwd2,bwd\) mark each phase’s plateau\.Relaycarries higher live memory throughfwd2, but its peak \(≈20\.1\\approx\\\!20\.1GiB\) lands within≈1\\approx\\\!1GiB of vanilla SFT’s \(≈21\.2\\approx\\\!21\.2GiB\); see main text and[Appendix˜B](https://arxiv.org/html/2605.22967#A2)for the mechanism\.

## 5Related Work

Discrete diffusion models\(Austin et al\.,[2021](https://arxiv.org/html/2605.22967#bib.bib1)\), which apply the iterative denoising principles of continuous diffusion\(Ho et al\.,[2020](https://arxiv.org/html/2605.22967#bib.bib24); Song et al\.,[2020](https://arxiv.org/html/2605.22967#bib.bib25)\)to categorical sequences, have emerged as a strong framework for language modeling\. In particular, Masked Diffusion Models \(MDMs\)\(Sahoo et al\.,[2024](https://arxiv.org/html/2605.22967#bib.bib3); Shi et al\.,[2024](https://arxiv.org/html/2605.22967#bib.bib4)\), which generate sequences by iterative unmasking have been shown to scaled well to larger models sizes\(Nie et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib13); Ye et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib26); Wu et al\.,[2025a](https://arxiv.org/html/2605.22967#bib.bib10),[b](https://arxiv.org/html/2605.22967#bib.bib11)\)\. The same diffusion style training that makes MDMs simple also limits what can be communicated between denoising steps: rich internal representations are collapsed into sampled tokens before the next step begins\. Some recent works term the collapse of internal information as a “sampling wall” or “information island”\(Jo et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib27); Xia et al\.,[2026](https://arxiv.org/html/2605.22967#bib.bib28)\)\.

To address this, several recent approaches use a continuous relaxation or augmenteded state trajectories\. CADD\(Zheng et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib29)\)pairs each position with a continuous diffusion process\. Soft\-Masked Diffusion\(Hersche et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib30)\)pass output distributions or top\-kkpredictions from the previous step back into the input to the model for the next step\. CADD\(Zheng et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib29)\)augments each discrete variable with a continuous variable that is trained using a continuous diffusion process in the embedding space\. VADD\(Xie et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib31)\), on the other hand, trains a VAE atop discrete diffusion\. All these approaches rely on leveraging a continuous diffusion process to carry more information across steps even though we ultimately only care about the discrete variables\. In contrast, our approach provides supervision through the discrete variables only\.

MetaState\(Xia et al\.,[2026](https://arxiv.org/html/2605.22967#bib.bib28)\)and Loopholing\(Jo et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib27)\)introduce a continuous pathway that carries hidden state across steps and train it without relying on continuous diffusion, which is quite similar to our approach\. MetaState adds a fixed\-size working memory to frozen dLLMs and trains it over multi\-step denoising rollouts\. Loopholing, on the other hand, simply injects the hidden state from the previous step into the input of the current step, like ourRelay\(sg\) setting in the ablations, and trains the entire model\. This allows the model to*learn to use*the hidden state for future predictions\.Relaygoes one step further by training the hidden state end\-to\-end via BPTT which allows the model to also*learn to save*the hidden state for future predictions\.

## 6Discussion

Summary of results\. Masked diffusion models suffer from a*hard reset*between denoising steps: the Transformer computes rich hidden states at every position—including those still masked—but discards them at the end of each forward pass, so the only information that persists is the discrete tokens just committed\.Relayaddresses this by carrying the last\-layer hidden states forward as a learned relay and training it end\-to\-end via truncated BPTT, so the model is explicitly rewarded for writing hidden states that will be useful to future denoising steps\. Empirically, the three components that constituteRelay—a rollout\-based training procedure, passing the hidden state forward across denoising steps, and training the hidden state end\-to\-end via BPTT—each push the performance\-latency frontier on their own, and combine constructively\. On Sudoku\-Extreme, the full method attained the best accuracy\-per\-NFE point on the Pareto frontier \([Figure˜2](https://arxiv.org/html/2605.22967#S4.F2)\); on Fast\-dLLM v2 it outperformed standard supervised fine\-tuning on coding tasks while reducing inference latency by up to32%32\\%\.

Limitations\.Relayintroduces two computational trade\-offs\. First, the relay mechanism adds a small per\-step overhead for reading and writing the continuous relay state, though the reduced number of forward passes needed to reach a given accuracy can still yield a net inference\-latency improvement\. Second, two\-step BPTT during training increases live activation memory and compute per each training step\. In our Fast\-dLLM v2 profile, however, this remains below the peak set by the final vocabulary projection backward pass \([Section˜4\.2\.1](https://arxiv.org/html/2605.22967#S4.SS2.SSS1),[Appendix˜B](https://arxiv.org/html/2605.22967#A2)\), leaving observed peak GPU memory nearly unchanged\. Overall, whileRelayrequires more training time than vanilla MLM training, this training\-time gap can be amortized by improvements to the inference\-time accuracy\-latency frontier, and narrowed with more careful engineering\.

Outlook and future work\.Relayis a meaningful step towards a non\-greedy, forward\-thinking approach to iterative non\-autoregressive generation, and there are several natural follow\-up directions\. The relay state gives a diffusion model a continuous substrate on which to carry intermediate computation across denoising steps; understanding what this state encodes, and whether it can be probed or steered, is a promising direction for interpreting and improving latent reasoning in MDMs\. Because the relay mechanism is largely architecture\- and modality\-agnostic, applying it beyond text—for example, to image or molecular discrete diffusion—is also a natural next step\.

## 7Conclusion

We introducedRelayto address the hard reset problem in MDMs by passing a continuous, differentiable latent state across inference steps\. By training a relay channel via truncated BPTT, we demonstrated that discrete diffusion models can explicitly optimize intermediate representations for future unmasking decisions, advancing the performance\-latency Pareto frontier\.

## Acknowledgments

DP, BR, and AM thank Michael Boratko for helpful initial discussions\. DP and BR acknowledge support from IBM under IBM Research Collaboration Agreement No\. W1668553 and from the National Science Foundation under grant IIS\-2106391\. NB acknowledges support from an NSF Graduate Research Fellowship, Quad Fellowship, and Mercor Graduate Fellowship\. TGJR acknowledges support provided, in part, by the Province of Ontario, the Government of Canada through CIFAR, the Vector Institute for Artificial Intelligence, and by the Digital Research Alliance of Canada \([alliancecan\.ca](https://arxiv.org/html/2605.22967v1/alliancecan.ca)\)\.

## References

- Austin et al\. \[2021\]Jacob Austin, Daniel D\. Johnson, Jonathan Ho, Daniel Tarlow, and Rianne van den Berg\.Structured denoising diffusion models in discrete state\-spaces\.In*Neural Information Processing Systems*, pages 17981–17993, 2021\.
- Campbell et al\. \[2022\]Andrew Campbell, Joe Benton, Valentin De Bortoli, Tom Rainforth, George Deligiannidis, and A\. Doucet\.A continuous time framework for discrete denoising models\.In*Neural Information Processing Systems*, pages 28266–28279\. Neural Information Processing Systems Foundation, Inc\. \(NeurIPS\), 2022\.doi:10\.48550/arXiv\.2205\.14987\.URL[https://openreview\.net/forum?id=DmT862YAieY](https://openreview.net/forum?id=DmT862YAieY)\.
- Sahoo et al\. \[2024\]Subham Sekhar Sahoo, Marianne Arriola, Yair Schiff, Aaron Gokaslan, Edgar Marroquin, Justin T\. Chiu, Alexander Rush, and Volodymyr Kuleshov\.Simple and effective masked diffusion language models, 2024\.URL[http://arxiv\.org/abs/2406\.07524](http://arxiv.org/abs/2406.07524)\.
- Shi et al\. \[2024\]Jiaxin Shi, Kehang Han, Zhe Wang, Arnaud Doucet, and Michalis K\. Titsias\.Simplified and generalized masked diffusion for discrete data, 2024\.URL[http://arxiv\.org/abs/2406\.04329](http://arxiv.org/abs/2406.04329)\.
- Gatmiry et al\. \[2024\]Khashayar Gatmiry, Nikunj Saunshi, Sashank J\. Reddi, Stefanie Jegelka, and Sanjiv Kumar\.Can looped transformers learn to implement multi\-step gradient descent for in\-context learning?In Ruslan Salakhutdinov, Zico Kolter, Katherine Heller, Adrian Weller, Nuria Oliver, Jonathan Scarlett, and Felix Berkenkamp, editors,*International Conference on Machine Learning*, volume 235 of*Proceedings of Machine Learning Research*, pages 15130–15152\. PMLR, 21–27 Jul 2024\.doi:10\.48550/arXiv\.2410\.08292\.URL[https://proceedings\.mlr\.press/v235/gatmiry24b\.html](https://proceedings.mlr.press/v235/gatmiry24b.html)\.
- Saunshi et al\. \[2025\]Nikunj Saunshi, Nishanth Dikkala, Zhiyuan Li, Sanjiv Kumar, and Sashank J\. Reddi\.Reasoning with latent thoughts: On the power of looped transformers\.In*International Conference on Learning Representations*, 2025\.doi:10\.48550/arXiv\.2502\.17416\.
- Li et al\. \[2024\]Zhiyuan Li, Hong Liu, Denny Zhou, and Tengyu Ma\.Chain of thought empowers transformers to solve inherently serial problems\.In*International Conference on Learning Representations*, 2024\.doi:10\.48550/arXiv\.2402\.12875\.
- Werbos \[1990\]Paul J\. Werbos\.Backpropagation through time: what it does and how to do it\.*Proceedings of the IEEE*, 78\(10\):1550–1560, 1990\.doi:10\.1109/5\.58337\.
- Arriola et al\. \[2025\]Marianne Arriola, Aaron Gokaslan, Justin T Chiu, Zhihan Yang, Zhi\-Hong Qi, Jiaqi Han, S\. Sahoo, and V\. Kuleshov\.Block diffusion: Interpolating between autoregressive and diffusion language models, 2025\.URL[http://arxiv\.org/abs/2503\.09573](http://arxiv.org/abs/2503.09573)\.
- Wu et al\. \[2025a\]Chengyue Wu, Hao Zhang, Shuchen Xue, Zhijian Liu, Shizhe Diao, Ligeng Zhu, Ping Luo, Song Han, and Enze Xie\.Fast\-dllm: Training\-free acceleration of diffusion llm by enabling kv cache and parallel decoding\.*arXiv\.org*, 2025a\.doi:10\.48550/arXiv\.2505\.22618\.
- Wu et al\. \[2025b\]Chengyue Wu, Hao Zhang, Shuchen Xue, Shizhe Diao, Yonggan Fu, Zhijian Liu, Pavlo Molchanov, Ping Luo, Song Han, and Enze Xie\.Fast\-dllm v2: Efficient block\-diffusion llm, 2025b\.URL[https://arxiv\.org/abs/2509\.26328](https://arxiv.org/abs/2509.26328)\.
- Zheng et al\. \[2024\]Kaiwen Zheng, Yongxin Chen, Hanzi Mao, Ming\-Yu Liu, Jun Zhu, and Qinsheng Zhang\.Masked diffusion models are secretly time\-agnostic masked models and exploit inaccurate categorical sampling\.In*International Conference on Learning Representations*, 2024\.doi:10\.48550/arXiv\.2409\.02908\.URL[https://openreview\.net/forum?id=CTC7CmirNr](https://openreview.net/forum?id=CTC7CmirNr)\.
- Nie et al\. \[2025\]Shen Nie, Fengqi Zhu, Chao You, Xiaojie Zhang, Jingyang Ou, and Jun Zhu\.LLaDA: Large language diffusion with autoregressive initialization, 2025\.URL[http://arxiv\.org/abs/2502\.09992](http://arxiv.org/abs/2502.09992)\.
- Ben\-Hamu et al\. \[2025\]Heli Ben\-Hamu, Itai Gat, Daniel Severo, Niklas Nolte, and Brian Karrer\.Accelerated Sampling from Masked Diffusion Models via Entropy Bounded Unmasking, 2025\.
- Kim et al\. \[2025\]Jaeyeon Kim, Kulin Shah, Vasilis Kontonis, Sham M\. Kakade, and Sitan Chen\.Train for the worst, plan for the best: Understanding token ordering in masked diffusions\.In*International Conference on Machine Learning*, 2025\.doi:10\.48550/arXiv\.2502\.06768\.URL[https://openreview\.net/forum?id=DjJmre5IkP](https://openreview.net/forum?id=DjJmre5IkP)\.
- Patel et al\. \[2025\]Dhruvesh Patel, Tahira Naseem, Gaurav Pandey, M\. Sultan, Andrew McCallum, and Ramón Fernandez\.Improved sampling from masked diffusion models with position contrastive guidance\.In*NeurIPS 2025 Workshop on Structured Probabilistic Inference & Generative Modeling*, 2025\.URL[https://openreview\.net/forum?id=e0WmOrWbtc](https://openreview.net/forum?id=e0WmOrWbtc)\.
- Kim et al\. \[2026\]Jaeyeon Kim, Jonathan Geuter, David Alvarez\-Melis, S\. Kakade, and Sitan Chen\.Stop training for the worst: Progressive unmasking accelerates masked diffusion training, 2026\.URL[http://arxiv\.org/abs/2602\.10314](http://arxiv.org/abs/2602.10314)\.
- McGuire et al\. \[2012\]Gary McGuire, Bastian Tugemann, and Gilles Civario\.There is no 16\-clue sudoku: Solving the sudoku minimum number of clues problem via hitting set enumeration\.*Experimental Mathematics*, 23\(2\):190–217, 2012\.doi:10\.1080/10586458\.2013\.870056\.
- Wang et al\. \[2025\]Guan Wang, Jin Li, Yuhao Sun, Xing Chen, Chang\-Le Liu, Yue Wu, Meng Lu, Sen Song, and Yasin Abbasi\-Yadkori\.Hierarchical reasoning model\.*arXiv\.org*, 2025\.doi:10\.48550/arXiv\.2506\.21734\.
- Vink \[2024\]Tim Vink\.sudoku\-solver: a python Sudoku solver that traces the human\-style strategies it uses\.[https://github\.com/timvink/sudoku\-solver](https://github.com/timvink/sudoku-solver), 2024\.
- Yang et al\. \[2024\]Qwen An Yang, Baosong Yang, Beichen Zhang, Binyuan Hui, Bo Zheng, Bowen Yu, Chengyuan Li, Dayiheng Liu, Fei Huang, Guanting Dong, et al\.Qwen2\.5 technical report, 2024\.URL[https://arxiv\.org/abs/2412\.15115](https://arxiv.org/abs/2412.15115)\.
- Bercovich et al\. \[2025\]A\. Bercovich, Itay Levy, Izik Golan, Mohammad Dabbah, Ran El\-Yaniv, Omri Puny, Ido Galil, Zach Moshe, Tomer Ronen, Najeeb Nabwani, et al\.Llama\-nemotron: Efficient reasoning models\.*arXiv\.org*, 2025\.doi:10\.48550/arXiv\.2505\.00949\.
- Liu et al\. \[2023\]Jiawei Liu, Chunqiu Steven Xia, Yuyao Wang, and Lingming Zhang\.Is your code generated by ChatGPT really correct? rigorous evaluation of large language models for code generation\.In*Neural Information Processing Systems*, pages 21558–21572\. Neural Information Processing Systems Foundation, Inc\. \(NeurIPS\), 2023\.doi:10\.52202/075280\-0943\.URL[https://openreview\.net/forum?id=1qvx610Cu7](https://openreview.net/forum?id=1qvx610Cu7)\.
- Ho et al\. \[2020\]Jonathan Ho, Ajay Jain, and Pieter Abbeel\.Denoising diffusion probabilistic models, 2020\.URL[http://arxiv\.org/abs/2006\.11239](http://arxiv.org/abs/2006.11239)\.
- Song et al\. \[2020\]Yang Song, Jascha Sohl\-Dickstein, Diederik P\. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole\.Score\-based generative modeling through stochastic differential equations, 2020\.URL[http://arxiv\.org/abs/2011\.13456](http://arxiv.org/abs/2011.13456)\.
- Ye et al\. \[2025\]Jiacheng Ye, Zhihui Xie, Lin Zheng, Jiahui Gao, Zirui Wu, Xin Jiang, Zhenguo Li, and Lingpeng Kong\.Dream 7b: Diffusion large language models, 2025\.URL[https://arxiv\.org/abs/2508\.15487](https://arxiv.org/abs/2508.15487)\.
- Jo et al\. \[2025\]Mingyu Jo, Jaesik Yoon, Justin Deschenaux, Caglar Gulcehre, and Sungjin Ahn\.Loopholing discrete diffusion: Deterministic bypass of the sampling wall, 2025\.URL[http://arxiv\.org/abs/2510\.19304](http://arxiv.org/abs/2510.19304)\.
- Xia et al\. \[2026\]Kejing Xia, Mingzhe Li, Lixuan Wei, Zhenbang Du, Xiangchi Yuan, Qirui Jin, and Wenke Lee\.MetaState: Persistent working memory for discrete diffusion language models, 2026\.URL[http://arxiv\.org/abs/2603\.01331](http://arxiv.org/abs/2603.01331)\.
- Zheng et al\. \[2025\]Huangjie Zheng, Shansan Gong, Ruixiang Zhang, Tianrong Chen, Jiatao Gu, Mingyuan Zhou, Navdeep Jaitly, and Yizhe Zhang\.Continuously augmented discrete diffusion model for categorical generative modeling, 2025\.URL[http://arxiv\.org/abs/2510\.01329](http://arxiv.org/abs/2510.01329)\.
- Hersche et al\. \[2025\]Michael Hersche, Samuel Moor\-Smith, Thomas Hofmann, and Abbas Rahimi\.Soft\-masked diffusion language models, 2025\.URL[http://arxiv\.org/abs/2510\.17206](http://arxiv.org/abs/2510.17206)\.
- Xie et al\. \[2025\]Tianyu Xie, Shuchen Xue, Zijin Feng, Tianyang Hu, Jiacheng Sun, Zhenguo Li, and Cheng Zhang\.Variational autoencoding discrete diffusion with enhanced dimensional correlations modeling, 2025\.URL[http://arxiv\.org/abs/2505\.17384](http://arxiv.org/abs/2505.17384)\.
- Gong et al\. \[2025\]Shansan Gong, Mukai Li, Jiangtao Feng, Zhiyong Wu, and LingPeng Kong\.Generative recursive reasoning models\.In*International Conference on Learning Representations \(ICLR\)*, 2025\.URL[https://openreview\.net/pdf?id=Vxu6kcIjwV](https://openreview.net/pdf?id=Vxu6kcIjwV)\.
- Patel et al\. \[2026\]Dhruvesh Patel, Durga Prasad Maram, Sai Sreenivas Chintha, Benjamin Rozonoyer, and Andrew McCallum\.xLM: A python package for non\-autoregressive language models\.In Danilo Croce, Jochen Leidner, and Nafise Sadat Moosavi, editors,*Proceedings of the 19th Conference of the European Chapter of the ACL \(Volume 3: System Demonstrations\)*, pages 445–456, Rabat, Morocco, March 2026\. Association for Computational Linguistics\.doi:10\.18653/v1/2026\.eacl\-demo\.31\.URL[https://aclanthology\.org/2026\.eacl\-demo\.31/](https://aclanthology.org/2026.eacl-demo.31/)\.

## Appendix

###### Contents

1. [1Introduction](https://arxiv.org/html/2605.22967#S1)
2. [2Background: Masked Diffusion Models](https://arxiv.org/html/2605.22967#S2)
3. [3Learned Relay Representations](https://arxiv.org/html/2605.22967#S3)1. [3\.1Augmented State Trajectories](https://arxiv.org/html/2605.22967#S3.SS1) 2. [3\.2Training](https://arxiv.org/html/2605.22967#S3.SS2)
4. [4Experiments](https://arxiv.org/html/2605.22967#S4)1. [4\.1Sudoku](https://arxiv.org/html/2605.22967#S4.SS1) 2. [4\.2Pretrained Model Adaptation: Fast\-dLLM v2](https://arxiv.org/html/2605.22967#S4.SS2)1. [4\.2\.1Training memory overhead](https://arxiv.org/html/2605.22967#S4.SS2.SSS1)
5. [5Related Work](https://arxiv.org/html/2605.22967#S5)
6. [6Discussion](https://arxiv.org/html/2605.22967#S6)
7. [7Conclusion](https://arxiv.org/html/2605.22967#S7)
8. [References](https://arxiv.org/html/2605.22967#bib)
9. [Appendix](https://arxiv.org/html/2605.22967#Ax1)
10. [AExperimental Details](https://arxiv.org/html/2605.22967#A1)1. [A\.1Sudoku](https://arxiv.org/html/2605.22967#A1.SS1)1. [A\.1\.1Dataset](https://arxiv.org/html/2605.22967#A1.SS1.SSS1) 2. [A\.1\.2Model Architecture](https://arxiv.org/html/2605.22967#A1.SS1.SSS2) 3. [A\.1\.3Training Hyperparameters](https://arxiv.org/html/2605.22967#A1.SS1.SSS3) 2. [A\.2Fast\-dLLM v2](https://arxiv.org/html/2605.22967#A1.SS2)1. [A\.2\.1Dataset](https://arxiv.org/html/2605.22967#A1.SS2.SSS1) 2. [A\.2\.2Model Architecture](https://arxiv.org/html/2605.22967#A1.SS2.SSS2) 3. [A\.2\.3Fast\-dLLM v2 training hardware and parallelism](https://arxiv.org/html/2605.22967#A1.SS2.SSS3)
11. [BFast\-dLLM v2 memory profiling](https://arxiv.org/html/2605.22967#A2)

## Appendix AExperimental Details

### A\.1Sudoku

#### A\.1\.1Dataset

Sudoku Extreme\. We train and evaluate on our derived dataset built on top of Sudoku\-Extreme\[Gong et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib32), Wang et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib19)\]by running the solver ofVink \[[2024](https://arxiv.org/html/2605.22967#bib.bib20)\]over every puzzle\.111Derived dataset:[https://huggingface\.co/datasets/brozonoyer/sapientinc\-sudoku\-extreme\-timvink\-sudoku\-solver](https://huggingface.co/datasets/brozonoyer/sapientinc-sudoku-extreme-timvink-sudoku-solver)\. Solver code:[https://github\.com/timvink/sudoku\-solver](https://github.com/timvink/sudoku-solver)\.The base dataset consists of 9×\\times9 Sudoku puzzles with 17 given clues, the minimum number compatible with a uniquely solvable puzzle\[McGuire et al\.,[2012](https://arxiv.org/html/2605.22967#bib.bib18)\]\. Each puzzle is represented as a flat sequence of lengthL=81L\{=\}81over a vocabulary of\|𝒱\|=11\|\{\\mathcal\{V\}\}\|\{=\}11task tokens: digits\{1,…,9\}\\\{1,\\ldots,9\\\}, a blank/zero token for unfilled cells, and a mask token\. The clue positions are treated as fixed and are not modified during inference; the remaining 64 positions are mutable\. Our derived version augments each puzzle with:

- •trajectory: step\-by\-step board states from question to solution
- •num\_steps: number of solver calls to reach the solution
- •strategies\_used: set of human\-like deduction strategies invoked \(used by the deduction\-only cohort below\)

We use the training split \(3,831,994 puzzles\) and evaluate on the test split \(422,786 puzzles\), and validate on the first 100 batches at batch size 512 \(51,200 puzzles\) per checkpoint\.

Deduction\-only cohort\. For the qualitative legality analysis of[Section˜4\.1](https://arxiv.org/html/2605.22967#S4.SS1)we use thestrategies\_usedfield described above to filter puzzles\. Since the studied architectures cannot perform recursive search, we keep only puzzles whose solver trace contains*Advanced*\(Naked Pair, Hidden Pair, Naked Triple, Hidden Triple, Naked Quad, Hidden Quad\) or*Master*\(X\-Wing, Swordfish, Jellyfish, Forcing Chain\) strategies and never falls back on recursive backtracking\. The resulting cohort contains 2,000 test puzzles \(1,933 Advanced \+ 67 Master\)\.

Evaluation protocol for[Table˜1](https://arxiv.org/html/2605.22967#S4.T1)\. Each cell of[Table˜1](https://arxiv.org/html/2605.22967#S4.T1)aggregates the firstN=2000N\{=\}2000puzzles from the Hugging Face test split in dataset order; for the deduction\-only cohort, we keep the first 2,000 examples whose solver trace uses Advanced or Master strategies without recursive backtracking\.

#### A\.1\.2Model Architecture

The backbone for all Sudoku experiments \([Table˜1](https://arxiv.org/html/2605.22967#S4.T1)\) is a shallow rotary Transformer:

- •Depth / width:L=4L\{=\}4layers, hidden dimensiondmodel=384d\_\{\\mathrm\{model\}\}\{=\}384, feedforward width4​dmodel=15364d\_\{\\mathrm\{model\}\}\{=\}1536
- •Attention:H=6H\{=\}6heads \(head dimensiondmodel/H=64d\_\{\\mathrm\{model\}\}/H\{=\}64\), rotary positional embeddings \(rotary width6464\)
- •MLP:ReLU nonlinearities, dropout0\.10\.1
- •Vocabulary:digits\{0,…,9\}\\\{0,\\ldots,9\\\}plus special tokens,\|𝒱\|=17\|\\mathcal\{V\}\|\{=\}17

TheRelayvariant adds a differentiable carry channel followingJo et al\. \[[2025](https://arxiv.org/html/2605.22967#bib.bib27)\]\. At each inference step the relay tensorhth\_\{t\}from the previous step is normalized by an affine LayerNorm \(εLN=10−5\\varepsilon\_\{\\mathrm\{LN\}\}\{=\}10^\{\-5\}\), yieldingδt=LNrelay​\(ht\)\\delta\_\{t\}=\\mathrm\{LN\}\_\{\\mathrm\{relay\}\}\(h\_\{t\}\), and injected additively into the residual stream before layer zero:x←Embed​\(xt\)\+δtx\\leftarrow\\mathrm\{Embed\}\(x\_\{t\}\)\+\\delta\_\{t\}\. The outgoing relay stateht\+1h\_\{t\+1\}is read from the final transformer block, while logits are always produced from the same terminal hidden states\. We initializeLNrelay\\mathrm\{LN\}\_\{\\mathrm\{relay\}\}with PyTorch defaults \(𝜸relay←𝟏\\bm\{\\gamma\}\_\{\\mathrm\{relay\}\}\\leftarrow\\mathbf\{1\},𝜷relay←𝟎\\bm\{\\beta\}\_\{\\mathrm\{relay\}\}\\leftarrow\\mathbf\{0\}\)\. We implement all the models using the xLM\[Patel et al\.,[2026](https://arxiv.org/html/2605.22967#bib.bib33)\]package, which provides a unified interface for training and inference of non\-autoregressive language models making the ablations and experiments easy to reproduce\.

Parameter counts \(with and without weight tying\) are:

- •Baseline\(MLM / rollout\-buffer only\): 7,105,536 untied; 7,099,008 tied
- •Relay: 7,106,304 untied; 7,099,776 tied

#### A\.1\.3Training Hyperparameters

- •Batch size:512 \(single GPU, bf16 mixed precision\)
- •Optimizer:AdamW, learning rate5×10−45\\times 10^\{\-4\}, weight decay10−210^\{\-2\}
- •LR schedule:constant with 2,000\-step linear warmup, no decay thereafter
- •Gradient clipping:global Frobenius norm0\.50\.5
- •BPTT unroll horizon:K=2K\{=\}2steps \(Relayruns only\)
- •Confidence threshold:τ=0\.15\\tau\{=\}0\.15\(maximum softmax probability\), perturbed by𝒩​\(0,0\.12\)\\mathcal\{N\}\(0,0\.1^\{2\}\)during training and fixed at inference
- •Validation:every 5,000 steps on 100 batches; threshold sweepτ∈\{0\.05,0\.10,…,0\.25\}\\tau\\in\\\{0\.05,0\.10,\\ldots,0\.25\\\}
- •Total steps:300,000; results reported in[Table˜1](https://arxiv.org/html/2605.22967#S4.T1)

### A\.2Fast\-dLLM v2

#### A\.2\.1Dataset

##### OpenCode/OpenMath c40m60 mixture for Fast\-dLLM v2 adaption\.

For Fast\-dLLM v2 adaptation, we use a 60k\-example supervised fine\-tuning mixture fromnvidia/OpenCodeInstruct222[https://huggingface\.co/datasets/nvidia/OpenCodeInstruct](https://huggingface.co/datasets/nvidia/OpenCodeInstruct)andnvidia/OpenMathInstruct\-2333[https://huggingface\.co/datasets/nvidia/OpenMathInstruct\-2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2), with 24k code examples and 36k math examples\. We filter for high\-quality prompt–answer pairs, remove held\-out evaluation contamination, format examples as one\-turn conversations, and cap sequences at 2048 tokens\.

#### A\.2\.2Model Architecture

Unlike standard MDMs, which denoise the entire token sequence globally, Fast\-dLLM v2 models\[Wu et al\.,[2025b](https://arxiv.org/html/2605.22967#bib.bib11)\]a block\-wise Markov process\. By partitioning the sequence into blocks of sizeDD, it targets the local conditional distributionpθ​\(xb\|xtb,x0<b\)p\_\{\\theta\}\(x^\{b\}\|x\_\{t\}^\{b\},x\_\{0\}^\{<b\}\)\. This localizes the diffusion process while anchoring it to an autoregressive prefix, successfully bypassing the immense pretraining costs associated with full\-attention MDMs\.

The core architectural shift lies in its attention topology\. Fast\-dLLM v2 concatenates the noisedxtx\_\{t\}and cleanx0x\_\{0\}sequences into a2​L2L\-length tensor, governed by a full attention maskℳf​u​l​l∈\{0,1\}2​L×2​L\\mathcal\{M\}\_\{full\}\\in\\\{0,1\\\}^\{2L\\times 2L\}\[Arriola et al\.,[2025](https://arxiv.org/html/2605.22967#bib.bib9)\]\. This mask explicitly splits into three distinct functional roles:

- •ℳB​D\\mathcal\{M\}\_\{BD\}: Enables intra\-block bidirectional attention within each block\.
- •ℳO​B​C\\mathcal\{M\}\_\{OBC\}: Allows the noised block to attend to the completely denoised, clean prefixx0<bx\_\{0\}^\{<b\}\.
- •ℳB​C\\mathcal\{M\}\_\{BC\}: Enforces standard left\-to\-right causality among the clean tokens\.

The2​L2Lconcatenation lets the noised and clean views be processed in a single forward pass\. On top of this, a complementary masking strategy trains on both a sampled maskmmand its complementm¯=1−m\\bar\{m\}=1\-m, so that every token in the input contributes supervision rather than only those masked undermm\.

At inference, this topology enables hierarchical Key\-Value caching—a major advantage over standard MDMs, which typically require full\-sequence recomputation at every denoising step\. Completely denoised blocksx0<bx\_\{0\}^\{<b\}are saved as read\-only context, while a DualCache handles prefix and suffix activations within the active, partially noised blockxtbx\_\{t\}^\{b\}\.

#### A\.2\.3Fast\-dLLM v2 training hardware and parallelism

All adaptation runs use DeepSpeed ZeRO\-3 with bf16 mixed precision and gradient checkpointing on two NVIDIA A100 80GB GPUs, with per\-device batch size22and gradient accumulation1616\(effective batch size3232\)\. ForRelayadaptation,LNrelay\\mathrm\{LN\}\_\{\\mathrm\{relay\}\}uses zero\-initialized𝜸relay\\bm\{\\gamma\}\_\{\\mathrm\{relay\}\}\(with𝜷relay=0\\bm\{\\beta\}\_\{\\mathrm\{relay\}\}\{=\}0\), so early forward passes approximate an identity relay until training updates𝜸relay\\bm\{\\gamma\}\_\{\\mathrm\{relay\}\}\[Wu et al\.,[2025b](https://arxiv.org/html/2605.22967#bib.bib11)\]\.

## Appendix BFast\-dLLM v2 memory profiling

This section gives the protocol and per\-phase numbers behind[Figure˜3](https://arxiv.org/html/2605.22967#S4.F3), repeated below\.

![[Uncaptioned image]](https://arxiv.org/html/2605.22967v1/x4.png)

Setup\. We profile a single training micro\-step of Fast\-dLLM v2 on the OpenCode/OpenMathc40m60mixture under the same hardware and parallelism as the main runs \([Section˜A\.2\.3](https://arxiv.org/html/2605.22967#A1.SS2.SSS3)\): two A100 80GB GPUs, DeepSpeed ZeRO\-3, bf16, and gradient checkpointing, with sequence length20482048and per\-device batch size22\. Production runs use gradient accumulation1616; profiling forces accumulation to11and replacesoptimizer\.stepwith a no\-op so that the recorded peak is attributable to a single forward/backward pair rather than to optimizer\-state allocation\.

Instrumentation\. On every decoder\-layer forward and backward hook we logmemory\_allocatedandmax\_memory\_allocatedfromtorch\.cuda—the solid and dashed traces in[Figure˜3](https://arxiv.org/html/2605.22967#S4.F3), respectively; we callreset\_peak\_memory\_stats\(\)once at the start of the profiled micro\-step so the dashed series is a within\-step high\-water mark rather than a long\-run accumulator\. All measurements are taken in eager mode withtorch\.compileand FlashAttention 2 disabled, so steps in the dashed roof correspond directly to discrete kernel\-level allocations\. Phase labels \(fwd,fwd2,bwd\) are placed at each phase’s plateau in the dashed series\. Horizontal axes are profiler event indices \(8888for vanilla SFT,174174forRelay\) and are not directly comparable across the two curves\.

Per\-phase peaks\. The dashed all\-time peaks settle at20,61820\{,\}618MiB \(≈20\.1\\approx\\\!20\.1GiB\) forRelayversus21,68321\{,\}683MiB \(≈21\.2\\approx\\\!21\.2GiB\) for vanilla SFT\. In both regimes the peak is set at the start ofbwd, when the cross\-entropy backward throughlm\_headtransiently allocates aB×T×VB\\\!\\times\\\!T\\\!\\times\\\!Vgradient\-of\-logits buffer that HuggingFace materializes in fp32 for numerical stability \(\[2,2048,151936\]×4​B≈2\.3\[2,2048,151936\]\\\!\\times\\\!4\\,\\text\{B\}\\\!\\approx\\\!2\.3GiB\) on top of the bf16 logits tensor it is differentiating\.Relay’s second forward elevates the live trace by≈5\\approx\\\!5GiB throughfwd2—saved activations of forward 1 plus the relay state𝒉\{\\bm\{h\}\}must coexist with forward 2 to provide credit through both passes \([Algorithm˜1](https://arxiv.org/html/2605.22967#algorithm1)\)—reaching17,05717\{,\}057MiB atrelay\_fwd2\_end, but this is still below the20,61820\{,\}618MiB peak that the CE backward sets one event later, so the live plateau throughfwd2does not become the binding peak\. The≈1\\approx\\\!1GiB gap inRelay’s favor between the two final peaks is structural rather than allocator noise: more autograd intermediates from forward 1 are released by the time the CE backward fires than vanilla releases by the analogous event \(Δ​live=−7,119\\Delta\\text\{live\}\\\!=\\\!\-7\{,\}119MiB across this transition forRelayvs\.−2,727\-2\{,\}727MiB for vanilla\)\. Rank\-11traces reproduce both peaks within<0\.1<\\\!0\.1MiB, so[Figure˜3](https://arxiv.org/html/2605.22967#S4.F3)shows only rank0\.

Similar Articles

Rethinking Cross-Layer Information Routing in Diffusion Transformers

Hugging Face Daily Papers

This paper proposes Diffusion-Adaptive Routing (DAR), a learnable, timestep-adaptive residual replacement that improves cross-layer information flow in Diffusion Transformers, leading to significant training acceleration and quality improvements.