ART:高效大语言模型解码中的注意力运行时终止

arXiv cs.CL 论文

摘要

本文提出ART,一种轻量级的运行时机制,它在LLM解码过程中追踪累积的注意力输出,并在进一步贡献变得微不足道时终止不必要的KV块访问,从而在保持相当精度的同时实现20%更高的生成吞吐量。

arXiv:2606.00024v1 公告类型:新 摘要:在大语言模型(LLM)中,长上下文解码受到内存带宽的严重制约,因为需要获取庞大的键值(KV)缓存。现有的多数KV管理方法依赖于解码前的仅键剪枝,尽管有证据表明注意力输出依赖于键和值的联合,但将值纳入其方法会带来过高的额外开销。在本文中,我们提出注意力运行时终止(ART),一种轻量级的运行时机制,在内核执行期间跟踪累积的注意力输出,并在进一步贡献变得微不足道时终止后续KV块的访问。这种设计使ART与现有的基于键的KV缓存管理方法正交,从而能够与它们无缝集成。在LongBench基准上的实验表明,与最先进的基线相比,ART在大批量下实现了20%更高的生成吞吐量,同时保持了相当的精度。
查看原文
查看缓存全文

缓存时间: 2026/06/02 15:35

# 高效大语言模型解码中的注意力运行时终止
来源:https://arxiv.org/html/2606.00024
###### 摘要

长上下文解码在大语言模型(LLMs)中受到严重制约,原因是获取大量键值(KV)缓存所需的内存带宽。大多数现有的 KV 管理方法在解码前依赖于仅基于键的剪枝,尽管有证据表明注意力输出共同依赖于键和值,因为在其方法中纳入值会带来难以承受的额外开销。在本文中,我们提出注意力运行时终止(ART),一种轻量级的运行时机制,它在内核执行期间跟踪累积的注意力输出,并在后续贡献变得可忽略时终止后续的 KV 块访问。这种设计使得 ART 与现有的基于键的 KV 缓存管理方法正交,从而能够无缝集成。在 LongBench 基准上的实验表明,ART 在大批量处理时比最先进的基线实现了 20% 更高的生成吞吐量,同时保持了相当的准确率。

## 1 引言

大语言模型(LLMs)[1, 2] 在自回归解码过程中依赖不断增长的键值(KV)缓存来存储先前生成 token 的中间表示。随着序列长度的增加,每个新的查询 Q 需要关注不断扩展的缓存键 K 和值 V,导致延迟和内存使用量线性增长。因此,有效管理 KV 缓存(通常称为*KV 剪枝*或*KV 驱逐*)已成为加速长上下文推理的核心挑战。

大多数现有方法采用基于键中心的 KV 缓存剪枝,仅保留由启发式或学习的重要性估计器选择的 token。代表性方法 [3, 4, 5, 6] 通常基于查询-键相似度或注意力分数来构建这些估计器。然而,这些代理并不总是反映 token 对模型输出的真实影响。如图 1 所示,这种失准揭示了基于注意力的剪枝的一个基本局限性:具有低注意力权重的 token 仍可能通过其值表示产生实质性影响。

参见图 1 的说明:**注意力分数与输出贡献的对比**。我们比较了基于标准注意力分数(\(\max_h \alpha_{h,j}\))的 token 排名与其通过加权值向量的 L2 范数测量的实际输出贡献。虽然存在一般相关性,但橙色点突出显示了一个关键的 token 子集:那些具有高功能影响(前 20% 贡献)却被分配低注意力分数的 token,这揭示了基于注意力的剪枝的一个关键局限性。

最近的一些研究 [7, 8, 9] 试图将值信息纳入 KV 缓存管理,展示了值感知建模的潜在优势。然而,这些方法通常依赖于额外的预测器、预计算或离线分析来估计值贡献,在推理过程中引入了不可忽视的开销。这引发了一个自然的问题:

**我们能否在运行时以可忽略的额外成本考虑键和值的联合效应?**

参见图 2 的说明:**我们的注意力运行时终止(ART)解码过程概览**。随着 KV 块被顺序处理,ART 监控中间注意力输出,并在输出稳定后触发提前终止,在解码过程中跳过剩余的 KV 块。

现代 FlashAttention 风格的内核 [10, 11] 在逐块执行过程中自然暴露了一系列中间注意力输出。随着 KV 块被逐步加载和处理(图 2),每一步都会增量更新注意力输出,提供新处理块贡献了多少额外信息的直接信号。当这些增量更新变得足够小时,即使尚未遍历所有 KV 块,进一步的计算和内存访问可能已无必要。这一洞察使得在运行时无需依赖昂贵的预解码估计即可渐进评估键和值的联合影响。

基于这一洞察,我们提出**注意力运行时终止(ART)**,一种用于提前终止注意力计算的轻量级输出感知机制。ART 的概览如图 2 所示。ART 直接在逐块注意力内核内运行,在 KV 块处理过程中逐步监控部分注意力输出的演变。与事先定义的静态剪枝规则不同,ART 是输出感知的,动态监控部分注意力输出在幅度和方向上的演变,以确定是否需要后续的 KV 块。因此,ART 与现有的 KV 缓存管理方法正交,可以无缝集成以在解码过程中优化缓存利用。

我们在长上下文基准 LongBench [12] 上评估了 ART,结果表明 ART 持续提高注意力效率而不牺牲生成质量。特别地,在大批量处理时,ART 比最先进的基线实现了 20% 更高的生成吞吐量。

我们的贡献总结如下:

- • 我们提出了 ART,一种输出感知的运行时机制,动态确定在注意力执行过程中是否需要处理更多的 KV 缓存块。
- • 我们设计了一种基于稳定性的终止标准,通过联合监控中间注意力输出的尺度和方向变化来量化输出空间中的收敛性,从而捕捉键和值的组合影响。
- • 我们证明 ART 是一个轻量级且可组合的模块,能够与现有的 KV 缓存管理方法无缝集成,并在长上下文基准上实现超过 20% 的生成吞吐量加速,且准确率损失可忽略。

## 2 相关工作

在本节中,我们回顾先前关于高效长上下文推理的工作。我们首先总结 KV 缓存管理方法,然后讨论高效注意力内核和推理系统。

### 2.1 KV 缓存管理

早期的高效长上下文推理方法采用基于模式的剪枝策略,例如 LM-Infinite [13] 和 Attention Sink [3],它们通过丢弃较旧的 token 或将注意力集中在一小组锚点位置来维护固定大小的活动上下文。为了超越静态启发式,后续方法使用语义、层次或动态标准来估计 token 重要性。ChunkKV [6]、SnapKV [14]、Quest [5] 和 PyramidKV [15] 通过结构化分组保留粗粒度的全局信息,同时过滤冗余 token。H2O [4]、TOVA [16] 和 Scissorhands [17] 进一步利用重要性分数的时间持久性,在解码步骤中自适应管理 KV 缓存。

尽管这些方法有效,它们本质上是**基于键中心**的,隐含假设注意力矩阵单独决定输出,忽视了值(V)幅度的作用。最近的工作 [8] 挑战了这一假设,表明值(V)也编码了关键的语义信号,并显著影响注意力结果。然而,当前的值感知方法 [7, 9] 通常依赖辅助预测器或离线分析,带来大量开销,限制了它们在在线服务中的实用性。这留下了一个空白:如何在**运行时**以无额外延迟的方式捕捉键和值的联合影响。

参见图 3 的说明:**ART 集成到 FlashAttention 执行管线中**。FlashAttention 将基于 DMA 的 KV 块预取从 HBM 与 Tensor Core 计算以逐块方式重叠。ART 在计算期间对演变的注意力输出执行轻量级运行时检查。一旦检测到收敛(例如,在第 \(i\) 步),ART 通过阻止进一步的 KV 块预取和计算来提前终止管线,减少不必要的内存流量和计算,而不影响最终输出。

### 2.2 高效注意力内核与推理系统

在内核层面,FlashAttention [10] 和 FlashAttention-2 [11] 通过 IO 感知的平铺和在线 softmax 显著提升了注意力效率。通过逐块计算注意力,这些内核减少了 HBM 与 SRAM 之间的内存流量,同时增量累积注意力输出。这种逐块执行模型构成了我们方法的关键基础,因为它暴露了中间累积状态,可以在运行时进行监控而无需修改注意力公式。

基于这些优化的内核,服务框架如 vLLM [18] 和 SGLang [19] 通过 PagedAttention 和高级调度进一步提升了端到端效率,解决了内存碎片化问题并提高了服务吞吐量。然而,这些系统将注意力内核视为原子操作符。我们的工作通过打开这个原子操作符,直接在内核执行中引入轻量级运行时提前终止机制,从源头减少不必要的 KV 访问,从而补充了现有框架。

## 3 方法论:ART

在本节中,我们介绍**注意力运行时终止(ART)**。我们首先检查现代注意力内核的执行特性,然后介绍基于稳定性的运行时终止机制,最后讨论 ART 的集成和正确性。

### 3.1 FlashAttention 的执行特性

尽管注意力输出由所有键上的全局 softmax 定义,现代 FlashAttention 风格的内核提供了一个关键的执行特性:注意力以流式、分块方式计算,其中输出在计算和内存传输重叠时增量累积(图 3)。这种执行模型以可忽略的成本暴露了中间注意力状态,使得能够推理输出收敛并提前终止注意力执行。

FlashAttention 的流式执行自然产生一系列中间注意力输出,随着 KV 块的处理而逐步更新。这种执行模式的一个重要含义是,注意力输出可能在遍历所有 KV 块之前就已足够稳定,这为提前终止提供了机会,而不会实质性地影响最终结果。

为了形式化这种行为,我们考虑一个查询块 \(Q \in \mathbb{R}^{M \times d}\) 关注键 \(K \in \mathbb{R}^{N \times d}\) 和值 \(V \in \mathbb{R}^{N \times d}\) 的缩放点积注意力:

\[
O = \mathrm{softmax}\!\left(\frac{QK^\top}{\sqrt{d}}\right)V. \tag{1}
\]

参见图 4 的说明:**最近 KV 保留的注意力输出收敛性**。该图展示了全 KV 缓存与截断的最近 KV 窗口的注意力输出之间的相对 L2 误差,作为加载比例的函数。

FlashAttention 通过将键和值划分为瓦片/块 \(\{ (K_t, V_t) \}_{t=1}^T\) 并顺序处理它们来计算公式 (1)。使用数值稳定的流式 softmax 公式,内核维护一个内部累加器,在每个块后更新。我们记 \(O^{(t)}\) 为处理前 \(t\) 个块后的注意力输出。随着 \(t\) 增加,\(O^{(t)}\) 收敛到最终输出 \(O^{(T)}\)。如图 4 所示,注意力输出在处理 KV 块时迅速收敛。只有部分块贡献了显著的误差减少,而后续的块仅引起边际变化。这种行为表明,注意力输出可能在遍历所有 KV 块之前就已稳定。

### 3.2 基于稳定性的运行时终止

鉴于早期稳定行为,一个自然的问题是如何在运行时确定剩余的 KV 块是否会显著影响最终的注意力输出。一种基于注意力分数的直接方法是不够的,因为高注意力权重可能被分配给信息量较少甚至为零的值向量(图 1)。这种局限性表明,仅基于键的信号不能可靠地指示何时进一步的注意力计算变得冗余。

受此洞察的启发,我们的想法是跟踪中间注意力输出的演变。具体来说,不是通过注意力分数估计重要性,而是监控部分注意力输出 \(O^{(t)}\) 在加入更多 KV 块时如何变化。如果输出变得足够稳定,进一步遍历 KV 块不太可能改变最终结果,因此可以安全地终止。通过直接在输出空间测量稳定性,ART 自然捕捉了键和值的联合贡献。

一个剩余的挑战是如何在内核执行期间高效评估这种稳定性。直接计算全向量范数或在全局累加器上执行跨线程归约会引入不可忽视的开销,并削弱提前终止的优势。这就要求一种轻量级机制,可以在无需全向量计算的情况下近似输出收敛性。

在利用 NVIDIA Tensor Cores 的优化注意力内核中(例如 FlashAttention-2 [11]),累加器 \(O^{(t)}\) 通过交织的矩阵乘法累加(MMA)布局分布在线程之间。我们通过利用线程块中领头线程持有的寄存器片段构建一个探针向量 \(x^{(t)} \in \mathbb{R}^m\)。关键的是,由于 MMA 布局的交织性质,这个片段构成了头部维度的确定性、分散子样本,而不是连续切片,提供了分散的覆盖。

相似文章

SparDA:用于高效长上下文 LLM 推理的稀疏解耦注意力

arXiv cs.CL

SparDA 提出了一种解耦稀疏注意力架构,通过添加轻量级"Forecast"投影来预测未来的 KV 缓存需求,从而实现从 CPU 到 GPU 的预取(lookahead prefetching),并降低选择开销。在基于稀疏预训练的 8B 模型上,其 prefill 速度最高可提升 1.25×,decode 速度最高可提升 1.7×,相比非 offload 基线,decode 吞吐量最高可提升 5.3×。

GQLA: 面向硬件自适应大语言模型解码的分组查询潜在注意力

arXiv cs.LG

GQLA 提出了对多头潜在注意力(MLA)的极小修改,在相同训练权重上同时暴露 MQA 吸收路径和 GQA 路径,从而无需重新训练即可实现硬件自适应解码。该方法压缩 KV 缓存并支持张量并行性,通过将 LLaMA-3-8B 从 GQA 转换为 GQLA 得到验证。