如何在强化学习后训练中压缩 KV 缓存?用于内存高效对齐的阴影掩码蒸馏

arXiv cs.LG 论文

摘要

本文提出了阴影掩码蒸馏(SMD),旨在解决大语言模型在强化学习后训练中因 KV 缓存压缩而导致的离策略偏差。该方法引入了一种机制,确保在策略上的对齐,并提高长上下文推理任务的内存效率。

arXiv:2605.06850v1 公告类型:新论文 摘要:强化学习(RL)已成为释放大型语言模型(LLM)高级推理能力的关键范式,涵盖了如 RLHF 和 RLAIF 等框架。无论采用何种具体的优化算法(例如 PPO、GRPO 或 Online DPO),在线 RL 本质上都需要一个探索性的轨迹生成(rollout)阶段。然而,对于长上下文推理任务,这一 rollout 阶段由于巨大的 Key-Value (KV) 缓存占用量,引发了严重的“内存墙”问题。虽然在 rollout 期间应用 KV 缓存压缩可以缓解这种内存开销,但它会引发关键的离策略偏差。尽管现代 KV 压缩在标准推理过程中通常是几乎无损的,但即使微小的近似误差也会因 RL 优化固有的不稳定性而被大幅放大。具体而言,采样器在稀疏上下文中生成响应,而学习者则使用完整、密集的上下文更新参数。现有的统计解决方案(如重要性重加权)难以纠正这种放大的偏差,且存在高梯度方差和严重的样本效率低下问题。
查看原文 导出为 Word 导出为 PDF
查看缓存全文

缓存时间: 2026/05/11 06:57

# 如何在 RL 后训练中压缩 KV Cache?基于 Shadow Mask 蒸馏的记忆高效对齐

来源: https://arxiv.org/html/2605.06850
Rui Zhu1 Weiheng Bai2 Qiushi Wu2 Yang Ren1 Haixu Tang3 Yuchu Liu1

1耶鲁大学 2明尼苏达大学双城分校 3印第安纳大学布卢明顿分校

###### 摘要

强化学习(RL)已成为释放大型语言模型(LLMs)高级推理能力的关键范式,涵盖了 RLHF 和 RLAIF 等框架。无论采用何种特定的优化算法(例如 PPO、GRPO 或 Online DPO),在线 RL 本质上都需要一个探索性的轨迹生成(rollout)阶段。然而,对于长上下文推理任务,由于 Key-Value(KV)缓存占用了巨大的内存空间,这一 rollout 阶段带来了严峻的“内存墙”问题。虽然在 rollout 过程中应用 KV 缓存压缩可以缓解这一内存开销,但它会引发严重的离策略(off-policy)偏差。尽管现代 KV 压缩技术在标准推理中通常是近乎无损的,但即使是微小的近似误差,也会因 RL 优化固有的不稳定性而被急剧放大。具体而言,采样器在稀疏上下文下生成响应,而学习者则使用完整、密集的上下文来更新参数。现有的统计解决方案,如重要性重加权,难以纠正这种放大的偏差,往往导致梯度方差高且样本效率极低。

在本文中,我们提出了 **Shadow Mask Distillation (SMD)**,这是一种优雅的结构化框架,旨在消除这种结构性失配。SMD 不再依赖事后统计修补,而是将在稀疏 rollout 期间记录的“Shadow Mask(阴影掩码)”直接注入到学习者的注意力层中,从数学上保证了完美的在策略(on-policy)对齐。此外,我们引入了一种双轨 KL 蒸馏机制,以将全局上下文知识从密集策略转移到掩码策略中。在 4B 模型上的广泛实验验证了 SMD 的显著功效。在 50% 的 KV 缓存压缩率下,SMD 实现了**近乎无损的压缩**,与未压缩的基线模型保持高度的竞争力(例如,在 GSM8K 上达到 73.6% vs. 74.5%)。此外,它防止了 SOTA 拒绝采样基线中固有的严重长上下文性能退化,其掩码模拟完全消除了原生 VRAM 峰值,为长上下文 RL 树立了稳健、记忆高效的标准。

## 1 引言

大型语言模型(LLMs)(Brownet al., 2020 (https://arxiv.org/html/2605.06850#bib.bib33); Touvronet al., 2023 (https://arxiv.org/html/2605.06850#bib.bib10); Baiet al., 2023 (https://arxiv.org/html/2605.06850#bib.bib12); Xionget al., 2025 (https://arxiv.org/html/2605.06850#bib.bib72)) 取得了前所未有的成功,这主要得益于诸如基于人类反馈的强化学习(RLHF)(Ouyanget al., 2022 (https://arxiv.org/html/2605.06850#bib.bib1); Baiet al., 2022 (https://arxiv.org/html/2605.06850#bib.bib39); Liet al., 2025 (https://arxiv.org/html/2605.06850#bib.bib69)) 等对齐技术。最近的进展,特别是组相对策略优化(GRPO)(Shaoet al., 2024 (https://arxiv.org/html/2605.06850#bib.bib2)) 和在线直接偏好优化(Online DPO)(Rafailovet al., 2024 (https://arxiv.org/html/2605.06850#bib.bib20)),进一步简化了这一过程。然而,随着 LLM 越来越多地部署在长上下文应用中 (Ding and others, 2024 (https://arxiv.org/html/2605.06850#bib.bib60)),模型自回归生成多个轨迹的 rollout 阶段遭遇了 formidable “内存墙”。存储长序列的 Key-Value (KV) 缓存 (Vaswaniet al., 2017 (https://arxiv.org/html/2605.06850#bib.bib22); Kwonet al., 2023 (https://arxiv.org/html/2605.06850#bib.bib30); Shazeer, 2019 (https://arxiv.org/html/2605.06850#bib.bib56)) 所需的巨大内存 footprint 严重限制了批量大小和训练吞吐量,使得在标准硬件上进行长上下文 RLHF 计算成本高昂。

一种自然的对策是在 rollout 阶段应用 KV 缓存压缩算法,例如 SnapKV (Liet al., 2024 (https://arxiv.org/html/2605.06850#bib.bib3)) 或 H2O (Zhanget al., 2023 (https://arxiv.org/html/2605.06850#bib.bib5)),以稀疏化上下文。虽然这缓解了内存瓶颈,但它无意中在 RL 流水线中引入了结构性的**二分法(dichotomy)**。在生成期间,actor 表现为受限于稀疏上下文的“短视”策略 ($\pi_{\text{sparse}}$)。然而,在优化阶段,学习者使用未压缩的密集上下文 ($\pi_{\text{dense}}$) 来评估这些轨迹。这种不对称性严重违反了策略梯度方法的核心假设 (Sutton and Barto, 2018 (https://arxiv.org/html/2605.06850#bib.bib7); Wanget al., 2024 (https://arxiv.org/html/2605.06850#bib.bib73)),造成了巨大的离策略偏差,误导了梯度更新,并经常导致不可逆的策略崩溃 (Chenet al., 2024 (https://arxiv.org/html/2605.06850#bib.bib71))。

当前的文献试图通过统计干预来解决这种分歧。最新的 state-of-the-art 方法依赖于稀疏感知拒绝采样和基于重要性的重加权 (Luoet al., 2024 (https://arxiv.org/html/2605.06850#bib.bib8))。不幸的是,这些事后统计补丁治标不治本。拒绝采样丢弃昂贵的 rollout 轨迹,导致极低的样本效率;与此同时,重要性重加权引入了极高的梯度方差,严重 destabilize 训练动态。

为了打破这一僵局,我们提出了 **Shadow Mask Distillation (SMD)**,这是一种结构性而非统计性的解决方案。我们观察到,如果学习者经历与生成器完全相同的信息瓶颈,物理上的在策略对齐就可以恢复。SMD 通过在稀疏 rollout 期间记录二进制“Shadow Mask”并将其物理注入到学习者的因果注意力矩阵中来实现这一点。这种时间冻结机制保证了严格的 $\pi_{\text{sparse}} \equiv \pi_{\text{shadow}}$ 对齐,完全消除了离策略方差。为了防止模型过拟合于截断的上下文,我们同时执行密集前向传递,应用 Kullback-Leibler (KL) (Kullback and Leibler, 1951 (https://arxiv.org/html/2605.06850#bib.bib38)) 散度惩罚,将全局上下文推理蒸馏到掩码策略中。此外,我们揭示了一个关键的工程见解:高级框架中用于 KV 驱逐的原生张量切片会引发“非原地分配峰值(Not-In-Place Allocation Spike)”。SMD 完全绕过了这个问题,提供了一种与框架无关的解决方案,避免了灾难性的内存溢出 (OOM) 峰值,而无需进行低级别的 C++/CUDA 修改 (Zhenget al., 2023 (https://arxiv.org/html/2605.06850#bib.bib29); Kwonet al., 2023 (https://arxiv.org/html/2605.06850#bib.bib30))。

我们的主要贡献总结如下:

- • 我们识别了记忆高效 RLHF 中的结构性离策略失配,并提出了 **Shadow Mask Distillation**,这是一个优雅的双轨框架,保证完美的梯度对齐且零数据浪费。
- • 我们通过实证证明,SMD 完全消除了记忆受限的学习者优化阶段中物理 KV 驱逐固有的巨大瞬时内存碎片化峰值,为长上下文生成提供了完美稳健的执行环境。
- • 我们揭示了注意力稀疏化 rollouts 的隐式正则化效应。在 Reddit TL;DR 基准测试中,SMD 在 ROUGE-L 上优于密集基线(相对 +0.6%),并在收敛速度和训练稳定性方面显著超越统计重加权基线。

## 2 相关工作

#### 记忆高效的 RLHF 和 GRPO。

通过 RL 对齐 LLM,例如近端策略优化(PPO)(Schulmanet al., 2017 (https://arxiv.org/html/2605.06850#bib.bib4)),通常需要维护多个模型副本(Actor, Critic, Reference, Reward),从而产生巨大的内存开销。GRPO (Shaoet al., 2024 (https://arxiv.org/html/2605.06850#bib.bib2)) 通过消除 Critic 模型并利用组相对优势来缓解这一问题。尽管如此,自回归 rollout 阶段仍然受到 KV 缓存分配的严重瓶颈限制 (Kwonet al., 2023 (https://arxiv.org/html/2605.06850#bib.bib30)),特别是当序列长度扩展到数万 token 时。我们的工作直接针对这一 rollout 内存墙,提出了一种正交优化方法,可以无缝集成到现代 RLHF 流水线中 (Ouyanget al., 2022 (https://arxiv.org/html/2605.06850#bib.bib1); Stiennonet al., 2020 (https://arxiv.org/html/2605.06850#bib.bib17))。

#### LLM 中的 KV 缓存压缩。

为了解决 Transformer 的线性内存扩展问题,提出了各种 KV 缓存压缩策略。基于驱逐的方法,如 H2O (Zhanget al., 2023 (https://arxiv.org/html/2605.06850#bib.bib5)) 和 SnapKV (Liet al., 2024 (https://arxiv.org/html/2605.06850#bib.bib3)),根据累积的注意力分数选择性保留“重击手”token。无需调整的量化框架如 KIVI (Liuet al., 2023 (https://arxiv.org/html/2605.06850#bib.bib24)) 进一步推动了压缩极限。像 StreamingLLM (Xiaoet al., 2024 (https://arxiv.org/html/2605.06850#bib.bib6)) 这样的替代方法利用注意力 sink 现象或流式机制,在无限长度设置下保持稳定的生成,而最近的工作如 FOCUS (Zhuet al., 2025 (https://arxiv.org/html/2605.06850#bib.bib75)) 将近无损压缩能力扩展到超长 DNA 序列等专门领域。虽然这些方法在标准推理期间非常有效,但将它们原生集成到分布式 RL 训练循环中(例如,Megatron-LM (Shoeybiet al., 2019 (https://arxiv.org/html/2605.06850#bib.bib32)) 或 Ray (Moritzet al., 2018 (https://arxiv.org/html/2605.06850#bib.bib31); Fanet al., 2025 (https://arxiv.org/html/2605.06850#bib.bib74)))会引发深刻的离策略偏差 (Zhuet al., 2023 (https://arxiv.org/html/2605.06850#bib.bib70))。我们的 Shadow Mask 框架允许任意基于驱逐的压缩算法(例如 SnapKV 或随机保留)安全地集成到 RLHF 循环中,而不会破坏策略梯度。

#### RL 中朴素 KV 压缩的失败。

纠正离策略偏差是 RL 中的经典问题。标准方法依赖于重要性采样(IS)(Sutton and Barto, 2018 (https://arxiv.org/html/2605.06850#bib.bib7)) 来重新加权梯度,这通常通过裁剪机制来稳定,正如 PPO (Schulmanet al., 2017 (https://arxiv.org/html/2605.06850#bib.bib4)) 中著名实现的那样。然而,虽然在标准推理期间直接应用 KV 压缩无缝工作,但将其原生插入到 RL 训练循环中以辅助 rollout 生成则会灾难性地失败。其原因源于 actor 的稀疏生成与学习者的密集评估之间的严重结构性失配。为了说明这一点,当标准的 PPO/GRPO 流水线在 rollouts 期间天真地增加 50% 的 SnapKV 压缩时,训练动态表现出严重的奖励崩溃,模型在 GSM8K 基准上的准确度从 74.5% 暴跌至无法使用的 64.3%(详见第 4 节 (https://arxiv.org/html/2605.06850#S4))。这 unequivocally 表明,如果没有严格的离策略校正,朴素的 KV 压缩与在线 RL 根本不相容。

#### 开创性努力及其局限性。

迄今为止,Sparse-RL (Luoet al., 2024 (https://arxiv.org/html/2605.06850#bib.bib8)) 是该新兴领域中唯一的开创性努力,成功识别了这一关键瓶颈,并为长上下文记忆高效 RL 提出了第一个可行的框架。通过巧妙结合稀疏感知拒绝采样和重要性重加权,Sparse-RL 有效地缓解了密集上下文和稀疏上下文之间的分歧。然而,其统计性质 inherently 限制了其有效性。在 LLM 的语境中,由于动作空间(词汇表)和序列长度巨大,重要性比率呈指数级增长 (Luoet al., 2024 (https://arxiv.org/html/2605.06850#bib.bib8)),导致灾难性的梯度方差。因此,Sparse-RL 被迫激进地丢弃大量昂贵的 rollout 轨迹(通常 >20%)以维持稳定性。

受这种巨大的数据浪费和持续的梯度方差启发,我们认为离策略困境不应通过统计手段修补,而应通过结构设计来解决。这一深刻见解直接激励了我们提出的 Shadow Mask Distillation 方法,我们将在下一节中详细说明。

## 3 方法论

在本节中,我们介绍 **Shadow Mask Distillation**,这是一种新颖的结构化框架,旨在消除 RLHF 中由 KV 缓存压缩引起的离策略偏差,而不依赖高方差的统计补丁。我们首先建立对不对称上下文困境的直觉理解,随后形式化我们的双轨机制。图 1 (https://arxiv.org/html/2605.06850#S3.F1) 提供了我们提出的方法的整体直觉。

参见标题 **Figure 1**: Shadow Mask Distillation (SMD) 的整体架构。在阶段 1(Rollout)中,KV 驱逐算法动态丢弃 token 以节省内存,将保留索引记录到二进制 Shadow Mask ($M$) 中。在阶段 2(Learner)中,SMD 执行双轨前向传递:**对齐轨道**应用 Shadow Mask 以完美重建稀疏生成环境,用于严格的 on-policy GRPO 参数更新,而**蒸馏轨道**利用完整的密集上下文,通过 KL 散度隐式正则化短视的稀疏策略。

### 3.1 预备知识与直觉:不对称上下文困境

在标准的 RLHF 算法如 GRPO (Shaoet al., 2024 (https://arxiv.org/html/2605.06850#bib.bib2)) 中,训练流水线由两个解耦的阶段组成:**rollout**(轨迹生成)和 **learner**(参数更新)。对于给定的提示 $x$,模型通过从当前策略 $\pi_{\theta}$ 中采样生成一组输出 $\{y_{1}, y_{2}, \dots, y_{K}\}$,并使用估计的优势 $\hat{A}_{i}$ 更新参数。

为了克服长上下文任务中的“内存墙”,记忆高效系统在 rollout 阶段应用 KV 缓存压缩(例如 SnapKV)。这创建了一个受限的策略,记为 $\pi_{\text{sparse}}$。然而,在学习者阶段,标准框架使用完整的、未压缩的上下文重新计算对数概率,实际上是在全知策略 $\pi_{\text{dense}}$ 下评估轨迹。

**直觉:** 这种差异类似于不公平的评审过程:玩家在蒙眼的情况下导航迷宫(稀疏 rollout),但教练使用完整的全局地图(密集 learner)来评估他们的移动。教练惩罚玩家未能利用他们从未观察到的信息。虽然最近的工作试图通过统计重要性重加权来修补这一点,但这种方法遭受极高的梯度方差和样本效率低下的问题。相反,我们的方法从物理上完全消除了这种偏差:我们只是在评估期间给教练戴上完全相同的眼罩。

### 3.2 轨道 1:通过 Shadow Masking 实现结构化的在策略对齐

为了原生地将学习者与 rollout 生成器对齐,我们引入了 **Shadow

相似文章

自蒸馏作为大语言模型的性能恢复机制:对抗压缩和灾难性遗忘

arXiv cs.CL

本文介绍了自蒸馏微调(SDFT)作为大语言模型性能恢复机制,用于解决灾难性遗忘、量化和剪枝导致的性能下降问题。作者利用中心核对齐(CKA)提供了理论证明,表明自蒸馏能够使学生模型的高维流形与教师模型的最优结构对齐,从而有效恢复丧失的能力。