Block-Based Double Decoders
摘要
提出了一种基于块的雙解碼器(block-based double decoders),这是一种使用双重因果块注意力掩码的新型Transformer架构,结合了解码器仅训练效率与编码器-解码器推理效率,实现了强大的扩展性能并减少了KV缓存内存。
arXiv:2605.18807v1 公告类型:新
摘要:编码器-解码器模型在推理时间上比仅解码器模型节省了大量资源,但其预训练目标存在稀疏监督和动态序列长度的问题,导致在大规模实践中难以应用。我们提出了基于块的雙解碼器(block-based double decoders),这是一种利用双重因果块注意力掩码的新型Transformer架构,通过全损失监督和静态序列打包进行训练,结合了解码器仅训练效率与编码器-解码器推理效率。在扩展律实验中,基于块的雙解碼器在性能上显著优于编码器-解码器,并在不同规模上紧密跟踪仅解码器模型。在推理时,它们将KV缓存内存和每令牌计算量削减至少2/3,且无需牺牲预填充缓存或其他仅解码器模型可用的现有推理优化。
查看缓存全文
缓存时间: 2026/05/20 08:37
# 基于块的双解码器
来源:https://arxiv.org/html/2605.18807
Vanessa Alexander vanessa\_alexander@brown\.edu &Benjamin Bradley benjamin\_bradley@brown\.edu &Chaitanya Harsha chaitanya\_harsha@brown\.edu &Asher Labovich asher\_labovich@brown\.edu 布朗大学计算机科学系 普罗维登斯,RI 02912
###### 摘要
编码器-解码器模型相比仅解码器模型可节省大量推理时间,但其预训练目标存在稀疏监督和动态序列长度的问题,导致它们难以在大规模实践中应用。我们提出**基于块的双解码器**,一种新颖的Transformer架构,利用**双因果块注意力掩码**实现完整的损失监督和静态序列打包,结合了仅解码器的训练效率与编码器-解码器的推理效率。在规模定律实验中,基于块的双解码器显著优于编码器-解码器,并在各个规模上紧密追赶仅解码器模型。在推理时,它们将KV缓存内存和每令牌计算量减少至少 \(\frac{2}{3}\),同时不牺牲预填充缓存或仅解码器模型可用的其他现有推理优化。
## 1 引言
近年来,Transformer [16 (https://arxiv.org/html/2605.18807#bib.bib1)] 的兴起推动了自然语言建模的进步。尽管原始Transformer采用基于编码器-解码器的架构,其中编码器使用完全注意力进行理解,解码器使用因果注意力进行预测,但仅解码器架构因其在文本生成场景中的可扩展性而脱颖而出 [11 (https://arxiv.org/html/2605.18807#bib.bib3)]。然而,近期对编码器-解码器架构 [3 (https://arxiv.org/html/2605.18807#bib.bib4)] 的关注再次升温,因其在效率和计算及存储受限环境下的有效性,它们能实现KV缓存的显著减少。
即使有这些工作,已有研究 [12 (https://arxiv.org/html/2605.18807#bib.bib5)] 表明,编码器和解码器之间的架构边界会导致效率低下,并且一个 \(2P\) 参数的编码器-解码器模型与一个 \(P\) 参数的仅解码器架构产生相同的计算成本。另一种替代的单Transformer方法PrefixLM,对输入前缀使用完全注意力,之后使用因果注意力,试图结合编码器-解码器的双向上下文和仅解码器模型的参数共享。该方法受益于较少的动态批处理,但会留下许多令牌未被训练。
但PrefixLM的表现不如使用跨度破坏训练的编码器-解码器,后者训练模型预测一定比例的“被破坏”令牌。尽管这取得了卓越的性能,但它需要高度动态的批处理,并且仍然有许多令牌未被训练,因为 [12 (https://arxiv.org/html/2605.18807#bib.bib5)] 发现仅15%的破坏率就能达到最佳性能。
我们提出一种用于预训练的新颖注意力掩码¹,称之为**双因果块掩码**。我们提出的架构由两个解码器组成。第一个解码器使用标准因果掩码,接收标准输入。第二个解码器对第一个解码器的输出进行交叉注意力,同时也对输入进行交叉注意力,但输入采用我们的方法进行掩码。我们将输入拆分为“块”,每个块内部使用完全自注意力,块之间使用因果交叉注意力。这样,我们从输入中的每个令牌都获得损失信号,并且实现了固定的令牌长度。
[^1]: 代码见 https://github.com/ashlab11/block-based-double-decoder
参见标题†
图 1: 不同规模的解码器、双解码器和编码器-解码器模型的损失与令牌数对比图
## 2 先前工作
原始Transformer架构 [16 (https://arxiv.org/html/2605.18807#bib.bib1)] 由一个双向自注意力编码器和一个用于自回归生成的因果解码器组成。编码器处理完整的输入序列,为每个令牌生成上下文表示。相比之下,解码器应用掩码(因果)自注意力,并额外引入交叉注意力,允许每个解码位置关注所有编码的输入表示。然而,近年来发现,对于生成式建模,仅解码器模型更具可扩展性且性能更优 [11 (https://arxiv.org/html/2605.18807#bib.bib3), 1 (https://arxiv.org/html/2605.18807#bib.bib6)]。与此同时,仅编码器架构 [2 (https://arxiv.org/html/2605.18807#bib.bib7)] 使用令牌掩码训练,在下游分类类任务微调时表现良好。
然而,[12 (https://arxiv.org/html/2605.18807#bib.bib5)] 证明任何分类任务本质上都可以描述为文本到文本的问题,从而使其处于生成模型的能力范围内。他们的发现包括对几种架构和预训练目标的详尽搜索,包括仅解码器(标准LM目标)、仅解码器(PrefixLM目标)以及编码器-解码器(跨度破坏)等。他们发现,在模型所承担的所有下游任务中,包括问答、摘要和翻译,编码器-解码器与跨度破坏的组合在一系列预测任务中表现最佳。在跨度破坏中,连续的令牌跨度被“破坏”或替换为哨兵令牌,每个句子分配一个唯一的ID。目标序列由缺失跨度的拼接组成,模型被训练来自回归地生成这些缺失跨度。通常,只有一小部分令牌(约15%)被破坏,在50%破坏率时会出现性能严重下降。
实验表明,在可比较的训练设置下,这种架构与跨度破坏的组合优于仅解码器语言模型(LM)和前缀变体。此外,在相同模型架构内,跨度破坏作为预训练目标在可比训练设置下也优于标准LM预训练目标。
我们注意到,尽管跨度破坏性能很高,但它作为预训练目标存在几个问题。首先,[12 (https://arxiv.org/html/2605.18807#bib.bib5)] 研究了何种跨度破坏率性能最佳,并确认了 [2 (https://arxiv.org/html/2605.18807#bib.bib7)] 的发现,即15%的跨度破坏达到最优训练效果。然而,这意味着只有这15%的令牌产生损失信号,因此大多数令牌未被训练。这造成了根本性的权衡:一方面通过增加预测难度来增强长距离推理,另一方面相对于总计算成本最大化监督密度。此外,跨度破坏需要动态批处理,因为对于给定的批次,预测的令牌数量在不同训练样本中并不恒定。这需要令牌填充,浪费计算和内存。我们假设另一种预训练目标可以与跨度破坏表现一样好,但不会出现这些低效问题。
此外,先前的工作如 [18 (https://arxiv.org/html/2605.18807#bib.bib8)] 表明,同时捕获局部依赖关系和全局信息的注意力模式可以在降低计算成本的同时实现强性能。尽管我们追求训练时的效率超过 \(O(n)\) 注意力,但我们从这种理念中汲取灵感,采用一种能够同时捕获局部和全局相互依赖性的架构。
## 3 架构对比
### 3.1 基于块的双解码器
<figure>
<img src="fig2.png" alt="Visual explanation of the decoder attention mask for an example sentence.">
<figcaption>图2:对示例句子解码器注意力掩码的可视化解释。分成三部分,我们看到解码器并行看到三个上下文-响应对:
(1): 空上下文,[BOS, A] 响应
(2): [BOS, A] 上下文,[B, C, D] 响应
(3): [BOS, A, B, C, D] 上下文,[E, F] 响应
因此每个令牌在前向传播中恰好出现在损失中一次。</figcaption>
</figure>
在本文中,我们希望创建一种用于下一个令牌预测的架构,它保留编码器-解码器架构的优点——特别是显着的KV缓存减少和在边缘设备上的易用性——同时不牺牲仅解码器Transformer的训练效率。在创建这样的架构时,我们考虑两个标准。首先,**损失信息密度**:打包训练样本中的每个令牌应在每次前向传播中贡献一个损失信号。第二,**序列长度静态性**:打包后的序列长度不应随训练批次变化,从而确保吞吐量可预测且最大化效率。
为了启发我们的架构,我们首先考虑基本的PrefixLM编码器-解码器训练目标,它两个标准都不满足。训练此类模型时,必须在每个批次中随机选择分割点,将 \(N\%\) 放在编码器中,\((100-N)\%\) 放在解码器中。只有解码器部分贡献损失,因此损失信息密度为 \((100-N)\%\),并且令牌计数随批次和 \(N\) 的变化而波动。这导致动态长度批处理(因为 \(N\) 每批次变化)和每令牌损失信息降低,两者结合严重损害训练效率。我们通过提出的架构——基于块的双解码器,解决了这两个问题。
具体来说,基于块的双解码器由两个堆栈组成。第一个称为**上下文解码器**,是一个标准的因果仅解码器Transformer;它接收完整输入并为每个令牌输出因果潜在变量 \(h_t\)。第二个称为**生成解码器**,接收三个输入:(1) 来自上下文解码器的因果潜在变量,(2) 令牌序列本身,(3) **块分区**:严格递增的索引序列 \(0 = b_0 < b_1 < \ldots < b_K = T\),将序列分割成K个连续子序列。对于图2中长度为7的序列,分区 (0, 2, 5, 7) 创建块 [BOS, A]、[B, C, D]、[E, F]。
在生成解码器中,块 \(k\) 中位置 \(t\) 的每个查询关注两个子序列:**块内键**,包含块 \(k\) 中位置 \(\leq t\) 的令牌(即因果自注意力),以及**跨块键**,包含所有块 \(m < k\) 中 \(s\) 位置的上下文解码器潜在变量 \(h_s\)(完全交叉注意力)。图2中的掩码是这两种注意力机制的并集。
这两种操作的组合存在一个微妙之处。将它们作为两个独立的注意力机制并残差相加输出,就像普通编码器-解码器中常见的那样,会引入三个问题。首先,两个操作应用的顺序在架构上变得重要,尽管完全是任意的。其次,某些查询行在交叉注意力掩码下没有键(例如图2中的 [BOS, A]),导致相应的 softmax 未定义。第三,标准注意力的注意力汇聚行为丢失:使用两个独立的 softmax,操作被迫将相等的概率质量分配给两个子序列,即使只有一个对预测下一个令牌重要。这三个问题可以通过仅进行一次注意力操作来解决,该操作根据 (查询, 键) 索引对使用两个不同的键矩阵。然而,据我们所知,目前还没有支持这种双键机制的快速注意力实现,因此我们**分别**计算注意力,并在事后组合它们的对数-求和-指数归一化。这在数学上等同于执行单一的注意力操作,但允许通过 PyTorch 的 FlexAttention 函数快速实现。我们注意到,从计算角度来看,这远不如“理想”方法高效,理想方法将直接利用所选块引入的稀疏性来最小化两次 SDPA 应用中的总乘法次数,从而削减除计算额外 KV 矩阵之外的所有额外计算。我们将这种理想方法的创建留给未来工作。
该架构实现了前面提到的两个标准。无论选择多少个块,每个令牌在前向传播中恰好出现一次在损失中(尽管第一个小块具有空的上下文解码器,因此必须在生成解码器中进行所有推理)。此外,唯一的动态组件是块列表,每批次变化;然而,它只影响生成解码器的注意力掩码,可以通过 PyTorch 的 FlexAttention 以最小延迟处理。并且,上下文解码器的存在保留了普通编码器-解码器架构的优势:在推理时,上下文解码器对提示运行一次,因此不需要 KV 缓存,实现了显著的加速。
本文提到的三种架构——仅解码器、编码器-解码器和双解码器——在训练和推理期间的计算需求差异很大,即使在参数和令牌匹配的情况下也是如此。以下各节将详细描述这些差异。
### 3.2 训练时间对比
尽管我们的主要关注点在于仅解码器、编码器-解码器和双解码器模型在推理延迟上的巨大差异,但三种架构在训练计算量上也存在差异,标准的 \(6NT\) 启发式方法掩盖了这些差异。在附录 A.2 中,我们推导了考虑架构的 FLOP 公式:双解码器因其额外的 KV 投影而增加额外计算,而编码器-解码器通过在跨度破坏下向解码器送入更少的令牌来节省计算。具体来说,对于序列长度为 \(T\)、隐藏维度为 \(d\)、层数为 \(L\) 的仅解码器模型,近似训练 FLOP 计数为
\[ L(72Td^2 + 12T^2d). \]
对于使用跨度破坏训练的编码器-解码器,填充后的编码器输入序列长度为 \(T_{in}\),解码器序列长度为 \(T_{out}\),这是
\[ L((52T_{in} + 28T_{out})d^2 + (4T_{out}^2 + 4T_{in}T_{out} + 8T_{in}^2)d). \]
对于双解码器的高效实现,
\[ L(76Td^2 + 12T^2d). \]
尽管双解码器模型比仅解码器模型需要更多的训练计算,但在高效实现时差异很小:在 \(T = T_{in} = 2048\) 和 \(T_{out} = 256\) 时,双解码器仅比仅解码器多使用 \(2.4\%\) 的 FLOP(而编码器-解码器**少**使用 \(21\%\) 的 FLOP)。将这些公式与我们的经验规模定律结合,可以比较每种架构达到持出测试集上匹配困惑度所需的计算量。
### 3.3 推理时间对比
如第 3.2 节所述,由于每个Transformer块中额外的 KV 矩阵,双解码器在训练计算上往往略**逊于**仅解码器模型。然而,这一劣势被推理期间的众多优势所抵消,我们将在下面详述。所有优势都直接源于训练和推理中上下文与响应的自然分离,并且其中一个优势因经典编码器-解码器模型的双向性质而无法实现。
1. 由于生成相似文章
WAV:面向深度仅解码器Transformer的多分辨率块残差路由
本文提出多分辨率残差路由方法WAV v1,这是块注意力残差机制的扩展,通过引入方向性细节基来增强块表示,从而改进深度仅解码器Transformer的训练效果。
学习跳跃块:自我发现的超度量路由用于硬件加速稀疏注意力
本文介绍了动态超度量注意力(Dynamic Ultrametric Attention),这是一个框架,其中Transformer在训练期间学习每头块稀疏路由拓扑,然后在推理时将这些拓扑卸载到自定义的Triton块稀疏内核上,与密集注意力相比,实现了高达28倍的加速和98.4%的内存减少。
内存高效型循环Transformer:循环语言模型中的计算与内存解耦
提出内存高效型循环Transformer(MELT),这是一种新型循环大语言模型架构,通过跨循环共享单一KV缓存,并结合插值过渡与注意力对齐蒸馏的分块训练方法,实现了推理深度与内存消耗的解耦。
BudgetDraft:面向稀疏KV投机解码的接受感知多视图训练
BudgetDraft提出了一种多视图训练方法,用于投机解码,将稀疏KV起草者与全KV验证者对齐,在中长上下文推理中实现了显著的加速。
ResBM:一种基于Transformer的新型架构,用于低带宽流水线并行训练,实现128倍激活压缩 [R]
ResBM提出了一种基于Transformer的架构,采用残差编码器-解码器瓶颈用于流水线并行训练,在保持收敛的同时实现了128倍激活压缩。该工作通过减少阶段间通信开销,推进了去中心化、互联网级分布式训练的发展。