面向长上下文大语言模型的训练-推理一致性分段执行

arXiv cs.CL 论文

摘要

本文提出了一种面向长上下文大语言模型的训练-推理一致性分段执行框架,旨在解决全上下文训练与受限推理机制之间的不匹配问题,在显著降低内存占用的同时实现了相当的性能。

arXiv:2605.11744v1 发布类型:新论文 摘要:基于 Transformer 的大语言模型在长上下文生成中面临严峻的可扩展性挑战,这主要源于全上下文注意力机制带来的计算和内存成本。在实际的计算和内存约束下,许多注重推理效率的长上下文方法仅在推理阶段采用有限上下文或分段级执行来提升效率,而在训练阶段仍使用全上下文注意力,从而导致训练与推理在执行方式及状态转移语义上出现不匹配。基于这一洞察,我们提出了一种训练-推理一致的分段级生成框架,其中训练和推理遵循相同的分段级前向执行语义。在训练过程中,通过限制梯度传播仅限于从前一相邻分段携带过来的 KV 状态来强制保持与推理的一致性,同时在允许前向传播期间针对特定头访问过去的 KV 状态,但不将这些状态纳入梯度传播。在长上下文基准测试中,我们的方法实现了与全上下文注意力相当的性能,在与强大的推理高效基线方法的对比中取得了具有竞争力的延迟-内存权衡,并在极长上下文长度下显著提升了可扩展性(例如,在 128K 上下文长度下,相比使用 FlashAttention 的全上下文注意力,峰值预填充内存降低了约 6 倍)。
查看原文 导出为 Word 导出为 PDF
查看缓存全文

缓存时间: 2026/05/13 06:17

# 长上下文大语言模型的训练-推理一致分段执行
来源: https://arxiv.org/html/2605.11744
###### 摘要

基于 Transformer 的大语言模型在长上下文生成中面临严重的可扩展性挑战,这源于全上下文注意力机制带来的计算和内存开销。在实际的计算和内存约束下,许多推理高效的长上下文方法仅在推理阶段采用有界上下文或分段级执行来提高效率,而在训练阶段仍继续使用全上下文注意力,导致训练与推理的执行方式及状态转换语义出现不匹配。基于这一洞察,我们提出了一种训练-推理一致的分段生成框架,其中训练和推理遵循相同的分段前向执行语义。在训练期间,通过将梯度传播限制在仅来自紧邻前一段的 KV 状态来强制实现与推理的一致性,同时允许注意力头在前向传播中以特定方式访问过去的 KV 状态,但不将这些状态纳入梯度传播。在多个长上下文基准测试中,我们的方法实现了与全上下文注意力相当的性能,同时在延迟与内存的权衡方面表现出竞争力,且在极长上下文长度下显著提高了可扩展性(例如,与使用 FlashAttention 的全上下文注意力相比,在 128K 上下文下的峰值预填充内存降低了约 6 倍)。

长上下文语言模型,高效注意力,训练-推理对齐,分段执行

## 1 引言

长上下文建模对于大语言模型(LLM)而言日益重要(Achiam et al., 2023; Team et al., 2024; Anthropic, 2024),支撑着文档理解、持续对话和复杂推理等实际应用(Bai et al., 2024)。然而,全上下文自注意力机制的二次计算复杂度从根本上限制了 Transformer 模型在长上下文场景下的可扩展性(Vaswani et al., 2017; Keles et al., 2023)。因此,长上下文推理通常采用受限的执行机制,如有限上下文或块状注意力,以降低计算成本(Xiao et al., 2024; Liu et al., 2025)。最近的工作通过保持注意力语义不变的执行级优化,显著提高了长上下文推理的效率,在不改变模型输出的情况下减少了内存消耗和实际推理成本(Dao, 2024; Agrawal et al., 2023)。然而,随着上下文长度的不断增加,仅依靠这种保持语义的执行级优化所带来的资源节省往往不足以应对实际需求(图 1),因此通常采用更为严格的执行策略,如窗口注意力或稀疏注意力机制。

> **图 1:** 长上下文预填充期间的峰值 GPU 内存消耗。

尽管这些方法在推理时非常有效,但大多数现有方法仅在推理阶段实施这些受限执行机制,而在训练阶段仍依赖全上下文注意力。这导致训练和推理在执行方式以及跨段状态演化方面存在不匹配。结果,模型可能依赖于训练时可用但在受限推理机制下不可用的信息,从而损害长上下文设置下的稳定性和泛化能力。

为了解决这一限制,我们提出了一种训练-推理一致的分段生成框架,将分段执行视为一种共享的建模假设,而非单纯的推理时优化。我们将序列划分为段,并仅携带固定大小的 KV 尾部作为**唯一**的可微分跨段接口状态,该状态在训练和推理中被**原样**使用。训练通过截断后端传播(TBPTT)将跨段信用分配限制在最近 $K$ 个状态转换上;在这种严格受限的递归下,TBPTT 计算推理一致目标的精确梯度,防止依赖推理时不可用的信息。为了访问超出携带 KV 范围之外的证据,模型以前向-only(无梯度)的方式额外消费检索到的 KV 前缀,这部分不参与状态递归。在架构上,我们通过头部和层级稀疏的长距离头部来实现这一设计,而大多数头部支持局部、携带状态的 computation,从而确保训练和推理之间的执行语义严格对齐。

我们的贡献如下:

- 我们提出了一种用于长上下文建模的分段级建模框架,通过设计强制实现训练-推理一致性,将跨段信息流分解为局部连续性通道和一个提供长距离条件化的独立前向-only 机制。
- 我们证明,通过将跨段学习限制在可控的接口状态上,可以在不引入持久记忆变量的情况下理论上保证训练-推理对齐。在这种受限表述下,截断后端传播计算的是推理一致目标的精确梯度,而非近似值。
- 我们在多个长上下文基准测试和上下文长度上对提出的训练-推理一致框架进行了实证验证,展示了在受限执行下的强劲性能,消融实验表明 $K=1$ 的 TBPTT 是充分且最优的,并显著提高了可扩展性(例如,与全注意力相比,128K 上下文预填充内存降低约 6 倍)。

> **图 2:** 训练-推理一致的分段执行。序列逐段处理,具有两个跨段输入:*携带的 KV 尾部* $C_{i-1}$(唯一*可微分*的跨段传播状态)和可选的从*仅历史* KV 池中读取的检索前缀 $R_{i-1}$。在训练期间,深度为 $K$ 的 TBPTT 截断沿状态链的信用分配(红色叉号),因此梯度最多通过 $C$ 流动 $K$ 个段转换(蓝色大括号),而检索路径和早期历史是前向-only(无梯度)。

## 2 相关工作

**执行级优化。** 最近的工作通过执行级和系统级优化,提高了基于 Transformer 模型中基于精确注意力的长上下文推理效率,而未修改模型参数、训练目标或注意力语义。代表性例子包括 FlashAttention 和 FlashAttention-2 等内核级优化(Dao et al., 2022; Dao, 2024),vLLM 和 SARATHI 等用于高效 KV 缓存管理和调度的运行时系统(Kwon et al., 2023; Agrawal et al., 2023),以及探索异构内存卸载以进行大规模推理的系统级方法,例如 FlexGen(Sheng et al., 2023)。虽然这些方法减少了常数因子并提高了中等上下文长度下的吞吐量,但它们保留了全上下文注意力语义,因此并未解决在需要受限执行的更长上下文长度下遇到的计算和内存挑战。

**推理时的受限执行。** 另一类工作仅在推理时限制注意力连接性或保留状态,而保持训练时的执行不变,从而实现长上下文推理。基于流式的方法,如 StreamingLLM(Xiao et al., 2024)和 LM-Infinite(Han et al., 2024),通过调整在短上下文中训练的模型的推理时注意力模式来实现零样本长度泛化,同时保持训练时执行不变。在另一个方向上,MInference(Jiang et al., 2024)通过推理时的稀疏注意力加速长上下文预填充,同时在训练期间保留密集注意力。受限执行也可能源于状态级约束而非注意力稀疏化,例如 ChunkKV(Liu et al., 2025),它在推理期间选择性地压缩和保留 KV 状态而无需重新训练。尽管这些方法以不同方式施加限制,涵盖注意力范围、稀疏性和状态保留,但它们依赖于与推理时行为不同的训练时执行假设,导致推理时注意力连接性或保留状态使用的不匹配。

**训练-推理对齐。** 训练-推理对齐已在推理采用受限执行的场景中得到探索,通过在训练中明确纳入可比约束来实现。Longformer(Beltagy et al., 2020)用固定的稀疏模式替换全自注意力,使得训练和推理在相同的注意力连接性下运行。核心上下文感知(CCA)(Chen et al., 2025)通过在适应过程中一致地应用机制,在减少上下文计算图上强制对齐。对齐也在流式或分段执行中得到了研究:Shiftable Context(Raffel et al., 2023)强制训练和推理之间保持一致的分段结构,而滑动窗口注意力训练(Fu et al., 2025)直接使用窗口注意力训练模型,以避免仅在推理时引入此类约束时出现的性能退化。综上所述,这些方法通过在训练和推理阶段强制保持一致的注意力连接性或上下文结构来缓解训练-推理不匹配。

**基于记忆和递归的模型。** 一种替代的建模范式通过引入持久记忆机制或在段之间传播递归状态来扩展有效上下文长度。Transformer-XL(Dai et al., 2019)是最早在 Transformer 中引入带有梯度截断的分段递归的工作,而我们的方法正式化了训练-推理一致的执行语义,并将短期可微分状态传播与前向-only 的长距离检索分开。压缩 Transformer(Rae et al., 2020)建立在 Transformer-XL 的基础上,通过学习的压缩保留旧状态。递归记忆 Transformer(Bulatov et al., 2022)进一步引入了显式记忆 token,这些 token 被递归更新并训练以存储跨段的全球信息。记忆 Transformer(Wu et al., 2022)则依赖于通过检索访问的外部键值记忆来支持长距离回忆。这些方法依赖于显式的、持久的跨段状态来携带长距离信息,其在训练期间的更新动态不一定与推理时遇到的执行语义对齐。相比之下,我们的方法避免了持久记忆状态,并在分段级执行下强制实现训练-推理一致性。

> **图 3:** 头部和层级稀疏的长距离检索。(a)在非长距离层 $\ell \notin \mathcal{L}_{\text{long}}$ 中,局部头部关注段内 token 和来自上一段的携带 KV 状态(绿色),而长距离头部仅使用段内因果注意力(橙色)。(b)在启用长距离的层 $\ell \in \mathcal{L}_{\text{long}}$ 中,局部头部保持不变,而长距离头部额外关注从仅历史 KV 池中检索的前缀(蓝色)。在所有情况下,当前段内的注意力保持因果性。

## 3 方法

在长上下文语言建模中,全上下文注意力的计算和内存成本使得大规模下的无限制执行不切实际,导致使用分段或其他受限的注意力机制。许多现有方法仅在推理时施加此类约束,而在训练期间保留全上下文注意力,导致训练和推理执行语义之间的不匹配。在本节中,我们描述了一种训练-推理一致的分段执行框架,其中训练和推理遵循相同的分段级执行方案,在同一执行语义下显式约束跨段状态传播和长距离访问。

### 3.1 训练-推理一致的分段执行

#### 设置

如图 2 所示,我们将 token 序列划分为 $N$ 个非重叠段 $\{x^{(i)}\}_{i=1}^N$,其中 $m_i = \|x^{(i)}\|$ 表示段 $i$ 的长度。为了在有限注意力下实现分段推理,我们暴露了一个受限的跨段接口状态 $C_i \in \mathcal{C}$(在我们的实现中是固定大小的携带 KV 接口),以及由长距离模块提供的前向-only 检索前缀 $R_i \in \mathcal{R}$。

###### 定义 3.1(分段级执行语义)。

对于每个段 $i$,模型在训练和推理时运行相同的前向算子:

$$ (C_i, o^{(i)}) = F_\theta(x^{(i)}, C_{i-1}, R_{i-1}) \tag{1} $$

其中 $o^{(i)}$ 表示用于计算语言模型损失的输出,$C_{i-1}$ 是模型可用的**唯一**可微分跨段接口状态,$R_{i-1}$ 是前向-only 检索前缀。段损失为

$$ \ell_i(\theta; C_{i-1}, R_{i-1}) = -\sum_{t=1}^{m_i} \log p_\theta(x_t^{(i)} \mid x_{<t}^{(i)}, C_{i-1}, R_{i-1}). \tag{2} $$

#### 操作解释。

对于分为三个段 $x^{(1)}, x^{(2)}, x^{(3)}$ 的序列,且 $C_0 = R_0 = \emptyset$,公式 (1) 中的分段级递归展开为

$$
\begin{aligned}
(C_1, o^{(1)}) &= F_\theta(x^{(1)}, \emptyset, \emptyset), \\
(C_2, o^{(2)}) &= F_\theta(x^{(2)}, C_1, R_1), \\
(C_3, o^{(3)}) &= F_\theta(x^{(3)}, C_2, R_2).
\end{aligned}
$$

这里 $o^{(i)}$ 用于预测当前段 $x^{(i)}$ 中的 token,而 $C_i$ 是为下一段产生的携带状态。因此,$C_{i-1}$ 和 $R_{i-1}$ 是条件...

相似文章

学习,快与慢:走向持续适应的LLMs

Hugging Face Daily Papers

一种针对LLMs的快慢学习框架,将固定的慢权重与优化的快上下文权重相结合,在持续学习场景中实现了高达3倍的样本效率提升,并减少了灾难性遗忘。

跨异构任务的自演化LLM记忆抽取

Hugging Face Daily Papers

研究者推出BEHEMOTH基准与CluE聚类提示优化,使LLM能从多样化任务中抽取并保留异构记忆,相比既往自演化框架提升9%。

LongAct:利用内在激活模式进行长上下文强化学习

Hugging Face Daily Papers

LongAct 提出了一种显著性引导的稀疏更新策略,通过选择性更新与查询和键向量中高幅值激活相关的权重来改进 LLMs 的长上下文推理能力,在 LongBench v2 上实现了约 8% 的提升。