@omarsar0: Nous Research 提出的一个很酷的想法。如果你可以使用一个次二次方复杂度的包装器来加速长上下文预训练,并在部署前移除它,会怎样?
摘要
Nous Research 推出了 Lighthouse Attention,这是一种仅用于训练的次二次方包装器,旨在加速扩展点积注意力(SDPA)的长上下文预训练。该包装器可在部署前移除,从而保持原生推理效率。
查看缓存全文
缓存时间: 2026/05/13 10:19
Nous Research 提出的一个很酷的想法。如果可以用一种次二次复杂度的包装器来加速长上下文预训练,并在部署前将其移除,会怎样?这正是 Lighthouse Attention 背后的理念。该方法在普通的缩放点积注意力(SDPA)外层包裹了一个分层的、无梯度的选择层,对称地压缩和解压缩查询(queries)、键(keys)和值(values),同时保留从左到右的因果性。关键在于,它可以在训练末尾的一个短暂恢复阶段中移除,因此部署后的模型仍然运行标准的注意力机制,在推理时没有任何架构成本。初步的大语言模型(LLM)实验报告称,与全注意力基线相比,该方法实现了更快的总训练时间和更低的最终损失。为什么这很重要?大多数高效注意力工作要么改变了部署时的架构,要么为此付出了质量代价。而一个仅在训练期使用、且能干净恢复的包装器则避开了这两点。如果它能扩展,这将成为长上下文预训练的重要训练时加速。论文:https://arxiv.org/abs/2605.06554 在我们的学院学习构建高效的 AI Agent:https://academy.dair.ai
基于灯塔注意力(Lighthouse Attention)的长上下文预训练
来源:https://arxiv.org/html/2605.06554 Bowen Peng* Nous Research [email protected] &Subho Ghosh* Nous Research [email protected] &Jeffrey Quesnelle Nous Research [email protected]
(2026年2月)
摘要
在极长序列长度下训练因果 Transformer 受到缩放点积注意力(SDPA)二次方时间和内存复杂度的瓶颈限制。在这项工作中,我们提出了灯塔注意力(Lighthouse Attention),这是一种仅用于训练的、基于对称选择的分层注意力算法,它包裹在普通 SDPA 之外,并可在训练末尾轻松移除。我们的分层选择也是无梯度的,这使得我们无需处理复杂且可能低效的后向传播内核。我们的贡献有三点:(i) 一个次二次分层的预处理和后处理步骤,对序列进行自适应压缩和解压缩。(ii) 一种对称压缩策略,同时池化查询、键和值,同时保留从左到右的因果性,从而大大提高并行度。(iii) 一种两阶段训练方法,我们在大部分时间使用灯塔注意力进行预训练,并在最后通过短暂训练恢复出全注意力模型。我们运行了初步的小规模 LLM 预训练实验,与所有其他设置匹配的全注意力训练相比,展示了我们方法的有效性,在恢复阶段后实现了更快的总训练时间和更低的最终损失。
完整代码可在以下地址获取:https://github.com/ighoshsubho/lighthouse-attention.
11footnotetext:同等贡献.## 1引言
语言建模的前沿已转向 128K、1M 及更长的上下文,这是由代理式多步推理、长文档理解以及交错的多模态输入推动的[25 (https://arxiv.org/html/2605.06554#bib.bib14),1 (https://arxiv.org/html/2605.06554#bib.bib15),11 (https://arxiv.org/html/2605.06554#bib.bib16),22 (https://arxiv.org/html/2605.06554#bib.bib17),27 (https://arxiv.org/html/2605.06554#bib.bib18),8 (https://arxiv.org/html/2605.06554#bib.bib19),23 (https://arxiv.org/html/2605.06554#bib.bib20)]。在此规模下的训练是主要的硬件瓶颈:缩放点积注意力具有 (\Theta(N^2)) 的算力和内存需求,这是一堵 FlashAttention[29 (https://arxiv.org/html/2605.06554#bib.bib3)] 推后但并未消除的墙。
越来越多的工作用选择机制替代稠密注意力:每个查询仅关注一小部分键。块级方法如 MoBA[20 (https://arxiv.org/html/2605.06554#bib.bib27)] 和 Native Sparse Attention[36 (https://arxiv.org/html/2605.06554#bib.bib28)] 选择连续块,而令牌级方法如 DeepSeek Sparse Attention (DSA;9 (https://arxiv.org/html/2605.06554#bib.bib33)) 通过学习索引器对每个过去令牌进行评分,并将前 (k) 个送入稀疏注意力算子;HISA[40 (https://arxiv.org/html/2605.06554#bib.bib34)] 添加了分层索引器以防止评分成为新瓶颈。这些方法产生了显著的推理加速,但继承了两个不利于长上下文预训练的设计决策:(i) 非对称性:查询保持全分辨率,而键和值被池化,因此层级结构仅作为压缩的可寻址内存,而非多尺度表示。(ii) 架构纠缠:选择逻辑位于注意力内核内部,因此现代张量核心 GPU 加速的精心优化的稠密注意力内核无法重用;每种稀疏方法都自带其内核。
还有一个特定于训练的问题。推理时的稀疏方法[40 (https://arxiv.org/html/2605.06554#bib.bib34),28 (https://arxiv.org/html/2605.06554#bib.bib21),31 (https://arxiv.org/html/2605.06554#bib.bib22),38 (https://arxiv.org/html/2605.06554#bib.bib23),32 (https://arxiv.org/html/2605.06554#bib.bib24)] 由于其稀疏替换仅针对稠密前向传播进行评估,因此在构造上与其稠密骨干一样好。而训练时的稀疏方法必须经受更严格的测试:一旦训练完成,生成的模型是否仍然是一个合格的稠密注意力模型?
我们将这最后一个问题作为我们的核心正确性标准。我们引入了灯塔注意力:一种基于选择的分层注意力,它在多级金字塔上对称地池化 (Q,K,V),使用无参数评分器对每个金字塔条目进行双向评分,并使用融合的块状位onic内核选择前 (K) 个条目。选定的条目形成一个稠密的、因果一致的子序列,使用标准 FlashAttention 进行关注;输出通过确定性内核散射回原位置。前 (K) 步是不可微的,没有直通估计器:梯度通过散射、FlashAttention 和收集流向 (W_Q,W_K,W_V),这些层学习产生在被选中时有用的值,而不是擅长选择的评分。不添加任何辅助参数或损失。由此产生两个后果:对称金字塔是一个完整的多尺度表示,而非压缩上下文;并且由于选择位于注意力路径之外,昂贵步骤是在大小为 (O(LpK+N/p^{L-1})) 的子序列上运行标准 FlashAttention,这在 (L=\log_p(N/K)) 时简化为 (O(N\log N))。
我们的核心实证发现直接解决了训练正确性问题:在短暂的稠密 SDPA 恢复后,灯塔训练模型在相同令牌预算下从 scratch 训练的全稠密 SDPA 基线相当或更优。分层训练信号并没有掏空模型在推理时使用全注意力的能力,这是仅推理的稀疏方法无法宣称的,因为它们从未触及训练循环。我们总结我们的贡献:
- •一种专为长上下文预训练设计的基于选择的分层注意力,具有对称 (Q/K/V) 池化、双向前 (K) 选择以及在收集到的子序列上使用标准 FlashAttention,将稀疏逻辑完全保留在注意力内核之外。
- •融合的 GPU 内核(块状位onic 前 (K) 和自定义散射回写),使该设计在极大上下文下保持快速。
- •据我们所知,这是训练时分层方法最强的实证标准:灯塔预训练后的稠密 SDPA 恢复在训练损失上与从 scratch 训练的稠密基线匹配。
2相关工作
压缩与剪枝。
应对二次方注意力的第一种响应是放弃 softmax,转而使用有界大小的状态:线性注意力 \katharopoulos2020transformers,[4 (https://arxiv.org/html/2605.06554#bib.bib5)],状态空间和门控变体[12 (https://arxiv.org/html/2605.06554#bib.bib6),6 (https://arxiv.org/html/2605.06554#bib.bib7),34 (https://arxiv.org/html/2605.06554#bib.bib8),30 (https://arxiv.org/html/2605.06554#bib.bib9)],以及对数线性注意力[13 (https://arxiv.org/html/2605.06554#bib.bib10)]:这给出了强大的渐近性能,但压缩了整个过去并限制了长程召回[2 (https://arxiv.org/html/2605.06554#bib.bib11)]。第二种保留 softmax 并在块粒度上剪枝,要么是无训练的(MInference, FlexPrefill, XAttention, SpargeAttention[15 (https://arxiv.org/html/2605.06554#bib.bib12),16 (https://arxiv.org/html/2605.06554#bib.bib13),33 (https://arxiv.org/html/2605.06554#bib.bib25),37 (https://arxiv.org/html/2605.06554#bib.bib26)]),要么是端到端的(MoBA, NSA[20 (https://arxiv.org/html/2605.06554#bib.bib27),36 (https://arxiv.org/html/2605.06554#bib.bib28)]);这些可以很好地映射到平铺矩阵乘法,但强制每个块做出单一的保留/丢弃决定,且仅池化键-值侧。第三种在令牌粒度上剪枝,主要在推理时用于 KV 缓存淘汰(H2O, TOVA, SnapKV, LazyLLM, Quest, SparQ[39 (https://arxiv.org/html/2605.06554#bib.bib29),26 (https://arxiv.org/html/2605.06554#bib.bib30),17 (https://arxiv.org/html/2605.06554#bib.bib31),10 (https://arxiv.org/html/2605.06554#bib.bib32),31 (https://arxiv.org/html/2605.06554#bib.bib22),28 (https://arxiv.org/html/2605.06554#bib.bib21)]),或通过端到端训练的学习索引器(DSA[9 (https://arxiv.org/html/2605.06554#bib.bib33)])。这个家族的定义属性是,一旦确定选择,它就作为自定义稀疏矩阵乘法或每查询收集器焊接到注意力算子中,从而排除了重用标准稠密内核的可能性。
层级结构与训练时正确性。
多分辨率注意力[35 (https://arxiv.org/html/2605.06554#bib.bib38)] 以两种形式回归到稀疏 LLM 注意力中。NSA[36 (https://arxiv.org/html/2605.06554#bib.bib28)]、InfLLM-V2[41 (https://arxiv.org/html/2605.06554#bib.bib35)]、Twilight[18 (https://arxiv.org/html/2605.06554#bib.bib36)] 和 DoubleP[24 (https://arxiv.org/html/2605.06554#bib.bib37)] 构建了注意力本身从压缩分支、质心摘要或量化代理读取的层级结构。HISA[40 (https://arxiv.org/html/2605.06554#bib.bib34)] 是 DSA 索引器的无训练、即插即用替代品,它运行从块到令牌的两阶段评分,并将选定的令牌原样转发到 DSA 已使用的同一稀疏 MLA 算子。在所有情况下,层级结构仅适用于键和值,且产生的选择仍然馈送给自定义稀疏注意力内核。灯塔注意力在三个轴上有所不同:它将查询与键和值对称地池化为相干的多分辨率 (Q^{(\ell)},K^{(\ell)},V^{(\ell)}) 三元组;金字塔仅用于排序和选择,因此随后的注意力是在具有稠密子序列的标准 FlashAttention 上运行,内核内没有稀疏索引;并且它通过不可微的前 (k) 选择(由可微散射包裹)进行端到端训练,没有辅助损失或直通估计器。仅推理的稀疏方法(包括 HISA)从其底层稠密模型继承正确性下限,但训练时的稀疏方法(MoBA, NSA)必须回答它们产生的权重是否仍然是合格的稠密模型。我们将短暂的稠密 SDPA 恢复并恢复出从 scratch 训练的稠密基线质量作为我们的核心正确性标准。
Hierarchical Selector(H_t)Projections(W_Q,W_K,W_V)Pyramid PoolDense GatherSDPABack-scatter(O_t)(\tilde{Q},\tilde{K},\tilde{V})(\tilde{O})(\mathcal{I})
图 1:灯塔注意力架构。前向(黑色):(H_t) 被投影到 (Q,K,V),通过主干上的对称 Pyramid Pool,并受来自 Hierarchical Selector 的索引 (\mathcal{I}) 引导,馈送给稠密收集器,该收集器将收集的层级结构拓扑排序为单个连续且因果的序列,然后进行标准 SDPA,并散射回以产生 (O_t)。选择(绿色):选择器从主干上获取池化摘要;无参数评分器根据 (l_2) 范数对它们进行排名,前 (K) 内核保留最大条目,发出整数索引 (\mathcal{I}) 合并回主干的稠密收集器。梯度(红色,虚线):(\nabla L) 沿主干传播((O_t \rightarrow) scatter (\rightarrow) FA (\rightarrow) gather (\rightarrow) pyramid pool (\rightarrow W_{Q,K,V} \rightarrow H_t));选择器分支不可微并被绕过。
3方法
我们介绍灯塔注意力,一种用于长上下文预训练的基于选择的分层注意力机制。灯塔注意力用一个四阶段流水线替换标准 Transformer 注意力层,该流水线包围但不修改注意力内核:前注意力选择阶段驱动连续收集,标准 FlashAttention[7 (https://arxiv.org/html/2605.06554#bib.bib40)] 在收集到的子序列上运行,后注意力散射将结果写回原始位置。选择由该层自身查询、键和值的多分辨率金字塔上的无参数评分功能驱动,因此灯塔注意力在底层注意力块之外没有引入新的可学习参数。
3.1预备知识
设 (X \in \mathbb{R}^{N \times d_{\text{model}}}) 为输入,(W_Q,W_K,W_V \in \mathbb{R}^{d_{\text{model}} \times d}) 为一个头的投影矩阵,(M \in \mathbb{R}^{N \times N}) 为因果掩码。标准缩放点积注意力[5 (https://arxiv.org/html/2605.06554#bib.bib39)]为
(Q=XW_Q,\quad K=XW_K,\quad V=XW_V,\qquad \mathrm{Attn}(Q,K,V)=\mathrm{softmax}!\left(\frac{QK^{\top}}{\sqrt{d}}+M\right)V,)(1)
其中时间和内存成本均为 (\Theta(N^2d))。FlashAttention 降低了常数但不改变渐近复杂度;当 (N \geq 10^5) 时,该项占主导地位。灯塔注意力用以下方法替换公式 (1 (https://arxiv.org/html/2605.06554#S3.E1)):(i) 将 (Q,K,V) 对称平均池化为 (L) 级金字塔(因子 (p));(ii) 在所有级别上联合进行无参数评分和融合的块状位onic 前 (k) 选择;(iii) 在 (S \ll N) 个选定条目的连续子序列上运行标准 FlashAttention;(iv) 散射回写,将每个输出分布到它代表的 (p^\ell) 个基础位置。阶段 (ii) 和 (iv) 是自定义内核(第5节 (https://arxiv.org/html/2605.06554#S5));阶段 (iii) 是与稠密基线相同的 FlashAttention 调用。前 (k) 被视为离散且不可微:索引不携带梯度,评分功能也不进行训练。梯度仅通过阶段 (iv)、(iii) 和收集到达 (W_Q,W_K,W_V):投影层学习产生被选中时有用的值,而不是擅长选择的评分,从而避开了可学习选择器的优化脆弱性。
Hierarchical Selector1. Pyramid Pooll=2(\ell=2)l=1(\ell=1)l=0(\ell=0)mean-pool by(p^\ell)2. Norm Scorel=0(\ell=0)l=1(\ell=1)l=2(\ell=2)l2范数 + max-pool3. Chunked Bitonic Top-Kl=2(\ell=2)l=1(\ell=1)l=0(\ell=0)parentkeptbitonic argsort + tree-prune(Q,K,V)[B,S,H,D][B,S,H,D](\mathcal{I})[B,H,K][B,H,K]pyramidtokensscoresqkscoreskq
图 2:Pyramid Pool 和 Hierarchical Selector。Pyramid Pool 是一个固定的预选阶段,位于选择器之外。((1)) Pyramid Pool 以 (p^\ell) 对 (Q,K,V) 进行平均池化;线条显示哪些令牌馈送给每个摘要。池化后的令牌进入选择
相似文章
提示缓存,但用于 RL 训练——在长提示/短回复负载上实现 7.5 倍加速
一种面向开源 RL 训练引擎的全新优化技术在训练过程中引入了提示缓存,通过减少冗余计算,在长提示、短回复负载场景下实现了高达 7.5 倍的加速。
Ulysses 序列并行:百万Token上下文训练
Ulysses 序列并行是一种用于训练具有百万Token上下文的大语言模型的技术,通过将序列块分布在多个GPU上来降低内存需求,实现高效的长上下文训练。它与HuggingFace Accelerate、Transformers Trainer和TRL集成,支持Flash Attention和DeepSpeed ZeRO。
@songhan_mit: 探索简化 OPD 以高效进行 LLM 后训练:
本文介绍了一种简化 OPD 以实现大语言模型高效后训练的方法。
浅层预填,深层解码:通过层非对称 KV 可见性实现高效的长上下文推理
本文介绍了 SPEED,一种层非对称 KV 可见性策略,通过仅在预填阶段的下层处理提示 token,同时在解码阶段保持全深度注意力,从而降低长上下文推理的成本。
LongAct:利用内在激活模式进行长上下文强化学习
LongAct 提出了一种显著性引导的稀疏更新策略,通过选择性更新与查询和键向量中高幅值激活相关的权重来改进 LLMs 的长上下文推理能力,在 LongBench v2 上实现了约 8% 的提升。