用于优化离散扩散语言模型的漂移目标

arXiv cs.CL 论文

摘要

本文提出TokenDrift,一种漂移目标方法,通过将分类预测提升至连续语义空间进行反对称漂移,从而优化离散扩散语言模型。在固定去噪步数下,该方法显著提升了生成质量。

arXiv:2605.19470v1 公告类型:新 摘要:离散扩散语言模型(DDLMs)通过迭代去噪分类令牌序列来生成文本,而近期针对连续生成器的漂移方法表明,此类采样时校正可部分通过反对称不动点目标吸收到训练中。我们研究了如何将此原理迁移至DDLMs,主要挑战在于与离散文本的接口:硬令牌样本不可微,且分类预测无法直接提供可漂移的连续样本。我们制定了TokenDrift,一种漂移目标,它将分类预测提升为软令牌特征,在冻结语义空间中应用反对称漂移,并将所得停止梯度特征目标反向传播至DDLMs的logits。在基于掩码和均匀状态扩散骨干的受控持续训练实验中,TokenDrift改善了固定NFE下的生成质量,相比匹配的持续基线,在MDLM和DUO上4次NFE的Gen.-PPL分别降低89%和86%。这些结果表明漂移可为DDLMs提供实用的优化目标。
查看原文
查看缓存全文

缓存时间: 2026/05/20 08:25

# 用于精炼离散扩散语言模型的漂移目标
来源:https://arxiv.org/html/2605.19470
大场大辅¹ 古田弘树 冈崎直明¹,²,³ ¹东京科学大学 ²产业技术综合研究所 ³国立信息学研究所 LLMC \{daisuke\.oba@nlp\.,okazaki@\}comp\.isct\.ac\.jp

###### 摘要

离散扩散语言模型(DDLM)通过迭代去噪分类令牌序列来生成文本,而近期针对连续生成器的漂移方法表明,部分采样时的校正可以通过反反对称不动点目标吸收到训练中。我们研究如何将此原则迁移到DDLM,主要挑战在于与离散文本的接口:硬令牌样本不可微,且分类预测无法直接提供连续样本进行漂移。我们提出了TokenDrift,一种漂移目标,它将分类预测提升为软令牌特征,在冻结的语义空间中应用反对称漂移,并将所得的停止梯度特征目标反向传播到DDLM logits。在受控的持续训练实验中,使用掩码扩散和均匀状态扩散主干网络,TokenDrift在固定NFE下提升了生成质量,相比匹配的持续训练基线,在4个NFE下,MDLM的Gen.-PPL降低了89%,DUO降低了86%。这些结果表明漂移可以为DDLM提供一种实用的精炼目标。

项目页面:https://daioba.github.io/tokendrift/

## 1 引言

离散扩散语言模型(DDLM)通过迭代去噪被破坏的令牌序列,为从左到右生成提供了一种非自回归的替代方案[1, 13, 22, 25, 23, 17, 29, 31, 5, 24]。因此,它们的实际行为不仅受学习到的去噪器影响,还受在固定去噪步数下达到的质量影响。虽然许多工作通过设计新的调度、采样器或蒸馏生成器来改进采样,但我们提出一个互补的问题:*训练目标本身能否精炼现有的DDLM,使得在相同推理预算下,同一采样器能产生更好的样本?*

近期基于漂移的连续生成模型方法为这个问题提供了一个有前景的训练端视角[4]。它们用固定点训练目标(图1,左)替代了通常在采样过程中进行的部分迭代校正:生成的样本沿着吸引-排斥场移动,朝向附近的数据样本,远离附近的模型样本。这对于DDLM精炼很有吸引力,因为它直接针对生成的样本本身,而不仅仅是重建被破坏的令牌,从而为在固定推理预算下提高样本质量提供了一个直接目标。

然而,将漂移迁移到DDLM并非直接替换。在连续漂移中,生成的样本或其特征表示可以沿着漂移方向微调,并用作停止梯度训练目标。对于文本,生成器输出的是令牌上的分类分布。如果这些分布在特征提取前被坍缩为硬令牌,那么得到的特征空间损失就不再能为模型logits提供有用的梯度。因此,将漂移应用于DDLM需要从分类预测到定义漂移的连续特征空间之间的可微桥梁。

见图注图1:我们用于离散扩散语言模型的漂移公式概述。原始漂移通过将生成特征\(h\)沿着漂移场\(V\)移动来构造停止梯度目标\(h^\star\)。对于离散文本,硬令牌采样会阻断梯度,因此我们将令牌概率提升为软嵌入,在特征空间中计算漂移目标,并将损失反向传播到logits。我们提出了TokenDrift,一种通过软令牌特征精炼DDLM的漂移目标(图1,右)。TokenDrift不采样硬令牌,而是将从模型分类预测中计算的期望令牌嵌入输入到冻结的语义编码器。这创建了一条从特征空间漂移损失回到DDLM logits的可微路径。在该特征空间中,我们估计与连续漂移相同的吸引-排斥漂移,并训练生成器朝着停止梯度漂移特征目标前进。

我们研究的核心部分是公式化。有几种合理的方式可以将特征空间漂移与分类生成器连接起来:可以在特征空间中直接匹配漂移目标,将漂移转换为logit空间中的镜像教师,或者将任一目标与原始去噪损失相结合。我们在第4.3节中在匹配预算下比较了这些替代方案,发现直接的特征空间漂移能在基于似然的质量和熵之间提供最稳定的权衡。我们进一步表明软令牌提升是必要的:用硬直通令牌代理[3]替换它会严重降低生成质量,表明在语义特征空间中保留预测不确定性对于有效漂移很重要。

在OpenWebText(OWT)[7]上的受控持续训练实验中,TokenDrift在固定NFE下显著提升了生成质量,无论是相对于预训练的掩码扩散语言模型MDLM[22],还是普通的持续训练。相同的目标也改进了均匀状态离散扩散主干网络DUO[23],表明该效果并非掩码扩散所特有。这些结果共同将TokenDrift定位为现有DDLM的精炼目标:我们不设计新的采样器或蒸馏低NFE学生,而是保持模型类和采样器固定,通过训练目标改进生成器。

贡献。总之,我们将漂移公式化为DDLM的可训练精炼目标,确定了使其对离散文本有效所需的设计选择,并展示了对掩码扩散和均匀状态扩散语言模型的受控改进。

## 2 预备知识

我们简要回顾为DDLM制定漂移目标所需的要素:分类去噪模型、基于漂移的固定点学习,以及将连续漂移直接应用于文本的困难。

### 2.1 离散扩散语言模型

令\(\mathcal{V}\)为词汇表,且\(x=(x_1,\dots,x_L)\in\mathcal{V}^L\)为令牌序列。由\(\theta\)参数化的DDLM定义了令牌序列上的破坏过程,并训练一个去噪器\(f_\theta\)从破坏的输入预测干净令牌。在每个去噪步骤,模型输出logits \(\ell_\theta=f_\theta(x_t)\in\mathbb{R}^{L\times|\mathcal{V}|}\),这引出了位置上的分类分布\(p_{\theta,t}=\mathrm{softmax}(\ell_{\theta,t})\in\Delta^{|\mathcal{V}|-1}\)。

不同的DDLM以不同方式实例化破坏过程:掩码扩散将令牌破坏为掩码状态,而均匀状态扩散将令牌破坏为词汇表上的均匀分布。在两种情况下,去噪器都预测分类令牌分布,生成过程通过将去噪器应用于选定的步数来进行。

### 2.2 基于漂移的固定点学习

漂移为连续生成器定义了一个固定点训练规则[4]。对于生成的样本或特征\(y\),令\(V(y;P_\mathrm{data},P_\theta)\)为漂移场,通过朝向附近数据样本的吸引和远离附近模型样本的排斥来估计。漂移形成一个停止梯度目标

\[y^\star=\mathrm{sg}\!\left(y+\alpha V(y;P_\mathrm{data},P_\theta)\right),\]

并训练生成器匹配该目标。当原始样本空间距离无意义时,相同的构造可以在冻结的特征空间中应用。

一个关键的结构特性是反对称性:交换吸引和排斥分布会反转漂移方向。因此,如果模型和数据分布一致,则吸引和排斥抵消,\(V(\cdot;P,P)=0\),所以漂移信号在平衡时消失。我们的目标是在DDLM中保留这种固定点结构。

### 2.3 离散文本接口

连续漂移假设生成的样本或特征可以加性移动并通过微分。DDLM输出的是分类令牌分布:硬令牌化会阻断到logits的梯度,而直接的概率更新必须满足单纯形约束。因此,将漂移应用于DDLM需要从分类预测到定义漂移的连续特征空间之间的可微桥梁。

## 3 TokenDrift

我们将漂移公式化为DDLM的精炼目标。给定分类令牌预测,TokenDrift将其提升为软令牌特征,在冻结的语义空间中计算反对称吸引-排斥漂移,并训练生成器朝向停止梯度漂移特征目标。这使得特征空间漂移对令牌logits可微。

符号。我们使用\(i\)表示样本,\(t\)表示令牌位置,\(v\)表示词汇索引。对于样本\(i\),生成器输出logits \(\ell_i\!\in\!\mathbb{R}^{L\times|\mathcal{V}|}\)和分布\(p_i\!=\!\mathrm{softmax}(\ell_i)\!\in\!\mathbb{R}^{L\times|\mathcal{V}|}\),其中\(p_{i,t}\in\Delta^{|\mathcal{V}|-1}\)。

### 3.1 软令牌特征提升

令\(E\in\mathbb{R}^{|\mathcal{V}|\times d}\)为由冻结语义编码器\(\phi\)使用的嵌入矩阵,\(x_0^{(i)}\)为干净序列,\(\bar{x}^{(i)}\)为其破坏输入。给定\(\bar{x}^{(i)}\),生成器输出分布\(p_i\)。我们通过设置\(\tilde{e}_{i,t}=p_{i,t}E\)用于预测位置\(t\in\mathcal{M}_i\),否则设置\(\tilde{e}_{i,t}=E[\bar{x}_t^{(i)}]\)来构成编码器输入,其中\(\mathcal{M}_i\)由底层扩散主干网络定义。生成的对应特征和相应的真实特征为

\[h_i=\phi(\tilde{e}_i)\in\mathbb{R}^m,\qquad u_i=\phi(E[x_0^{(i)}])\in\mathbb{R}^m.\]

因此,\(h_i\)是模型完成序列的特征,而\(u_i\)是对应干净序列的特征。由于\(\tilde{e}_i\)通过\(p_i\)依赖于\(\ell_i\)在预测位置上,特征空间漂移损失可以反向传播到logits。对于没有观测令牌位置的主干网络,如均匀状态扩散,我们设置\(\mathcal{M}_i=\{1,\dots,L\}\)并在所有位置使用软嵌入。

### 3.2 反对称漂移估计

对于每个生成的特征\(h_i\),我们从真实数据特征构建正参考集\(\mathcal{P}_i\),从生成特征构建负参考集\(\mathcal{N}_i\)。漂移遵循漂移模型[4]的吸引-排斥结构:正样本将样本拉向数据分布,而负样本将其推离当前模型分布。

对于\(u_j\in\mathcal{P}_i\)和\(v_k\in\mathcal{N}_i\),我们计算温度缩放的亲和度\(s^+_{ij}=-\|h_i-u_j\|_2^2/\tau\)和\(s^-_{ik}=-\|h_i-v_k\|_2^2/\tau\)。遵循原始漂移构造,我们联合归一化正负亲和度以获得权重\(W^+_{ij}\)和\(W^-_{ik}\)。相应的正负重心定义了温度\(\tau\)下的漂移场:

\[b_i^+=\sum\nolimits_j W^+_{ij} u_j,\quad b_i^-=\sum\nolimits_k W^-_{ik} v_k,\quad V_i^{(\tau)}=b_i^+-b_i^-.\]

因此,漂移从附近的模型特征指向附近的数据特征。

多温度漂移。我们为每个\(\tau\in\mathcal{T}\)计算\(V_i^{(\tau)}\),通过标量批次级RMS尺度归一化每个温度,并求平均:

\[s^{(\tau)}=\sqrt{\operatorname{mean}_i\|V_i^{(\tau)}\|_2^2+\epsilon},\quad V_i=\frac{1}{|\mathcal{T}|}\sum\nolimits_{\tau\in\mathcal{T}}\frac{V_i^{(\tau)}}{s^{(\tau)}}.\]

这防止了任何单一温度尺度主导最终漂移。

### 3.3 特征空间固定点目标

我们的主要目标是直接的特征空间固定点损失。给定当前生成特征\(h_i\)和漂移\(V_i\),我们构成停止梯度目标\(h_i^\star=\mathrm{sg}\!\left(h_i+\alpha V_i\right)\),其中\(\alpha>0\)是漂移尺度。漂移目标为

\[\mathcal{L}_\mathrm{drift}=\frac{1}{2B}\sum_{i=1}^B\left\|h_i-h_i^\star\right\|_2^2.\]

因为\(h_i^\star\)被冻结,特征空间梯度正比于\(-V_i\),所以最小化此损失将\(h_i\)推向漂移方向。通过软令牌提升,该特征空间信号反向传播到令牌logits。

与基础目标的关系。漂移目标可以单独使用,也可以与原始离散扩散训练目标结合使用。除非另有说明,我们的主要方法将漂移目标用作主要训练信号,而组合变体作为消融进行评估(第4.3节)。

### 3.4 通过镜像教师的替代公式

作为直接特征空间匹配的替代,我们还考虑将特征空间漂移转换为令牌级教师分布。令\(h_i=h(\ell_i)\)为由logits诱导的特征。我们定义

\[g_i=\nabla_{\ell_i}\!\left(h(\ell_i)^\top\mathrm{sg}(V_i)\right),\quad \ell_i^\star=\mathrm{sg}(\ell_i+\eta g_i),\quad p_i^\star=\mathrm{softmax}(\ell_i^\star).\]

Softmax将更新后的logits映射回有效的分类分布,产生一个单纯形感知的镜像教师。我们在预测位置\(\mathcal{M}_i\)上评估两种匹配损失:分布KL散度和logit空间MSE:

\[\mathcal{L}_{\mathrm{mirror-KL}}=\frac{1}{B}\sum_i\sum_{t\in\mathcal{M}_i}\mathrm{KL}(p_{i,t}^\star\|p_{i,t}),\quad \mathcal{L}_{\mathrm{mirror-MSE}}=\frac{1}{B}\sum_i\sum_{t\in\mathcal{M}_i}\|\ell_{i,t}^\star-\ell_{i,t}\|_2^2.\]

相似文章

GDSD:强化学习作为扩散语言模型的引导式降噪器自蒸馏

Hugging Face Daily Papers

GDSD提出了一种强化学习方法,直接从优势引导的自教师中蒸馏扩散语言模型的降噪器,避免了基于ELBO的似然代理带来的偏差。在规划、数学和编码基准上,比先前最先进的方法准确率提升高达+19.6%。

可学习性引导的扩散语言模型微调

arXiv cs.CL

我们提出LIFT,一种可学习性引导的扩散语言模型微调算法,该算法根据 token 难度和时间步对齐训练,在推理基准测试上取得了显著提升。

Discrete Stochastic Localization用于非自回归生成

arXiv cs.LG

提出离散随机定位(Discrete Stochastic Localization, DSL),一种用于非自回归文本生成的连续状态扩散框架,采用单位球面令牌嵌入和时步不变的降噪器,在OpenWebText上实现了比掩码离散扩散模型更好的分布忠实性。