自剪枝键值注意力:通过预测未来效用决定何时写入

arXiv cs.LG 论文

摘要

提出了自剪枝键值注意力(SP-KV),一种通过学习预测键值对未来效用的机制,动态剪枝KV缓存,将内存使用和解码速度提升3-10倍,且性能下降极小。模型和效用预测器通过下一词元预测进行端到端联合训练。

arXiv:2605.14037v1 公告类型:新 摘要:在现代测试时计算和代理范式下,语言模型处理越来越长的序列。使用Transformer架构的高效文本生成日益受到键值缓存内存占用和带宽的限制。为解决这一限制,我们提出了自剪枝键值注意力(SP-KV),一种旨在预测未来KV效用以减少长期KV缓存大小的机制。该策略以细粒度运行:一个轻量级的效用预测器为每个键值对打分,同时最近的KV通过局部窗口始终可用,而较旧的键值对仅在其预测效用超过给定阈值时才被写入缓存并用于全局注意力。LLM和效用预测器完全通过下一词元预测损失进行端到端联合训练,并从预训练的LLM检查点进行适配。 SP-KV并非强制固定压缩比,而是执行动态稀疏化:该机制自适应输入,通常将KV缓存大小减少至原来的$3$到$10\times$,较长的序列通常更易压缩。这带来了内存使用和解码速度的大幅提升,而验证损失和广泛下游任务性能几乎没有下降。除了作为有效的KV缓存缩减机制外,我们的方法还揭示了结构化的层和头特定稀疏模式,可用于指导混合局部-全局注意力架构的设计。
查看原文
查看缓存全文

缓存时间: 2026/05/15 06:26

# 学习何时写入:通过预测未来效用进行决策 来源:https://arxiv.org/html/2605.14037  
[1] Meta FAIR [2] MICS, CentraleSupélec  
\*共同贡献  
通信作者  

## 自剪枝键值注意力:通过预测未来效用来学习何时写入  

Manuel Faysse, Maria Lomeli, Matthijs Douze, Pierre-Emmanuel Mazaré, Loïc Cabannes, Wen-tau Yih, Hervé Jégou  
[[[email protected]](mailto:[email protected])(2026年5月13日)  

###### 摘要  
在现代测试时计算和智能体范式中,语言模型处理越来越长的序列。使用Transformer架构进行高效文本生成时,键值缓存的内存占用与带宽正日益成为关键瓶颈。为解决这一问题,我们提出**自剪枝键值注意力**(Self-Pruned Key-Value Attention,SP-KV),这是一种通过预测未来键值对效用来缩小长期KV缓存规模的机制。该策略以细粒度方式运作:一个轻量级的效用预测器为每个键值对打分;近期KVs始终通过局部窗口可用,而较旧的键值对只有在预测效用超过给定阈值时才会被写入缓存并用于全局注意力。语言模型和效用预测器通过端到端联合训练,仅使用下一令牌预测损失,并从预训练的语言模型检查点适配而来。与强制固定压缩比不同,SP-KV执行**动态**稀疏化:该机制会根据输入自适应调整,通常能将KV缓存大小缩减3到10倍,序列越长压缩率通常越高。这带来了内存使用和解码速度的巨大提升,同时对验证损失或广泛下游任务集上的性能几乎没有影响甚至完全没有影响。除了作为一种有效的KV缓存缩减机制,我们的方法还揭示了结构化的逐层和逐头稀疏模式,这些模式可用于指导混合局部-全局注意力架构的设计。  

见图注  
**图1:自剪枝键值注意力概览**:学习到的KV效用预测器调节注意力操作中键值对的使用。推理时,只有效用超过给定阈值τ的KV对被保留在持久化KV缓存中,从而实现内存节省和解码加速。我们始终保留近期(128个令牌)以保持局部交互。训练时,令牌选择被可微分门控替代以保持梯度流动。经过全注意力预训练的模型在继续预训练过程中逐渐稀疏化,无需特定的损失函数。  

## 1 引言  
在推理阶段,Transformer语言模型(Vaswani 2017 Attention Is All You Need; Brown 2020 Language Models are Few-Shot Learners)的大小和速度越来越受内存而非计算量的限制。在自回归生成过程中,键值缓存随序列长度线性增长,并且每个新生成的令牌都需要读取整个缓存。随着部署场景转向长上下文、检索增强和智能体测试时流水线,这种不断膨胀的缓存使得GPU内存流量成为主要的性能瓶颈。这种压力也延伸到后训练阶段,因为长上下文强化学习和集成工具智能体的训练依赖于扩展的解码展开(Zhu 2025 Scaling TTC Agents; Wang 2025 LoongRL)。  

为了缓解这个问题,GQA(Ainslie 2023 GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints)和MLA(Liu 2024 DeepSeek-V2)等架构方法通过跨查询头组共享键值来减小KV缓存。其他方法利用了大多数查询-键交互集中在短局部窗口内,而长距离交互相对稀疏的事实(Zhang 2023 H2O)。通常,混合Transformer通过将全局注意力与局部滑动窗口注意力交替使用(Beltagy 2020 Longformer; Rivière 2024 Gemma 2),或者用固定记忆序列机制(如Gated DeltaNet(Yang 2025 Gated DeltaNet))替换某些注意力层,来减少对全局注意力的依赖。  

另一条工作路线利用**读取时稀疏性**:诸如QUEST和DeepSeek Sparse Attention等查询感知方法在解码时只检索部分历史键。虽然解码速度得到提升,但完整缓存仍保留在内存中(Tang 2024 QUEST: Query-Aware Sparsity for Efficient Long-Context LLM Inference; DeepSeek-AI 2025 DeepSeek-V3 Technical Report)。  

在本文中,我们基于这样一个观察:对于给定的注意力头,查询和键大多区分为短期和长期交互,这表明只有一部分过去的键值对在未来的解码中始终有用。这引发了一个问题:是否有必要将每个令牌不加区分地写入长期/持久化内存?如果答案是否定的,那么意味着KV缓存不仅可以被**稀疏读取**,更重要的是还可以**选择性写入**,从而在节省FLOPs的同时节省内存。  

此前,诸如H2O和KVZap等驱逐方法尝试在预填充后利用过去令牌统计信息或针对冻结模型学习的策略来修剪缓存(Zhang 2023 H2O; Jégou 2026 KVZap)。虽然这些技术能大幅减少内存使用,但常常以牺牲模型质量为代价。事后方法的中心局限在于,模型的内部令牌表示没有针对修剪策略进行适配,而修剪机制本身通常是在小型辅助数据集上校准的。这导致训练-测试不匹配,随着压缩力度增大和输入分布发生偏移,该不匹配会更加严重(见第5节)。  

我们做出以下贡献:  
**贡献1.** 我们提出自剪枝KV注意力(SP-KV),一种学习到的**稀疏写入**机制,该机制仅将最有用的键值对选择性地写入持久化KV缓存。一个轻量级效用预测器为每个KV对分配一个效用分数;近期令牌通过局部因果窗口保持可访问,而较早的KV对仅当预测效用超过阈值时才被写入全局缓存并参与全局注意力。语言模型和效用预测器在大规模数据上仅使用下一令牌预测进行联合训练,通常从预训练的全注意力检查点继续预训练。在不同的模型规模和序列分布下,SP-KV实现了高KV缓存缩减,带来内存和解码速度的提升,而对验证损失或下游评估几乎没有影响。我们提供了大量消融实验,并展示了该机制可以迁移到密集注意力之外的混合局部-全局设置。  

**贡献2.** 我们进一步展示了SP-KV可以作为架构探针,用于设计更强的全局/局部注意力Transformer混合体。具体来说,在参考模型上,仅保留学习到的平均SP-KV效用最高的那些注意力头作为全局头,而将剩余注意力头设为局部头,我们得到的混合体在相同KV缓存预算下性能优于标准的交错布局。  

## 2 方法概览  

### 2.1 自剪枝KV机制  
图1展示了在第l层、单个注意力头的**自剪枝KV**机制。设T为序列长度,H^l = [h_0^l, ..., h_{T-1}^l]^⊤ ∈ R^{T × d_model}为第l层的隐藏状态。  

#### 逐键效用预测。  
对于每个键头k,我们预测每个令牌位置s的效用值:  
u_s^{l,k} = σ(f_θ^{l,k}(h_s^l)) ∈ (0,1)  (1)  
其中σ(·)表示逻辑斯蒂sigmoid函数,确保效用值在(0,1)范围内,f_θ^{l,k}(·)是由θ参数化的轻量级效用预测器,为一个2层感知器(MLP)。为简化后续符号,我们省略层和键头上标,将u_s^{l,k}简记为u_s。推理时,我们将效用门控值与阈值τ进行比较以获得二值化值;z_s=1表示位置s的KV对符合**长距离**(全局)注意力的条件;z_s=0则不符合。  
z_s = 1[u_s ≥ τ], z_s ∈ {0,1}.  (2)  

#### 滑动窗口与门控全局注意力。  
为保留局部时间特征,我们**始终**允许在大小为w的因果局部滑动窗口内进行注意力(默认w=128)。对于查询位置t和键位置s,我们定义窗口指示函数1_win(t,s) = 1[0 ≤ t-s < w]。最终输出注意力权重由下式给出:  
a_{t,s} = softmax( (q_t k_s^⊤) / √d_k + M_{t,s})  (3)  
其中掩码M_{t,s}定义为:  
M_{t,s} = { 0, 如果1_win(t,s) = 1 或 (z_s = 1 且 s ≤ t); 否则 -∞ }  (4)  

#### 训练时的软化松弛。  
在训练时,离散门控z_s会阻碍梯度流动。因此我们用一个可微分的松弛替代,使用sigmoid函数对门控进行软化:  
α_s = σ( (u_s - τ) / β )  (5)  
其中β是温度参数(默认为0.01)。然后我们将掩码M_{t,s}替换为基于α_s的加性分数:  
M_{t,s}^{soft} = { 0, 如果1_win(t,s) = 1; 否则 c(α_s - 1) }  (6)  
其中c ≥ 0是一个大的常数(实践中c = 1e6)。这相当于在被丢弃的键值对上施加一个大的负偏置,而非完全屏蔽。该机制使得稀疏选择性写入能够与标准下一个令牌预测损失下的反向传播兼容。  

#### 部署时的高效实现。  
推理时,我们存储一个仅包含满足z_s=1的KV对的稀疏持久化KV缓存。对于每个新生成的查询,计算注意力时,我们使用该稀疏缓存和局部窗口中的键。由于持久化缓存大小是动态的且取决于效用,它可以是部分预填充、递增构建或定期重建的。推理时门控是确定性的(公式2)。  

### 2.2 训练协议  

#### 继续预训练与端到端联合训练。  
我们从预训练的全注意力LM检查点开始,将效用预测器MLP(每层每个注意力头一个)插入架构,并继续在大规模语料库上训练模型。整个模型——LM主干和效用预测器——通过标准的自回归下一个令牌损失进行端到端联合优化。关键的是,我们**不**使用任何辅助稀疏性损失;模型在梯度驱动下自发地学习将低效用分配给那些对未来预测贡献可忽略的键值。这使得门控机制能够直接在该模型的主要优化目标下学习何时写入。  

#### 渐进式稀疏化。  
尽管公式6中的软化松弛使得梯度流经门控机制成为可能,但我们发现,在训练过程中逐渐增加稀疏度有助于收敛。我们从较低的效用阈值开始(近似模拟全注意力),并在固定的步数内逐渐将其增加到目标值τ_target。温度参数β也按照相似的计划进行调整。这种计划防止了过早锁定稀疏模式,并使效用预测器能够逐渐适应稀疏写入的决策。  

## 3 实验  

### 3.1 语言建模评估  
我们在多个参数规模的模型上评估SP-KV:8.1B、2.91B和1.34B,每个模型在C4(Raffel等人,2020)的900B令牌上以32k上下文长度训练。作为参考,我们使用相同的设置训练全注意力基线模型(如8.1B*表示1T令牌)。对于8.1B和2.91B模型,SP-KV检查点通过从已训练的全注意力基线模型继续预训练120B令牌获得(占总预算的1/8,速度为8倍);1.34B模型从头开始用SP-KV训练。除非另有说明,所有SP-KV模型均使用τ=0.5的阈值。  

#### 验证损失缩放定律。  
我们采用Chinchilla缩放定律公式(Hoffmann等人,2022),并针对香草全注意力和SP-KV在8个不同计算预算(1×10^18至2×10^21 FLOPs)上的2.91B模型拟合参数。通过R² > 0.999验证拟合质量(见图2)。全注意力和自剪枝KV注意力显示出相似的计算缩放行为。对更大计算预算的外推预测了两个变体几乎相同的性能,并得到了8.1B模型(未用于拟合)的确认。这表明通过继续预训练(占总预算的1/8)适配自剪枝KV不会降低性能,同时实现了稀疏性的好处。通常,8.1B模型在验证数据上仅保留29.6%的键(τ=0.5)。  

### 3.2 下游任务结果  

**表1:** 8.1B参数模型在标准下游任务和完整RULER长上下文基准(13个子任务类型)上的结果,该模型在32k上下文中使用全注意力训练,与其自剪枝KV变体(τ=0.5)对比。SP-KV在实现约66% KV稀疏化的同时,保持了标准基准性能(平均-0.2%)。整体RULER退化率为-1.2%,完整逐任务细分见**表5**。密度指局部窗口之外保留的KV条目占比。  

除困惑度外,我们还在一个多样化的基准集合上测试了相对香草注意力的非退化性。  

#### 预训练基准套件。  
使用在32k上下文训练的8.1B模型,我们在表1报告的广泛标准下游基准上评估自剪枝KV注意力。总体而言,自剪枝KV变体与全注意力基线紧密匹配,平均变化仅为-0.2%,同时平均仅保留33.7%的非局部KV条目。  

#### 长上下文评估。  
如表1所示,自剪枝KV在长达16k令牌的RULER基准任务上保持了接近基线的性能。在32k处略大的下降(-3.9%)可能反映了在该长度上的有限暴露,因为32k是训练期间看到的最大上下文长度。虽然这些结果展示了SP-KV对长上下文序列的域外泛化能力,但在训练混合中添加RULER风格数据的额外实验(表6)表明,SP-KV从长上下文训练中显著受益,在大多数RULER任务上与训练在相同数据上的香草注意力变体匹配或超越。完整RULER结果见表5。  

#### 稀疏性。  
保留的KV密度在不同任务间变化很大,如表1所示。标准下游任务通常保留20%到50%的非局部KV条目,在非常短的任务(如ARC、OBQA和Winogrande)上密度较高,在生成任务(如GSM8k、HumanEval Plus和MBPP)上密度较低。相比之下,RULER评估表现出低得多的密度,平均约为17%-19%,而“大海捞针”(NIAH)仅需要约5%-7%的保留KV条目,同时保持完美的检索准确率。这支持了先前的发现(Liu 2023 Lost in the Middle: How Language Models Use Long Contexts),即任务相关信息在长上下文输入中是稀疏的,使模型能够丢弃局部窗口之外的大多数过去KV条目。  

见图注  
**图3:**(左)不同τ值(门控决策二值化的阈值)下,NLL与KV缓存密度之间的关系。2.91B模型实验。全注意力+SP-KV帕累托优于混合3:1配置(FairCodeGenTeam 2025 CWM),在约26%密度下实现近乎无损的NLL(+0.07%)(τ=0.5)。(右)自剪枝KV注意力的自定义内核实现(批量大小16)的逐令牌解码延迟(毫秒)。内存瓶颈使得限制键读取的门控替代方案能够优于标准注意力。较低的密度比可以直接转化为较低的延迟。  

### 3.3 控制稀疏性与性能的权衡  
激进地移除大量KV条目会减少内存占用和注意力成本,但通常在中等等稀疏度水平之上会导致严重的质量退化(Jégou 2026 KVZap)。相比之下,我们的方法通过自剪枝KV注意力训练**自诱导**稀疏性,在实践中产生了显著更平坦的退化曲线。  

#### 阈值优化。  
最直接的稀疏性控制是剪枝阈值τ,它提供了保留KV密度与下游性能之间的平滑插值,如图3(左)所示。在SP-KV模型(红色)的操作范围内,模型可以实现高稀疏性,例如约10%的保留密度,同时保持强劲的性能;或者,使用更密集的设置可以恢复接近全注意力的性能。  

#### 混合体。  
SP-KV也可以仅应用于在局部和全局注意力之间交替的Transformer的全局层(蓝色曲线)(Rivière 2024 Gemma 2)。这种架构在设计中已经稀疏,因为局部层不维护长距离KV缓存,但在一定程度内,其全局层仍可通过轻微的性能退化进一步稀疏化。将SP-KV应用于所有层会产生更强的稀疏性-性能前沿,如图3所示,但混合局部/全局变体在训练和推理期间通常更容易针对速度和内存进行优化。其他稀疏性因素在子节C.2中研究。训练期间,这些因素包括学习率调度、效用预测器架构以及辅助损失的使用。推理时,可以通过事后更改阈值来动态调整稀疏性。  

### 3.4 推理效率  
自剪枝KV注意力减少了自回归解码期间存储和读取的键值条目数量。这直接降低了KV缓存内存使用:在保留密度ρ下,非局部缓存占用大约为全注意力的ρ倍。因此,在相同内存预算下,SP-KV可以支持更大的批量或更长的上下文。我们在图3(右)中评估了一个初始的稀疏解码内核。结果显示,在内存受限的体制中,尤其是批量长上下文解码,有明显的收益。在批量大小16下,SP-KV内核始终比全注意力更快,加速比大约为2.1倍至4.6倍。在实践中,当密度接近100%且序列长度较短时,收益会缩小,因为SP-KV的开销抵消了KV读取的减少。虽然还可以进行进一步优化,但这些结果表明SP-KV的稀疏性通过减少缓存占用和解码期间的内存流量,转化为实际的效率提升。  

## 4 设计更强的混合体:SP-KV用于神经架构搜索  

#### 使用SP-KV作为架构探针。  

除了作为一种高效的注意力机制

相似文章

SparDA:用于高效长上下文 LLM 推理的稀疏解耦注意力

arXiv cs.CL

SparDA 提出了一种解耦稀疏注意力架构,通过添加轻量级"Forecast"投影来预测未来的 KV 缓存需求,从而实现从 CPU 到 GPU 的预取(lookahead prefetching),并降低选择开销。在基于稀疏预训练的 8B 模型上,其 prefill 速度最高可提升 1.25×,decode 速度最高可提升 1.7×,相比非 offload 基线,decode 吞吐量最高可提升 5.3×。