Block-Based Double Decoders
Summary
Proposes block-based double decoders, a novel transformer architecture using doubly-causal block-based attention masks to combine decoder-only training efficiency with encoder-decoder inference efficiency, achieving strong scaling performance and reduced KV-cache memory.
View Cached Full Text
Cached at: 05/20/26, 08:37 AM
# Block-Based Double Decoders
Source: [https://arxiv.org/html/2605.18807](https://arxiv.org/html/2605.18807)
Vanessa Alexander vanessa\_alexander@brown\.edu &Benjamin Bradley benjamin\_bradley@brown\.edu &Chaitanya Harsha chaitanya\_harsha@brown\.edu &Asher Labovich asher\_labovich@brown\.edu Department of Computer Science Brown University Providence, RI 02912
###### Abstract
Encoder\-decoder models offer substantial inference\-time savings over decoder\-only models, but their pretraining objectives suffer from sparse supervision and dynamic sequence lengths, keeping them out of practice at scale\. We proposeblock\-based double decoders, a novel transformer architecture that utilizes doubly\-causal block\-based attention masks to train with full loss supervision and static sequence packing, combining decoder\-only training efficiency with encoder\-decoder inference efficiency\. In scaling law experiments, block\-based double decoders strongly outperform encoder\-decoders and closely track decoder\-only models across scales\. At inference time, they cut KV\-cache memory and per\-token compute by at least23\\frac\{2\}\{3\}without sacrificing prefill caching or other existing inference optimizations available to decoder\-only models\.
## 1Introduction
In recent years, the rise of the transformer\[[16](https://arxiv.org/html/2605.18807#bib.bib1)\]has led to advances in natural language modeling\. Although the original transformer employed an encoder\-decoder based architecture, with full attention for understanding in the encoder and causal attention for prediction in the decoder, decoder\-only architectures have gained prominence for their scalability in text\-generation settings\[[11](https://arxiv.org/html/2605.18807#bib.bib3)\]\. Recently, however, there has been renewed attention in encoder\-decoder architectures\[[3](https://arxiv.org/html/2605.18807#bib.bib4)\]for their gains in efficiency and their effectiveness in compute\- and storage\-constrained environments, as they achieve substantial reductions in KV\-caching\.
Even with this work, it has been shown\[[12](https://arxiv.org/html/2605.18807#bib.bib5)\]that the architectural boundary between the encoder and decoder leads to inefficiencies, and that a2P2Pparameter encoder\-decoder model incurs the same computational cost as aPPparameter decoder\-only architecture\. An alternative single\-transformer approach, PrefixLM, in which full attention is used for a prefix of the inputs, and causal attention after this, attempts to combine the bidirectional context of encoder\-decoders with the sharing of parameters in decoder\-only models\. This method benefits from less dynamic batching, but leaves many tokens untrained on\.
But PrefixLM is outperformed by an encoder\-decoder trained with span corruption, where the model is trained to predict a proportion of “corrupted” tokens\. Although this achieves superior performance, it requires highly dynamic batching, and still leaves many tokens untrained on, as\[[12](https://arxiv.org/html/2605.18807#bib.bib5)\]found that a mere 15% corruption results in the best performance\.
We propose a novel attention mask for pre\-training111Code found at[https://github\.com/ashlab11/block\-based\-double\-decoder](https://github.com/ashlab11/block-based-double-decoder), which we calldoubly\-causal block\-based masking\. Our proposed architecture consists of two decoders\. The first sees a standard causal mask and takes in the standard input\. The second decoder conducts cross\-attention on the output of the first decoder, but also on the input, masked with our approach\. We split the input into “blocks,” with full self\-attention within each block, and causal cross\-attention between blocks\. In this way, we receive loss signal from every token in the input, and we achieve constant token length\.
Figure 1:Graph comparing loss versus tokens for different size decoder, double decoder, and encoder\-decoder models
## 2Prior Work
The original Transformer architecture\[[16](https://arxiv.org/html/2605.18807#bib.bib1)\]consists of a bidirectional self\-attention encoder paired with a causal decoder for autoregressive generation\. The encoder processes the full input sequence to produce contextualized representations for each token\. In contrast, the decoder applies masked \(causal\) self\-attention and additionally incorporates cross\-attention, allowing each decoding position to attend to all of the encoded input representations\. However, in recent years, it has been found that for generative modeling, decoder\-only models are more scalable and achieve superior performance\[[11](https://arxiv.org/html/2605.18807#bib.bib3),[1](https://arxiv.org/html/2605.18807#bib.bib6)\]\. Meanwhile, encoder\-only architectures\[[2](https://arxiv.org/html/2605.18807#bib.bib7)\], trained using token masking, are performant when fine\-tuned on downstream classification\-type tasks\.
However,\[[12](https://arxiv.org/html/2605.18807#bib.bib5)\]demonstrated that any classification task can fundamentally be described as a text\-to\-text problem, putting it in reach of generative models\. Their findings included a thorough search of several architectures and pre\-training objectives, including decoder\-only with the standard LM objective, decoder\-only with the PrefixLM objective, and encoder\-decoder with span corruption, among others\. They found that across all the downstream tasks these models were tasked with, including question answering, summarizing, and translation, encoder\-decoder with span corruption performed the best across a range of prediction tasks\. In span corruption, contiguous spans of tokens are "corrupted" or replaced with sentinel tokens that are assigned a unique ID specific to that sentence\. The target sequence consists of the concatenated missing spans, and the model is then trained to autoregressively generate those missing spans\. Typically, only a fraction of the tokens \(about 15%\) are corrupted with major performance degradation occurring at 50% corruption\.
Empirically, this architecture combined with span corruption has been shown to outperform both decoder\-only language models \(LM\) and prefix variants under comparable training settings\. Additionally, span corruption as a pretraining objective has been shown to empirically outperform standard LM pretraining objectives under comparable training settings within the same model architectures\.
We note that, despite its high performance, that there are several issues with span corruption as a pre\-training objective\. First,\[[12](https://arxiv.org/html/2605.18807#bib.bib5)\]studied what rate of span corruption would be most performant, and confirmed\[[2](https://arxiv.org/html/2605.18807#bib.bib7)\]’s findings that 15% span corruption led to optimal training\. However, this means that only these 15% of tokens yield a loss signal, so most tokens are not being trained on\. This creates a fundamental tradeoff between increasing prediction difficulty to enhance long\-range reasoning and maximizing supervision density relative to the total computational cost\. Additionally, span corruption requires dynamic batching, as for a given batch, the number of tokens predicted will not be constant across all training examples\. This requires token padding, a waste of compute and memory\. We hypothesize that an alternate pre\-training objective can perform as well as span corruption, but without these inefficiencies\.
Additionally, prior work such as\[[18](https://arxiv.org/html/2605.18807#bib.bib8)\]demonstrated that attention patterns that capture both local dependencies and global information can achieve strong performance at reduced computational cost\. Although we strive for train\-time efficiency overO\(n\)O\(n\)attention, we take inspiration from this notion with an architecture that captures both local and global interdependencies\.
## 3Architecture Comparison
### 3\.1Block\-based double decoders
BOSBOSAABBCCDDEEFFkeys→\\rightarrowqueries→\\rightarrowself\-attentioncross\-attention
Figure 2:Visual explanation of the decoder attention mask for an example sentence\. Splitting up into three parts, we see that the decoder sees three context\-response pairs in parallel:
\(1\):empty context, \[BOS, A\] response
\(2\):\[BOS, A\] context, \[B, C, D\] response
\(3\):\[BOS, A, B, C, D\] context, \[E, F\] response
Each token thus appears in the loss exactly once per forward pass\.In this paper, we wish to create an architecture for next\-token prediction that retains the benefits of encoder\-decoder architectures – in particular, substantial KV\-cache reductions and ease of use on edge device – without sacrificing the training efficiency of decoder\-only transformers\. We consider two criteria when creating such an architecture\. First,*loss\-information density:*every token in a packed training example should contribute a loss signal in every forward pass\. Second,sequence length statis:post\-packing sequence length should not change across training batches, so that throughput is predictable and maximally efficient\.
To motivate our architecture, we first consider a basic PrefixLM encoder\-decoder training objective, which fails both criteria\. To train one, we must randomly choose split\-points at each batch, placingN%N\\%in the encoder and\(100−N\)%\(100\-N\)\\%in the decoder\. Only the decoder portion contributes to the loss, so loss\-information density is\(100−N\)%\(100\-N\)\\%, and token counts fluctuate by batch depending on N\. This results in both dynamic length batching \(since N changes by batch\) and lower loss information per token, which combined substantially harms training efficiency\. We solve both of these problems via our proposed architecture, block\-based double decoders\.
Concretely, a block\-based double decoder consists of two stacks\. The first, which we call thecontext decoder, is a standard causal decoder\-only transformer; it takes in the full input and outputs causal latentshth\_\{t\}for each token\. The second, thegeneration decoder, takes in three inputs: \(1\) the causal latents from the context decoder, \(2\) the token sequence itself, and \(3\) ablock partition: a strictly increasing sequence of indices0=b0<b1<…<bK=T0=b\_\{0\}<b\_\{1\}<\\ldots<b\_\{K\}=Tthat splits the sequence into K contiguous subsequences\. For the length\-7 sequence in Figure[2](https://arxiv.org/html/2605.18807#S3.F2), the partition \(0, 2, 5, 7\) creates blocks \[BOS, A\], \[B, C, D\], \[E, F\]\.
Within the generation decoder, each query at positionttin blockkkattends to two subsequences: thewithin\-block keys, comprising the tokens of blockkkat positions≤t\\leq t\(i\.e\. causal self\-attention\), andcross\-block keys, comprising the context\-decoder latentshsh\_\{s\}for allssin blocksm<km<k\(full cross\-attention\)\. The mask in Figure[2](https://arxiv.org/html/2605.18807#S3.F2)is the union of these two attention mechanisms\.
There exists a subtletly in how these two operations are combined\. Treating these as two separate attention mechanisms and residually adding the outputs, as is common in normal encoder\-decoders, introduces three problems\. First, the order in which the two operations are applied becomes architecturally significant despite being completely arbitrary\. Second, certain query rows have no keys under the cross\-attention mask \(e\.g\. \[BOS, A\] in Figure[2](https://arxiv.org/html/2605.18807#S3.F2)\), leaving the corresponding softmax undefined\. Third, the attention\-sink behavior of standard attention is lost: with two separate softmaxes, the operation is forced to allocate equal probability mass to each of the two subsequences, even when only one is important for predicting the next token\. These three problems can all be solved at once by conducting just one attention with two different key matrices depending on the \(query, key\) index pair\. However, since to our knowledge there currently exists no fast attention implementation that allows this dual\-key mechanism, we instead compute the attentionsseparatelyand combine their log\-sum\-exp normalizes post\-hoc\. This is mathematically equivalent to conducting the singular attention operation, but allows for fast implementation via PyTorch’s FlexAttention function\. We note that this is substantially less efficient from a computational perspective than the “ideal” method, which would directly utilize the sparsity induced by the chosen blocks to minimize total multiplications across the dual SDPA applications, cutting off all additional compute beyond the calculation of an additional KV matrix\. We leave the creation of this ideal method to future work\.
This architecture achieves both of the criteria mentioned earlier\. Every token appears in the loss exactly once per forward pass, regardless of the number of blocks chosen \(though, a small first block has an empty context\-decoder, and thus must conduct all reasoning in the generation\-decoder\)\. In addition, the only dynamic component is the block list, which changes each batch; however, it affects only the attention mask of the generation\-decoder, which can be handled with minimal latency by PyTorch’s FlexAttention\. And, the existence of a context\-decoder retains the benefits of normal encoder\-decoder architectures: at inference\-time, the context\-decoder runs once over the prompt, thus requiring no KV\-cache and achieving substantial speed\-ups\.
Each of the three architectures mentioned in this paper – decoder\-only, encoder\-decoder, and double\-decoder – have substantially different compute requirements during both training and inference, even when parameter and token\-matched\. The following sections describe these differences in detail\.
### 3\.2Training Time Comparisons
Although our primary focus is on the large inference\-latency differences between decoder\-only, encoder\-decoder, and double decoder models, the three architectures also differ in training compute, and the standard6NT6NTheuristic obscures these differences\. In Appendix[A\.2](https://arxiv.org/html/2605.18807#A1.SS2), we derive architecture\-aware FLOP formulas: double decoders incur additional compute from their extra KV projections, while encoder\-decoders save compute by feeding fewer tokens through the decoder under span corruption\. Specifically, for a decoder\-only model with sequence lengthTT, hidden dimensiondd, andLLlayers, the approximate train\-time FLOP count is
L\(72Td2\+12T2d\)\.L\(72Td^\{2\}\+12T^\{2\}d\)\.For an encoder\-decoder trained with span corruption, with padded encoder input sequence lengthTinT\_\{in\}and decoder sequence lengthTout,T\_\{out\},this is
L\(\(52Tin\+28Tout\)d2\+\(4Tout2\+4TinTout\+8Tin2\)d\)\.L\(\(52T\_\{in\}\+28T\_\{out\}\)d^\{2\}\+\(4T\_\{out\}^\{2\}\+4T\_\{in\}T\_\{out\}\+8T\_\{in\}^\{2\}\)d\)\.and for an efficient implementation of the double decoder,
L\(76Td2\+12T2d\)\.L\(76Td^\{2\}\+12T^\{2\}d\)\.
Although double decoder models require more training compute than decoder\-only models, the difference is small when implemented efficiently: atT=Tin=2048T=T\_\{in\}=2048andTout=256,T\_\{out\}=256,double decoder uses only2\.4%2\.4\\%more FLOPs than decoder\-only \(while encoder\-decoder uses21%21\\%*fewer*FLOPs\)\. Combining these formulas with our empirical scaling laws lets us compare the compute required by each architecture to reach matched perplexity on held\-out test sets\.
### 3\.3Inference Time Comparison
As mentioned in Section[3\.2](https://arxiv.org/html/2605.18807#S3.SS2), double decoders tend to slightlyunderperformdecoder\-only models in training compute due to their additional KV matrices in every transformer block\. However, this disadvantage is strongly outweighed by their numerous benefits during inference, which we detail below\. All benefits arise directly from the natural separation of context and response in both training and inference, and one such benefit is unavailable for classic encoder\-decoder models due to their bidirectional nature\.
1. 1\.Since the generation decoder need only refer to the final output of the context decoder, there is no need to save the activations from the context decoder\. This saves both time and memory;\[[3](https://arxiv.org/html/2605.18807#bib.bib4)\]found 4\.7x higher throughput at low context lengths, and\[[15](https://arxiv.org/html/2605.18807#bib.bib13)\]noted these throughput gaps only increase as context length grows\. In addition, the KV\-cache required during generation scales only with the size of the generation decoder, rather than the full model\. As detailed in Appendix[A\.1](https://arxiv.org/html/2605.18807#A1.SS1), this results in a𝟐𝟑\\mathbf\{\\frac\{2\}\{3\}\}KV\-cache memory reduction in comparison to decoder\-only models under our architectural design, with the benefit growing as the ratio of context\-decoder to generation\-decoder layers grows\. Furthermore, these benefits carry over to per\-token latency; as generation depends only on the size of the generation\-decoder after the context\-decoder representation is completed, the latency cost per\-token scales equivalently to that of the memory boost\. These advantages open up a connection to test\-time compute scaling: if it is possible to inflate the compute spent ononlythe context\-decoder \(i\.e\. without adding more tokens or increasing compute in the generation\-decoder\) at test\-time, then TTFT increases while per\-token latency and memory requirements remain the same\. Recently, works on looped transformers\[[8](https://arxiv.org/html/2605.18807#bib.bib9),[10](https://arxiv.org/html/2605.18807#bib.bib10),[4](https://arxiv.org/html/2605.18807#bib.bib11),[13](https://arxiv.org/html/2605.18807#bib.bib12)\]have found that looping layers at test\-time can improve performance and generalization capacity; doing so just for the context decoder in our work may effectively marry the positive effects on reasoning with the low memory\-and\-compute costs of dual\-stack models\.
2. 2\.In scenarios where memory is particularly scarce, models with two separate stacks \(including both regular encoder\-decoders and our block\-based double decoder\) can “swap” between CPU and GPU when needed\. A model split 2/3 encoder and 1/3 decoder can save 1/3 parameter memory while applying the context decoder, before swapping the placement of each stack when generation is needed\.
3. 3\.In cases where model providers expect many prompts to share a common prefix \(e\.g\. a long shared system prompt\) they oftenprefillthe KV\-cache for that prefix once and reuse it across requests, dramatically reducing TTFT\. While immensely useful in decoder\-only models – especially given that system prompts are typically long and thus the dominant contributor to TTFT without prefill – this form of caching is fundamentally incompatible with encoder\-decoder models: the bidirectional encoder makes every token’s representation depend on the full input, so a cached prefix cannot be reused once the suffix changes\. Our double decoder directly resolves this, as its doubly\-causal nature allows immediate transfer of prefix\-level KV caching from decoder\-only inference stacks\.
4. 4\.The architectural separation in our model also reduces latency\. Time\-to\-first\-token \(TTFT\) depends only on the depth of the context decoder, yielding a reduction proportional toLencL\\frac\{L\_\{enc\}\}\{L\}which corresponds to a𝟏𝟑\\mathbf\{\\frac\{1\}\{3\}\}reduction in TTFT latencyunder the same split\. Notably, the benefits of prefill \(as used in\[[7](https://arxiv.org/html/2605.18807#bib.bib18)\]and other inference packages\) transfer immediately over to our model in a way theydo notto traditional encoder\-decoder models due to the bidirectional nature\. Given that many recent latency improvements come from improving prefill caching, this is animmediate shift with essentially zero fundamental inference code changes\.In addition to lower TTFT, our model also substantially reduces the latency of each token generated\. Per\-token generation cost depends only on the depth of the generation decoder, as each new token is processed autoregressively through this stack\. As a result, the latency cost per\-token scales proportionally toLdecL\\frac\{L\_\{dec\}\}\{L\}, yielding a corresponding reduction under the same 2/3 1/3 split\.
## 4Experimental Setup
#### Models and scaling grid\.
We compare three architectures – decoder\-only, standard encoder\-decoder \(SED\), and our proposed double decoder – all under a shared tokenizer, sequence length 2048, and tied input/output embeddings, with parameter counts approximately matched across families\. The main sweep is a three\-way grid over architecture, model size, and token budget, spanning 6\.25M\-100M parameters and 62\.5M\-1B tokens in multiples of 2\. We vary hidden dimension across the grid with num heads =d64\\frac\{d\}\{64\}so head dimension stays at 64, holding the depth profile at 8 encoder / 4 decoder layers \(or 12 decoder layers for decoder\-only models\) for most configurations\. The smallest models use 10/5 instead, because the multiple\-of\-64 width constraint leaves too little resolution to hit small parameter targets by width alone\. SED’s decoder carries an extra cross\-attention sublayer per block, so we reduce its decoder layer count to match non\-embedding parameters at fixed width\. We provide a detailed description of the hyperparameters constant to all models in Table[1](https://arxiv.org/html/2605.18807#A1.T1)\. We train all models on 10xH100 NVIDIA GPUs, with a total of 200 GPU\-hours for given results\. All experiments were conducted with PyTorch\[[9](https://arxiv.org/html/2605.18807#bib.bib16)\]\.
As described in Section[3\.2](https://arxiv.org/html/2605.18807#S3.SS2), we report architecture\-aware FLOP counts rather than the common6NT6NTheuristic used for decoder\-only models\.
#### μ\\muP and width transfer\.
We adopt a width\-scaling protocol inspired by maximal update parameterization\[[17](https://arxiv.org/html/2605.18807#bib.bib17)\]with base widthd0=64d\_\{0\}=64\. Hidden\-matrix learning rates are scaled byd0d\\frac\{d\_\{0\}\}\{d\}, embeddings and the tied output projection use the architecture\-specific base learning rate, and norms and biases skip weight decay\. After the tied output projection, logits are multiplied byd0d\\sqrt\{\\frac\{d\_\{0\}\}\{d\}\}\. Base learning rates are tuned on a 0\.5M model for all three architectures \(0\.01 for decoder\-only and double decoder, 0\.004 for encoder\-decoder\) and confirmed to scale to models 16x at large\. After confirming the ideal learning rate, we conducted a sweep of weight decay along the 0\.5M param model, and found 0\.5 best for double decoder and encoder decoder with 0\.1 best for decoder\-only\. Appendix[A\.3](https://arxiv.org/html/2605.18807#A1.SS3)provides a visual description of these sweeps\.
#### Pretraining data and objectives\.
All models pretrain on packed SlimPajama\[[14](https://arxiv.org/html/2605.18807#bib.bib14)\]with architecture\-native objectives: double decoder uses our doubly\-causal block\-based masking method, SED uses T5\-style span corruption with sentinels, and decoder\-only uses causal next\-token prediction\. Because these losses correspond to different prediction problems, we do not compare raw pretraining losses across families\. Instead, we add a post\-hoc prefix\-LM fine\-tuning phase \(10% of training data\) on held\-out SlimPajama with a base learning rate of 0\.0002 before applyingμ\\muP\.
The collator samples a breakpoint at each batch and trains the model to predict suffix tokens conditioned on prefix tokens\. For double decoder and SED, the prefix is routed through the encoder; for decoder\-only, the prefix merely does not provide additional loss information\. This equalizes the loss objective and data and allows for stronger cross\-architecture comparisons\.
## 5Results
Across the entire \(N, D\) grid, our block\-based double decoder substantially outperforms the encoder\-decoder baseline and tracks decoder\-only closely \(Figures[1](https://arxiv.org/html/2605.18807#S1.F1),[4](https://arxiv.org/html/2605.18807#S5.F4),[5](https://arxiv.org/html/2605.18807#S5.F5)\)\. At the largest configuration we evaluated \(100M params, 1B tokens\), the double decoder reaches an evaluation loss∼0\.2\\sim 0\.2nats worse than the parameter and token\-matched decoder\-only model, while the encoder\-decoder trailsbotharchitectures by∼0\.7\\sim 0\.7nats\. This relative ordering – decoder\-only≲\\lesssimdouble decoder≪\\llencoder\-decoder – holds across every model size and token budget in the sweep\.
The most informative difference is visible along the compute axis in Figure[4](https://arxiv.org/html/2605.18807#S5.F4)\. For both decoder\-only and double decoder, individual size\-curves cross over, tracing out the Pareto envelope characteristic of well\-known scaling laws\[[6](https://arxiv.org/html/2605.18807#bib.bib2)\]\. Yet, across the entire FLOP range tested, larger encoder\-decoder models are essentiallyuniformlyworse than smaller ones at every fixed compute level, and the curves remain nearly parallel on the log axis\. We interpret this as direct evidence that all tested encoder\-decoder configurations sit in the same scaling regime – specifically, thedata\-limitedregime, where additional parameters cannot be utilized within the token budgets explored\. Figure[1](https://arxiv.org/html/2605.18807#S1.F1)provides additional evidence for this claim: while the decoder\-only and double decoder curves flatten visibly past∼50\\sim 50M tokens as smaller models hit their capacity floors, the encoder\-decoder curves continue descending steeply at every model size with no sign of flattening\. This finding directly motivates and justifies block\-based double decoders\. As argued earlier, span corruption only receives loss information from∼15%\\sim 15\\%of tokens during the forward pass, so the encoder\-decoder receives a sparse signal and cannot fill its capacity within the token budgets we test\. By restoring full token loss information via block\-based double decoders, we immediately see substantial gaps across the entire compute grid\.
Figure 3:Table showcasing CE loss after training for each parameter/token combination\.The∼0\.2\\sim 0\.2nat gap between the decoder\-only model and the double decoder model is what we believe to be the cost of the architectural separation of context and response\. Section[3\.3](https://arxiv.org/html/2605.18807#S3.SS3)details the inference\-time benefits this separation enables, which we view as more than compensating for the modest training\-time loss gap\.
Figure 4:Graph of loss vs FLOPs curve for each size modelFigure 5:Graph of loss vs FLOPs curve for models by token training count
## 6Limitations and Future Work
While our method achieves strong results, we note several limitations\. First, attention blocks are sampled randomly per batch, which may increase variability and unpredictability of loss dynamics\. Second, due to compute constraints we were limited to relatively small\-scale models, and our scaling\-law extrapolations may not hold at several orders of magnitude beyond the regime we directly consider\. Training a select few larger models would substantially strengthen evidence for extrapolation at much larger scales\.
Beyond the inference time savings mentioned here, the double decoder architecture opens several promising research directions\. Latent chain\-of\-thought methods such as COCONUT\[[5](https://arxiv.org/html/2605.18807#bib.bib15)\]find benefits in inserting continuous reasoning tokens between context and response, but the absence of a native context\-response boundary in decoder\-only architectures has confined these techniques to post\-training\. Our architecture exposes this separation throughout the entire training process, theoretically opening the door for latent reasoning to be applied during pretraining\. As discussed in Section[3\.3](https://arxiv.org/html/2605.18807#S3.SS3), looped transformers\[[8](https://arxiv.org/html/2605.18807#bib.bib9),[10](https://arxiv.org/html/2605.18807#bib.bib10)\]also naturally synergize with the context decoder, possibly combining test\-time reasoning gains with low\-latency inference\. We leave direct testing of these possibilities to future work\.
## 7Broader Impacts
Improved efficiency of transformer models through our proposed block\-based double decoder architecture and doubly\-causal block\-based masking method may provide positive societal impacts by reducing the computational cost and memory usage during inference at small additional costs during training\. These efficiency gains can improve the accessibility of large\-scale language models to researchers with limited computational resources\. Notably, our proposed method does not introduce new capabilities or fundamentally alter model behavior, and therefore these societal risks are generally inherited from existing LMs\.
## 8Division of Labor
All authors contributed equally\. Asher and Ben focused primarily on coding and running models, while Vanessa and Chai focused on research and writing\.
## References
- \[1\]T\. B\. Brown, B\. Mann, N\. Ryder, M\. Subbiah, J\. Kaplan, P\. Dhariwal, A\. Neelakantan, P\. Shyam, G\. Sastry, A\. Askell, S\. Agarwal, A\. Herbert\-Voss, G\. Krueger, T\. Henighan, R\. Child, A\. Ramesh, D\. M\. Ziegler, J\. Wu, C\. Winter, C\. Hesse, M\. Chen, E\. Sigler, M\. Litwin, S\. Gray, B\. Chess, J\. Clark, C\. Berner, S\. McCandlish, A\. Radford, I\. Sutskever, and D\. Amodei\(2020\)Language models are few\-shot learners\.External Links:2005\.14165,[Link](https://arxiv.org/abs/2005.14165)Cited by:[§2](https://arxiv.org/html/2605.18807#S2.p1.1)\.
- \[2\]J\. Devlin, M\. Chang, K\. Lee, and K\. Toutanova\(2019\)BERT: pre\-training of deep bidirectional transformers for language understanding\.External Links:1810\.04805,[Link](https://arxiv.org/abs/1810.04805)Cited by:[§2](https://arxiv.org/html/2605.18807#S2.p1.1),[§2](https://arxiv.org/html/2605.18807#S2.p4.1)\.
- \[3\]M\. Elfeki, R\. Liu, and C\. Voegele\(2025\)Return of the encoder: maximizing parameter efficiency for slms\.External Links:2501\.16273,[Link](https://arxiv.org/abs/2501.16273)Cited by:[§1](https://arxiv.org/html/2605.18807#S1.p1.1),[item 1](https://arxiv.org/html/2605.18807#S3.I1.i1.p1.1)\.
- \[4\]J\. Geiping, S\. McLeish, N\. Jain, J\. Kirchenbauer, S\. Singh, B\. R\. Bartoldson, B\. Kailkhura, A\. Bhatele, and T\. Goldstein\(2025\)Scaling up test\-time compute with latent reasoning: a recurrent depth approach\.External Links:2502\.05171,[Link](https://arxiv.org/abs/2502.05171)Cited by:[item 1](https://arxiv.org/html/2605.18807#S3.I1.i1.p1.1)\.
- \[5\]S\. Hao, S\. Sukhbaatar, D\. Su, X\. Li, Z\. Hu, J\. Weston, and Y\. Tian\(2024\)Training large language models to reason in a continuous latent space\.arXiv preprint arXiv:2412\.06769\.Cited by:[§6](https://arxiv.org/html/2605.18807#S6.p2.1)\.
- \[6\]J\. Hoffmann, S\. Borgeaud, A\. Mensch, E\. Buchatskaya, T\. Cai, E\. Rutherford, D\. de Las Casas, L\. A\. Hendricks, J\. Welbl, A\. Clark, T\. Hennigan, E\. Noland, K\. Millican, G\. van den Driessche, B\. Damoc, A\. Guy, S\. Osindero, K\. Simonyan, E\. Elsen, J\. W\. Rae, O\. Vinyals, and L\. Sifre\(2022\)Training compute\-optimal large language models\.External Links:2203\.15556,[Link](https://arxiv.org/abs/2203.15556)Cited by:[§5](https://arxiv.org/html/2605.18807#S5.p2.2)\.
- \[7\]W\. Kwon, Z\. Li, S\. Zhuang, Y\. Sheng, L\. Zheng, C\. H\. Yu, J\. E\. Gonzalez, H\. Zhang, and I\. Stoica\(2023\)Efficient memory management for large language model serving with pagedattention\.InProceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles,Cited by:[item 4](https://arxiv.org/html/2605.18807#S3.I1.i4.p1.3)\.
- \[8\]A\. Labovich\(2026\)Stability and generalization in looped transformers\.External Links:2604\.15259,[Link](https://arxiv.org/abs/2604.15259)Cited by:[item 1](https://arxiv.org/html/2605.18807#S3.I1.i1.p1.1),[§6](https://arxiv.org/html/2605.18807#S6.p2.1)\.
- \[9\]A\. Paszke, S\. Gross, F\. Massa, A\. Lerer, J\. Bradbury, G\. Chanan, T\. Killeen, Z\. Lin, N\. Gimelshein, L\. Antiga, A\. Desmaison, A\. Köpf, E\. Z\. Yang, Z\. DeVito, M\. Raison, A\. Tejani, S\. Chilamkurthy, B\. Steiner, L\. Fang, J\. Bai, and S\. Chintala\(2019\)PyTorch: an imperative style, high\-performance deep learning library\.CoRRabs/1912\.01703\.External Links:[Link](http://arxiv.org/abs/1912.01703),1912\.01703Cited by:[§4](https://arxiv.org/html/2605.18807#S4.SS0.SSS0.Px1.p1.1)\.
- \[10\]H\. Prairie, Z\. Novack, T\. Berg\-Kirkpatrick, and D\. Y\. Fu\(2026\)Parcae: scaling laws for stable looped language models\.External Links:2604\.12946,[Link](https://arxiv.org/abs/2604.12946)Cited by:[item 1](https://arxiv.org/html/2605.18807#S3.I1.i1.p1.1),[§6](https://arxiv.org/html/2605.18807#S6.p2.1)\.
- \[11\]A\. Radford, K\. Narasimhan, T\. Salimans, and I\. Sutskever\(2018\)Improving language understanding by generative pre\-training\.Technical reportOpenAI\.External Links:[Link](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf)Cited by:[§1](https://arxiv.org/html/2605.18807#S1.p1.1),[§2](https://arxiv.org/html/2605.18807#S2.p1.1)\.
- \[12\]C\. Raffel, N\. Shazeer, A\. Roberts, K\. Lee, S\. Narang, M\. Matena, Y\. Zhou, W\. Li, and P\. J\. Liu\(2023\)Exploring the limits of transfer learning with a unified text\-to\-text transformer\.External Links:1910\.10683,[Link](https://arxiv.org/abs/1910.10683)Cited by:[§1](https://arxiv.org/html/2605.18807#S1.p2.2),[§1](https://arxiv.org/html/2605.18807#S1.p3.1),[§2](https://arxiv.org/html/2605.18807#S2.p2.1),[§2](https://arxiv.org/html/2605.18807#S2.p4.1)\.
- \[13\]N\. Saunshi, N\. Dikkala, Z\. Li, S\. Kumar, and S\. J\. Reddi\(2025\)Reasoning with latent thoughts: on the power of looped transformers\.External Links:2502\.17416,[Link](https://arxiv.org/abs/2502.17416)Cited by:[item 1](https://arxiv.org/html/2605.18807#S3.I1.i1.p1.1)\.
- \[14\]D\. Soboleva, F\. Al\-Khateeb, R\. Myers, J\. R\. Steeves, J\. Hestness, and N\. Dey\(2023\-06\)SlimPajama: A 627B token cleaned and deduplicated version of RedPajama\.Note:[https://cerebras\.ai/blog/slimpajama\-a\-627b\-token\-cleaned\-and\-deduplicated\-version\-of\-redpajama](https://cerebras.ai/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama)External Links:[Link](https://huggingface.co/datasets/cerebras/SlimPajama-627B)Cited by:[§4](https://arxiv.org/html/2605.18807#S4.SS0.SSS0.Px3.p1.1)\.
- \[15\]Y\. Sun, L\. Dong, Y\. Zhu, S\. Huang, W\. Wang, S\. Ma, Q\. Zhang, J\. Wang, and F\. Wei\(2024\)You only cache once: decoder\-decoder architectures for language models\.External Links:2405\.05254,[Link](https://arxiv.org/abs/2405.05254)Cited by:[item 1](https://arxiv.org/html/2605.18807#S3.I1.i1.p1.1)\.
- \[16\]A\. Vaswani, N\. Shazeer, N\. Parmar, J\. Uszkoreit, L\. Jones, A\. N\. Gomez, Ł\. Kaiser, and I\. Polosukhin\(2017\)Attention is all you need\.InProceedings of the 31st International Conference on Neural Information Processing Systems,NIPS’17,Red Hook, NY, USA,pp\. 6000–6010\.External Links:ISBN 9781510860964Cited by:[§1](https://arxiv.org/html/2605.18807#S1.p1.1),[§2](https://arxiv.org/html/2605.18807#S2.p1.1)\.
- \[17\]G\. Yang, E\. J\. Hu, I\. Babuschkin, S\. Sidor, X\. Liu, D\. Farhi, N\. Ryder, J\. Pachocki, W\. Chen, and J\. Gao\(2022\)Tensor programs v: tuning large neural networks via zero\-shot hyperparameter transfer\.External Links:2203\.03466,[Link](https://arxiv.org/abs/2203.03466)Cited by:[§4](https://arxiv.org/html/2605.18807#S4.SS0.SSS0.Px2.p1.3)\.
- \[18\]M\. Zaheer, G\. Guruganesh, A\. Dubey, J\. Ainslie, C\. Alberti, S\. Ontanon, P\. Pham, A\. Ravula, Q\. Wang, L\. Yang, and A\. Ahmed\(2021\)Big bird: transformers for longer sequences\.External Links:2007\.14062,[Link](https://arxiv.org/abs/2007.14062)Cited by:[§2](https://arxiv.org/html/2605.18807#S2.p5.1)\.
## Appendix AAdditional Calculations and Results
### A\.1KV\-cache calculations at inference time
Here, we calculate the difference between a decoder\-only and dual\-stack \(encoder\-decoder, or double\-decoder\) model in terms of key architectural parameters\. We define the following variables:
- •d:d:hidden state size
- •LencL\_\{enc\}: layers in encoder / context\-decoder
- •LdecL\_\{dec\}: layers in decoder \(as part of encoder\-decoder\) / generation\-decoder
- •LL: layers in decoder\-only model
- •TinT\_\{in\}: number of tokens in context
- •ToutT\_\{out\}: number of tokens in output
- •bb: bytes per activation \(e\.g\. 2 in fp16\)
We assume MHA for attention and KV\-caching for simplicity\.
For a decoder\-only model, calculating the memory of the KV\-cache is simple: every token has both a key and value representation for every block\. Thus, memory requirement can be calculated as
2dbL\(Tin\+Tout\)2dbL\(T\_\{in\}\+T\_\{out\}\)For dual\-stack models, calculating is more complicated\. For self\-attention, each token inToutT\_\{out\}attends only to prior tokens inToutT\_\{out\}across all decoder\-layers, giving a cache of2dbLdecTout2dbL\_\{dec\}T\_\{out\}\. For cross\-attention, each layer has a key and value representation of the final encoder state; thus, the cache is2dbLdecTin2dbL\_\{dec\}T\_\{in\}\. Combined, this results in a cache of
2dbLdec\(Tin\+Tout\)2dbL\_\{dec\}\(T\_\{in\}\+T\_\{out\}\)As the only difference at inference\-time between double decoders and encoder\-decoders is the use of log\-sum\-exp rather than summing the attention blocks together, both possess the same memory advantages\. In particular, the memory cost in KV\-cache relative to a decoder\-only model can be represented simply as
LdecL\\frac\{L\_\{\\text\{dec\}\}\}\{L\}Since our models utilize a 2/3 1/3 dual\-stack split, this results in a𝟐𝟑\\mathbf\{\\frac\{2\}\{3\}\}KV\-cache memory reduction\.
### A\.2FLOP heuristic calculations
With variables as defined in the last section, we work through a single head in a single layer of self attention\. In a decoder only, there are 4 linear matrix multiplications, so8Td2,8Td^\{2\},two matrix multiplications for SDPA,4T2d,4T^\{2\}d,and the two FF layers contribute16Td2\.16Td^\{2\}\.Together, this is
24Td2\+4T2d\.24Td^\{2\}\+4T^\{2\}d\.We multiply byLLfor the number of layers, and since we are studying FLOPs at train time, we multiply by 3, for the forward pass, and then both backwards passes \(for weights and for activations\)\. Hence
L\(72Td2\+12T2d\)\.L\(72Td^\{2\}\+12T^\{2\}d\)\.The formula for encoder\-decoder is calculated in a similar way\. Since our models utilize a2/32/31/31/3split, we multiply by2L3\\frac\{2L\}\{3\}for the encoder andL3\\frac\{L\}\{3\}for the decoder\. The FLOP count for a single encoder layer is identical to that of a single decoder, so
3⋅23L\(24Tind2\+4Tin2d\)=L\(48Tind2\+8Tin2d\)\.3\\cdot\\frac\{2\}\{3\}L\(24T\_\{in\}d^\{2\}\+4T\_\{in\}^\{2\}d\)=L\(48T\_\{in\}d^\{2\}\+8T\_\{in\}^\{2\}d\)\.For the decoder, the dual cross\-attention on the encoder’s output and self\-attention on the smaller token size complicates the calculation somewhat, but we find
L\(28Toutd2\+4Tout2d\+4TinToutd\+4Tind\)\.L\(28T\_\{out\}d^\{2\}\+4T\_\{out\}^\{2\}d\+4T\_\{in\}T\_\{out\}d\+4T\_\{in\}d\)\.Summing these two yields
L\(\(52Tin\+28Tout\)d2\+\(4Tout2\+4TinTout\+8Tin2\)d\)\.L\(\(52T\_\{in\}\+28T\_\{out\}\)d^\{2\}\+\(4T\_\{out\}^\{2\}\+4T\_\{in\}T\_\{out\}\+8T\_\{in\}^\{2\}\)d\)\.Finally, for the double decoder architecture, the context\-decoder is a standard causal decoder, so the FLOP count is once again
L\(48Td2\+8T2d\)\.L\(48Td^\{2\}\+8T^\{2\}d\)\.The generation\-decoder has six linear matrix multiplications and 2 SDPA modules, at least in our naive implementation\. However, a theoretically optimal matrix multiplication algorithm would bring the computation down to the equivalent of 1 SDPA, because both decoders have a causal mask, so the necessary multiplications of query, key, and value matrices are equivalent to an unmasked attention head\. Hence, the linear matmuls yield12Td2,12Td^\{2\},the SDPA yields4T2d,4T^\{2\}d,and the two FF layers add16Td2\.16Td^\{2\}\.Hence, the generation\-decoder uses
L\(28Td2\+4T2d\),L\(28Td^\{2\}\+4T^\{2\}d\),altogether
L\(76Td2\+12T2d\)\.L\(76Td^\{2\}\+12T^\{2\}d\)\.
### A\.3μ\\muP hyperparameter sweeps
The following graphs showcase our findings for sweeps on learning rate and weight decay usingμ\\muP\.
Figure 6:Graphs showcasing the results ofmumuP learning rate transfer\. Left graphs find the best learning rate at the smallest model \(0\.5M\)\. Middle graphs compare that learning rate with others for larger models\. Right graphs confirm that the best learning rate remains constant across scales\.Figure 7:Graphs showcasing the result of weight decay sweeps after finding the ideal learning rate\.
### A\.4Common hyperparameter table
CategoryHyperparameterValueDataSequence length2048 tokensDataTokenizer32k BPE tokenizerInitializationWeight initializationXavier uniformArchitectureInput/output embeddingsTiedArchitectureAttention head dimension64ArchitecturePositional encodingRotary positional embeddingsOptimizationOptimizerAdamWOptimizationAdamW betas\(0\.9,0\.95\)\(0\.9,0\.95\)OptimizationAdamW epsilon10−810^\{\-8\}OptimizationLR scheduleLinear warmup for 5% of steps, then linear decayOptimizationFinal LR fraction0\.1 of peak LROptimizationGradient clippingGlobal norm clipped to 1\.0μ\\muPBase width64μ\\muPHidden\-weight LR ruleηhidden=ηbase⋅64/d\\eta\_\{\\mathrm\{hidden\}\}=\\eta\_\{\\mathrm\{base\}\}\\cdot 64/dμ\\muPEmbedding/output/norm LRηbase\\eta\_\{\\mathrm\{base\}\}PretrainingBase LR sourcePer\-architecture tuned LR, unless overriddenPretrainingWeight decay sourcePer\-architecture tuned WD, unless overriddenPrefixLM SFTBase LR2×10−42\\times 10^\{\-4\}PrefixLM SFTHidden\-weight LR2×10−4⋅64/d2\\times 10^\{\-4\}\\cdot 64/dPrefixLM SFTSFT token budget10% of pretraining tokens by defaultPrefixLM SFTEffective batch size32 sequencesTable 1:Hyperparameters shared across model families\. Heredddenotes model width\. Architecture size, layer allocation, pretraining objective, and collator differ by model family and are therefore omitted\.Similar Articles
WAV: Multi-Resolution Block Residual Routing for Deep Decoder-Only Transformers
This paper introduces Multi-Resolution Residual Routing (WAV v1), an extension of Block Attention Residuals that augments block representations with directional detail bases, improving deep decoder-only Transformer training.
Learning to Skip Blocks: Self-Discovered Ultrametric Routing for Hardware-Accelerated Sparse Attention
This paper introduces Dynamic Ultrametric Attention, a framework where Transformers learn per-head block-sparse routing topologies during training, which are then offloaded to a custom Triton block-sparse kernel at inference time, achieving up to 28x speedup and 98.4% memory reduction over dense attention.
Memory-Efficient Looped Transformer: Decoupling Compute from Memory in Looped Language Models
Proposes Memory-Efficient Looped Transformer (MELT), a novel recurrent LLM architecture that decouples reasoning depth from memory consumption by sharing a single KV cache across loops and using chunk-wise training with interpolated transition and attention-aligned distillation.
BudgetDraft: Acceptance-Aware Multi-View Training for Sparse-KV Speculative Decoding
BudgetDraft proposes a multi-view training method for speculative decoding that aligns a sparse-KV drafter with a full-KV verifier, achieving significant speedups for mid-to-long context inference.
ResBM: a new transformer-based architecture for low-bandwidth pipeline-parallel training, achieving 128× activation compression [R]
ResBM introduces a transformer-based architecture with residual encoder-decoder bottlenecks for pipeline-parallel training, achieving 128× activation compression while maintaining convergence. The work advances decentralized, internet-grade distributed training by reducing inter-stage communication overhead.