学习的中继表示用于前瞻性离散扩散模型

arXiv cs.LG 论文

摘要

本文介绍了学习的中继表示(Relay),一种使掩码扩散模型能够在去噪步骤之间传播潜在信息的方法,克服了硬重置问题并改善了性能-延迟权衡。该方法在编码任务上优于标准的监督微调,同时将推理延迟降低高达32%。

arXiv:2605.22967v1 公告类型:新 摘要:当掩码扩散模型(MDM)通过迭代精炼生成序列时,掩码位置上的丰富内部计算被丢弃,迫使每个后续精炼步骤重新计算存储为模型表示的宝贵内部信息。为了避免去噪轮次之间的硬重置,我们提出了学习的中继表示(Relay),一种通过显式学习如何传播潜在信息以利于未来去噪步骤,使MDM在去噪时具有前瞻性的方法。Relay引入了一个可微的每令牌通道,在前向传播之间传递信息,并通过截断的时间反向传播(BPTT)进行训练。我们证明该框架可以扩展到最先进的扩散语言模型(DLM),并与块扩散和KV缓存等技术无缝兼容。我们首先在基于数独的挑战性规划任务上对Relay的设计选择进行了彻底验证。然后,我们将Relay扩展到最先进的DLM Fast-dLLM v2,在编码任务上优于标准监督微调,同时将推理延迟降低高达32%。我们的实验结果表明,最先进的DLM可以显式地训练以跨解码步骤向前中继潜在信息,推动性能-延迟的帕累托前沿。我们提供了所有实验的代码。
查看原文
查看缓存全文

缓存时间: 2026/05/25 08:57

# 学习的中继表征:面向前向思考的离散扩散模型
来源:https://arxiv.org/html/2605.22967
Benjamin Rozonoyer¹,Jacopo Minniti²¹¹,Dhruvesh Patel¹¹¹,Neil Band³,Avishek Joey Bose⁴,⁵,Tim G. J. Rudner²,⁶,Andrew McCallum¹
¹马萨诸塞大学阿默斯特分校
²多伦多大学
³斯坦福大学
⁴帝国理工学院
⁵Mila
⁶Vijil

###### 摘要

当掩码扩散模型(MDMs)通过迭代精炼生成序列时,掩码位置上的丰富内部计算被丢弃——这迫使后续的每一个精炼步骤都要重新计算存储在模型表征中的宝贵内部信息。为了避免去噪轮次之间的硬重置,我们提出学习的中继表征(Relay),这是一种使MDMs在去噪时能够“前向思考”的方法——*显式地学习如何传播潜在信息,以利于未来的去噪步骤*。Relay引入了一个可微分的每词元通道,可以在前向传递之间传递信息,并通过截断的时间反向传播(BPTT)进行训练。我们展示了该框架可以扩展到最先进的扩散语言模型(DLMs),并且与块扩散和KV缓存等技术无缝兼容。我们首先在具有挑战性的基于数独的规划任务上,对Relay的设计选择进行了详尽的验证。然后,我们将Relay扩展到最先进的DLM——Fast-dLLM v2,在编码任务上超越了标准的监督微调,同时将推理延迟降低了多达32%。我们的实证结果表明,最先进的DLM可以被显式地训练,在解码步骤间向前中继潜在信息,从而推进了性能-延迟的帕累托前沿。我们为所有实验提供了代码。

## 1 引言

掩码扩散模型(MDMs)通过迭代去噪生成离散序列(Austin et al., 2021 (https://arxiv.org/html/2605.22967#bib.bib1); Campbell et al., 2022 (https://arxiv.org/html/2605.22967#bib.bib2); Sahoo et al., 2024 (https://arxiv.org/html/2605.22967#bib.bib3); Shi et al., 2024 (https://arxiv.org/html/2605.22967#bib.bib4)):从一个完全掩码的画布开始,每次前向传递会揭开剩余位置中的一部分。Transformer在每个位置(包括仍然被掩码的位置)计算隐藏状态,但在每一步结束时丢弃它们,下一次传递仅从部分解掩码的序列开始。我们称此为*硬重置*问题:跨步骤持久化的唯一信息是刚刚确定的离散词元,这使得MDMs无法累积中间连续计算。

这之所以重要,是因为循环计算——在一个固定参数的模型上展开多个步骤——正是近期工作证明与困难推理任务上性能提升相关的结构特性,因为它有效扩展了模型可以逼近的函数类别(Gatmiry et al., 2024 (https://arxiv.org/html/2605.22967#bib.bib5); Saunshi et al., 2025 (https://arxiv.org/html/2605.22967#bib.bib6); Li et al., 2024 (https://arxiv.org/html/2605.22967#bib.bib7))。MDMs每次生成已经执行多次前向传递;硬重置阻止了这些计算被重用。

这引出一个自然的问题:MDMs的顺序解掩码结构如何支持在步骤间携带更丰富信息的循环计算?

我们的答案是学习的中继表征(Relay),一种使离散扩散模型具有*前向思考*能力的方法:在每个去噪步骤,除了任何新解掩码的词元外,模型还将其最后一层隐藏状态作为学习到的*中继*向前传递,使下一次前向传递能够直接访问上一步的连续计算。然而,简单地向前传递这些状态本身并不能确保它们对后续步骤编码任何有用信息。因此,Relay通过截断的时间反向传播(BPTT; Werbos, 1990 (https://arxiv.org/html/2605.22967#bib.bib8))对中继进行端到端训练,使其对后续几个去噪步骤具有最大信息量,并在解掩码轨迹上实现一种潜在思维链形式。

贡献。我们提出Relay,它为MDMs配备了学习的中继表征——跨解码步骤向前传递的连续潜在状态,并通过截断的BPTT进行端到端训练。Relay是架构无关的,并且不改变MDMs的推理时解码过程(解掩码调度、采样);推理时的唯一增加是将中继与已确定的词元一起向前传递。它还与主流的DLM加速技术兼容,包括块扩散(Arriola et al., 2025 (https://arxiv.org/html/2605.22967#bib.bib9))和KV缓存(Wu et al., 2025a (https://arxiv.org/html/2605.22967#bib.bib10), b (https://arxiv.org/html/2605.22967#bib.bib11))。

总结来说,我们的主要贡献如下:

1.  1. 我们提出Relay,一种在MDMs中融入循环计算的通用方法,通过截断的BPTT训练模型,使其在解码步骤间向前传递学习到的潜在中继。Relay可以从头训练MDM,或通过轻量级适配微调预训练的MDM。
2.  2. 我们通过对Fast-dLLM v2 1.5B(Wu et al., 2025b (https://arxiv.org/html/2605.22967#bib.bib11))进行全参数适配,在LLM规模上验证了Relay,在编码任务上优于标准监督微调,同时将推理延迟降低了多达32%。
3.  3. 我们进行了广泛的消融实验,绘制了Relay的设计空间图,并验证了我们的选择。

## 2 背景:掩码扩散模型

我们通过训练掩码扩散模型传递学习到的中继状态来解决硬重置问题。在介绍我们的方法Relay之前,我们回顾一下Relay所基于的掩码扩散模型(MDMs)(Shi et al., 2024 (https://arxiv.org/html/2605.22967#bib.bib4); Sahoo et al., 2024 (https://arxiv.org/html/2605.22967#bib.bib3))的训练和推理过程。

符号表示。我们将词汇表表示为V\{\mathcal{V}\},包括[M]词元。长度为LL的序列在词汇表上的空间为VL\{\mathcal{V}\}^L。上标表示序列中的位置,例如,xix^i是序列x∈VL\{\bm{x}\}\in\{\mathcal{V}\}^L中的第ii个词元。M(x)⊆[L]\mathcal{M}(\{\bm{x}\})\subseteq[L]表示序列x\{\bm{x}\}中被掩码位置的集合。

训练。加噪过程通过采样一个时间t∈[0,1]t\in[0,1]并以概率αt\alpha_t独立地掩码干净序列x0∈(V∖{[M]})L\{\bm{x}\}_0\in(\{\mathcal{V}\}\setminus\{\texttt{[M]}\})^L中的每个位置,得到加噪(部分掩码)序列xt\{\bm{x}\}_t。逐坐标的后验分布P(X0i=x0i|Xt=xt)\mathbb{P}(X_0^i=x_0^i\,\vert\,\bm{X}_t=\{\bm{x}\}_t)记为p(x0i∣xt)p(x_0^i\mid\{\bm{x}\}_t)。如Zheng et al. (2024 (https://arxiv.org/html/2605.22967#bib.bib12))所述,该后验仅通过其掩码模式和已揭示的词元依赖于xt\{\bm{x}\}_t,而不依赖于时间tt本身。逐坐标后验由神经网络参数化,记为pθi(⋅∣xt)∈Δp_\theta^i(\cdot\mid\{\bm{x}\}_t)\in\Delta,其中i∈M(xt)i\in\mathcal{M}(\{\bm{x}\}_t),并通过最小化每个掩码位置的加权交叉熵损失之和进行训练¹¹我们假设了一个线性噪声调度。:

L(θ)=Ex0,t,xt[1t∑i:xti=[M]−log pθi(x0i∣xt)].\mathcal{L}(\theta)=\mathop{\mathbb{E}}_{\{\bm{x}\}_0,t,\{\bm{x}\}_t}\left[\frac{1}{t}\sum\nolimits_{i:\{\bm{x}\}_t^i=\texttt{[M]}}-\log p^i_\theta(x_0^i\mid\{\bm{x}\}_t)\right].(1)逐坐标参数化后验通过嵌入Embθ:V→Rd\text{Emb}_\theta:\{\mathcal{V}\}\to\mathbb{R}^d,解嵌入UnEmbθ:Rd→R|V|\text{UnEmb}_\theta:\mathbb{R}^d\to\mathbb{R}^{|\{\mathcal{V}\}|},以及Transformer主干fθ:VL→RL×df_\theta:\{\mathcal{V}\}^L\to\mathbb{R}^{L\times d}实现,它们产生后验分布:

pθi(w∣xt)\displaystyle p_\theta^i(w\mid\{\bm{x}\}_t)=eli(w)∑w′∈Veli(w′),其中li(w)=UnEmbθ(fθ(Embθ(xt)))wi.\displaystyle=\frac{e^{\ell^i(w)}}{\sum_{w'\in\{\mathcal{V}\}}e^{\ell^i(w')}},\quad\text{其中}\quad\ell^i(w)=\text{UnEmb}_\theta(f_\theta(\text{Emb}_\theta(\{\bm{x}\}_t)))^i_w.
推理。生成过程沿着递减的时间网格1=t0>t1>⋯>tK=01=t_0>t_1>\cdots>t_K=0进行,从全掩码序列xt0=([M],...,[M])\{\bm{x}\}_{t_0}=(\texttt{[M]},\ldots,\texttt{[M]})迭代解掩码,直到完全解掩码的序列xtK∈(V∖{[M]})L\{\bm{x}\}_{t_K}\in(\{\mathcal{V}\}\setminus\{\texttt{[M]}\})^L。在每一步kk,给定当前部分掩码序列xtk\{\bm{x}\}_{t_k},模型为每个被掩码位置i∈M(xtk)i\in\mathcal{M}(\{\bm{x}\}_{t_k})和词元w∈Vw\in\{\mathcal{V}\}计算每个位置后验分布的logitslk\bm{\ell}_k。一个解掩码策略u(⋅∣lk,xtk)u(\cdot\mid\bm{\ell}_k,\{\bm{x}\}_{t_k})(可以是随机的)然后选择一组位置Uk⊆M(xtk)\mathcal{U}_k\subseteq\mathcal{M}(\{\bm{x}\}_{t_k})进行揭示,产生下一个部分掩码序列xtk+1\{\bm{x}\}_{t_{k+1}}。u(⋅∣lk,xtk)u(\cdot\mid\bm{\ell}_k,\{\bm{x}\}_{t_k})的常见选择包括在每一步解掩码固定比例的剩余掩码(Nie et al., 2025 (https://arxiv.org/html/2605.22967#bib.bib13))以及基于置信度的规则(Ben-Hamu et al., 2025 (https://arxiv.org/html/2605.22967#bib.bib14); Kim et al., 2025 (https://arxiv.org/html/2605.22967#bib.bib15); Patel et al., 2025 (https://arxiv.org/html/2605.22967#bib.bib16))。

硬重置问题。在每次推理步骤后,MDMs丢弃用于选择新揭示词元的整个计算状态。下一步仅从xtk+1\{\bm{x}\}_{t_{k+1}}重新开始。因此,标准MDM推理将每个部分掩码序列视为一个新的预测问题——一个*硬重置*——而不是一个正在进行计算的延续。由于模型每次前向传递只能执行固定数量的FLOPs,硬重置阻止了模型有效地在步骤间分摊推理。在下一节中,我们提出解决此问题的方法:我们学习一个连续的潜在状态,该状态在MDM推理的步骤间传递,并能够绕过硬重置。

## 3 学习的中继表征

为了解决*硬重置*问题,我们引入了一个连续的、可微分的状态,该状态在MDM推理步骤间携带,并且可以绕过硬重置。

### 3.1 增强状态轨迹

MDMs的训练过程是:采样一个数据点x0∼pdata\{\bm{x}\}_0\sim p_{\mathrm{data}},一个时间t∼U(0,1)t\sim\mathcal{U}(0,1),以及在该时间tt和数据点x0\{\bm{x}\}_0下根据噪声调度得到部分掩码序列xt\{\bm{x}\}_t。在推理过程中,我们有一个离散化的时间网格1=t0>⋯>tn=01=t_0>\cdots>t_n=0,以及通过使用某个解掩码策略uu得到的对应推理轨迹xt0,...,xtn\{\bm{x}\}_{t_0},\ldots,\{\bm{x}\}_{t_n},其中xt0={[M]}L\{\bm{x}\}_{t_0}=\{\texttt{[M]}\}^L。我们希望将一个连续状态向前传递跨越解码步骤,该状态可以携带上一步中尚未实现为解码词元的中间计算。我们可以将此行为分解为两个原语:模型必须在推理步骤kk产生一个中继状态hk\{\bm{h}\}_k,并学会在步骤k+1k+1消费该中继状态。图̃1 (https://arxiv.org/html/2605.22967#S3.F1)展示了模型产生的增强状态轨迹的示意图,其中sk=(xtk,hk)\{\bm{s}\}_k=(\{\bm{x}\}_{t_k},\{\bm{h}\}_k)是步骤kk的增强状态。

### 3.2 训练

架构。我们使用主干fθf_\theta、中继模块RθR_\theta、词元嵌入Embθ\text{Emb}_\theta和解嵌入头UnEmbθ\text{UnEmb}_\theta(见图̃1 (https://arxiv.org/html/2605.22967#S3.F1))来参数化增强动力学。在步骤kk,模型将当前对(xtk,hk)(\{\bm{x}\}_{t_k},\{\bm{h}\}_k)映射到下一个中继状态和每个位置的logits,通过

hk+1\displaystyle\{\bm{h}\}_{k+1}=fθ(Embθ(xtk)+Rθ(hk)),\displaystyle=f_\theta\!\left(\text{Emb}_\theta(\{\bm{x}\}_{t_k})+R_\theta(\{\bm{h}\}_k)\right),lk\displaystyle\bm{\ell}_k=UnEmbθ(hk+1),\displaystyle=\text{UnEmb}_\theta(\{\bm{h}\}_{k+1}),(2)初始化为h0=0\{\bm{h}\}_0=\{\bm{0}\}。每个位置后验pθi(⋅∣xtk,hk)p_\theta^i(\cdot\mid\{\bm{x}\}_{t_k},\{\bm{h}\}_k)通过softmax从lk\bm{\ell}_k读出,与标准MDMs完全相同。

由于我们只关心最终状态xtn\{\bm{x}\}_{t_n},我们继续使用与标准MDMs相同的交叉熵损失提供监督,并使用截断BPTT训练模型产生有助于改进未来KK步预测的有用中继状态hk\{\bm{h}\}_k。具体来说,我们不像标准MDMs那样采样xt\{\bm{x}\}_t,而是从全掩码序列xt0={[M]}L\{\bm{x}\}_{t_0}=\{\texttt{[M]}\}^L开始,在方程̃2 (https://arxiv.org/html/2605.22967#S3.E2)下结合解掩码策略uu(见下文)展开,产生增强轨迹(xt0,h0),...,(xtn,hn)(\{\bm{x}\}_{t_0},\{\bm{h}\}_0),\ldots,(\{\bm{x}\}_{t_n},\{\bm{h}\}_n)。总训练损失是轨迹上每步交叉熵的期望和:

L(θ)\displaystyle\mathcal{L}(\theta)=Ex0,ξ0:n−1[∑k=0n−1∑i∈M(xtk)−log pθi(x0i∣xtk,hk)],\displaystyle=\mathop{\mathbb{E}}_{\{\bm{x}\}_0,\,\xi_{0:n-1}}\!\left[\,\sum_{k=0}^{n-1}\,\sum_{i\in\mathcal{M}(\{\bm{x}\}_{t_k})}-\log p^i_\theta\!\left(x_0^i\mid\{\bm{x}\}_{t_k},\{\bm{h}\}_k\right)\right],(3)其中ξ0:n−1\xi_{0:n-1}表示解掩码策略沿展开所使用的外生随机性。与外部观察到的条件变量不同,hk\{\bm{h}\}_k是展开的内部产物,是计算轨迹的一部分,而不是生成对象的一部分。在推理时,每一步将实现的对(xtk,hk)(\{\bm{x}\}_{t_k},\{\bm{h}\}_k)向前传递,但只有xtk\{\bm{x}\}_{t_k}最终解码为文本,而hk\{\bm{h}\}_k则作为未来预测的可微分记忆通道。完整过程总结在算法̃1 (https://arxiv.org/html/2605.22967#algorithm1)中;我们在下面推导梯度估计器。

**输入**:模型

fθf_\theta,中继模块

RθR_\theta,展开长度

KK,解掩码策略

uu,训练步数

NN,学习率

η\eta
1

2**对于** *t∈{1,...,N}t\in\{1,\ldots,N\}* **执行**

3**如果** *t=1t=1* 或 *M(z)=∅\mathcal{M}(\{\bm{z}\})=\emptyset* **则**

4

x0∼pdata

相似文章

基于时空并行解码与置信度外推的高效扩散LLMs

arXiv cs.CL

本文介绍了时空并行解码(TSPD)和置信度外推(CE),通过动态判断令牌何时收敛并预测logit趋势,来加速基于扩散的大语言模型的推理,减少不必要的去噪步骤,同时保持输出质量。

可学习性引导的扩散语言模型微调

arXiv cs.CL

我们提出LIFT,一种可学习性引导的扩散语言模型微调算法,该算法根据 token 难度和时间步对齐训练,在推理基准测试上取得了显著提升。

重新思考扩散Transformer中的跨层信息路由

Hugging Face Daily Papers

本文提出扩散自适应路由(DAR),这是一种可学习的、时间步自适应的残差替换方法,旨在改善扩散Transformer中的跨层信息流动,从而显著加速训练并提升质量。