通过学习的Token路由在Transformer中实现自适应计算深度
摘要
本文提出了Token-Selective Attention (TSA),一种可微的token路由机制,它学习在每个token上跳过Transformer层中不必要的计算,从而在语言建模任务中将token层操作减少14-23%,且质量损失极小。
查看缓存全文
缓存时间: 2026/05/08 06:35
# 基于学习令牌路由的Transformer自适应计算深度
来源:https://arxiv.org/html/2605.05222
Ahmed Abdelmuniem Abdalla Mohammed 独立研究员 ahmed\.abdelmuniem@gmail\.com ORCID: 0009\-0008\-7410\-6621 (https://orcid.org/0009-0008-7410-6621)
###### 摘要
标准Transformer架构对每个令牌应用相同数量的层,无论上下文难度如何。我们提出**令牌选择性注意力(TSA)**,一种学习到的逐令牌门控,作用于连续Transformer块之间的残差更新。每个门控是一个轻量级两层多层感知器(MLP),产生一个连续的暂停概率,使该机制实现端到端可微,参数开销为1.7%,且无需更改基础架构。值得注意的是,TSA在没有任何显式深度压力的情况下学习到难度成比例的路由:即使λ=0(无深度正则化),任务损失梯度单独驱动路由器跳过20%的令牌层操作。在字符级语言建模上,TSA在Tiny-Shakespeare和enwik8上节省了14–23%的令牌层操作(TLOps),质量损失小于0.5%。在匹配效率下,TSA的验证损失比早退出低0.7%,并且学习到的路由直接迁移到推理时的稀疏执行,实现真实的墙钟加速。
关键词:自适应计算、令牌路由、稀疏Transformer、高效推理、深度正则化
## 1 引言
Transformer语言模型(Vaswani et al., 2017 (https://arxiv.org/html/2605.05222#bib.bib6))对每个序列中的每个令牌应用固定数量的层。这种设计以逐令牌适应性换取架构简洁性。实际上,这种权衡代价高昂:在可预测上下文中的常见令牌所需的处理远少于在新颖构造中的稀有令牌,但两者在前向传递的每一层都获得相同的计算量。
这种低效性在推理规模上尤为严重。对于大型部署模型,主要成本是对所有令牌通过所有层的前向传递。如果相当一部分令牌可以在不损失质量的情况下提前退出,节省的计算量将直接转化为更低的延迟和更高的吞吐量。
已有几种方法解决了这个问题。Graves (2016 (https://arxiv.org/html/2605.05222#bib.bib1)) 为循环神经网络(RNN)引入了自适应计算时间(ACT),在循环步骤中累积暂停概率。Dehghani et al. (2019 (https://arxiv.org/html/2605.05222#bib.bib2)) 将这一思想扩展到深度共享的Transformer层,提出了Universal Transformer。最近,Raposo et al. (2024 (https://arxiv.org/html/2605.05222#bib.bib3)) 提出了深度混合(Mixture-of-Depths, MoD),使用硬top-k选择将令牌路由通过一个固定的子集层;Bae et al. (2025 (https://arxiv.org/html/2605.05222#bib.bib4)) 引入了混合递归(Mixture of Recursions),对每个令牌应用学习到的步数递归块;Chen et al. (2025 (https://arxiv.org/html/2605.05222#bib.bib5)) 提出了内部思考Transformer(Inner Thinking Transformer),在高风险位置插入额外的计算步骤。
我们提出**令牌选择性注意力(TSA)**:一种连续软门控,作用于残差更新,基于每个令牌的当前隐藏状态。该机制在架构上最小化——每个块间间隙一个两层MLP——并且完全可微,无需直通估计器、Gumbel采样或强化学习。我们的贡献是:
- • 一种简单、可微的令牌路由机制,对每个层每个令牌软门控残差更新(§2 (https://arxiv.org/html/2605.05222#S2))。
- • 路由从任务损失梯度中出现的证据:在λ=0(无深度正则化)下,路由器仅通过任务损失梯度就学会跳过20%的令牌层操作,而没有任何显式的深度压力(§3.4 (https://arxiv.org/html/2605.05222#S3.SS4))。
- • 跨数据集的字符级语言建模验证:在Tiny-Shakespeare和enwik8上节省14–23%的令牌层操作,质量损失小于0.5%(§3.2 (https://arxiv.org/html/2605.05222#S3.SS2),§3.3 (https://arxiv.org/html/2605.05222#S3.SS3))。
- • 消融实验显示对λ在两个数量级范围内的鲁棒性(§3.4 (https://arxiv.org/html/2605.05222#S3.SS4)),在匹配效率下质量优于早退出(§3.5 (https://arxiv.org/html/2605.05222#S3.SS5)),以及在商品硬件上通过稀疏推理实现真实墙钟加速(§3.6 (https://arxiv.org/html/2605.05222#S3.SS6))。
## 2 方法
### 2.1 架构
设一个预层归一化解码器专用Transformer具有块f0, f1, ..., fL-1,每个块应用多头自注意力和前馈网络(FFN),带有残差连接和LayerNorm(Ba et al., 2016 (https://arxiv.org/html/2605.05222#bib.bib12))。在TSA中,在每个块fl之后插入一个轻量级路由器rl,其中l=0, ..., L-2。
块f0是**主干**,总是无条件执行:裸令牌嵌入不携带上下文信号,在步骤零做路由决策是无信息且可能退化的。路由从主干之后开始:
h ← f0(h), pl = rl(h), h ← fl+1(h, pl), l=0, ..., L-2. (1)
图1 (https://arxiv.org/html/2605.05222#S2.F1) 说明了双模式机制:训练时的软门控(可微)和推理时的硬阈值稀疏执行(真实FLOPs节省)。
参见图注
**图1:**TSA双模式架构。路由器rl读取隐藏状态h并为每个令牌生成暂停概率pl。**左(训练)**:所有令牌始终通过attn + FFN;残差更新由(1-pl)软缩放,保持门控可微以使路由器学习。**右(推理)**:注意力保持密集,但pl>0.5的令牌完全跳过FFN,通过gather/scatter实现,从而节省真实FLOPs。主干块f0始终无条件执行。
### 2.2 路由器架构
每个路由器是一个具有sigmoid输出的两层MLP:
rl(h) = σ(Wl(2) ReLU(Wl(1) h + bl(1)) + bl(2)), rl(h) ∈ (0,1)B×T, (2)
其中隐藏维度为d/4(最小为16)。每个路由器增加d²/4 + d/2 + 1个参数;在d=256,L=6时,对于4.78M参数的基础模型总计约83K(1.7%开销)。
最后偏置bl(2)初始化为-1.0,初始化时σ(-1)≈0.27。该偏置防止在模型学习到有用表示之前早期崩溃为“全部暂停”。
### 2.3 门控块更新
对于每个路由决策l=0, ..., L-2,块fl+1的门控更新为:
h ← h + (1 - pl) ⊙ Δ^{attn}_{l+1}(h), (3)
h ← h + (1 - pl) ⊙ Δ^{ffn}_{l+1}(h), (4)
其中pl ∈ (0,1)B×T沿模型维度d广播,Δ^{attn}_{l+1}和Δ^{ffn}_{l+1}分别是块fl+1的预归一化注意力和前馈残差增量。当pl=0时,更新与标准Transformer相同。当pl=1时,状态不变——该块被跳过。插值是平滑的,在训练期间保持通过pl的梯度流。
### 2.4 深度正则化
如果没有暂停的激励,路由器默认pl≈0,TSA退化为带额外参数的标准Transformer。我们添加了一个深度正则化项,温和地鼓励早期暂停:
L_depth = λ · (1/(L-1)) Σ_{l=0}^{L-2} (1-pl)的均值, (5)
其中(1-pl)的均值是层l的平均活跃分数(在批次和序列位置上平均)。总训练损失为L = L_task + L_depth。我们在语言实验中使用λ=0.001;§3.4 (https://arxiv.org/html/2605.05222#S3.SS4) 证明了TSA在λ∈[0, 0.1]范围内具有鲁棒性。
### 2.5 计算度量
我们使用**令牌层操作(TLOps)**来衡量计算量:对于每个块,TLOps等于在该块处理的令牌数。跨路由决策的平均活跃分数为:
α = (1/(L-1)) Σ_{l=0}^{L-2} (1-pl)的均值。 (6)
相对于固定深度基线的TLOps节省为:
Δ = 1 - (1 + (L-1)α) / L。 (7)
主干块(始终活跃)包含在分子和分母中,使Δ成为一个保守估计。
***训练计算说明。** 在训练期间,所有层都完全执行:门控缩放残差更新但不跳过计算。因此TLOps衡量的是每层对最终表示的有效贡献,而非实际节省的FLOPs。在推理时,稀疏TSA(§3.6 (https://arxiv.org/html/2605.05222#S3.SS6))利用gather/scatter操作低贡献位置,实现真实计算节省。
## 3 实验
### 3.1 合成算法任务
#### 设置。
我们使用解码器专用Transformer在**复制**和**排序**任务上进行训练,序列长度为10,词汇表大小为32。输入格式为[BOS] src [SEP] tgt [EOS],对源令牌掩码损失。基线模型和TSA均使用d=128,L=6,H=4,d_ff=512(基线:1.20M参数;TSA:1.22M参数,+1.7%)。训练使用AdamW(Loshchilov and Hutter, 2019 (https://arxiv.org/html/2605.05222#bib.bib11)),β=(0.9,0.95),lr=3×10^{-4},λ_wd=0.1,在10K训练序列上进行10K梯度步。我们报告在1K保留序列上的令牌级序列准确率。
#### 结果。
**表1:**合成任务结果(d=128,L=6,玩具词汇表)† TLOps节省 = 1 - (1+(L-1)α)/L,包括强制的骨干块。
路由模式直接反映了任务难度。复制是恒等映射:路由器学会几乎所有令牌在主干块后就已完全确定(α=0.341;总体节省54.9% TLOps)。排序需要比较和排列,得到α=0.730——任务真正需要更多计算的地方。这种难度成比例分配在没有关于任务身份或难度的任何显式监督下出现。
### 3.2 字符级语言建模
#### 设置。
我们在Tiny-Shakespeare(Karpathy, 2015 (https://arxiv.org/html/2605.05222#bib.bib10))上训练(1.1M字符,65字符词汇表,80/10/10训练/验证/测试划分)。两个模型均使用d=256,L=6,H=8,d_ff=1024,上下文长度128。训练使用AdamW,余弦学习率调度,批次大小64,进行5,000梯度步(基线:4.78M参数;TSA:4.86M参数,+1.7%)。令牌嵌入初始化时没有填充索引:字符索引0是换行符(约占语料库的8%),其嵌入梯度不能被归零。
#### 结果。
**表2:**语言建模结果(d=256,L=6,Tiny-Shakespeare)验证损失增加:+0.006 nats(+0.4%相对)。BPC = bits-per-character。
TSA达到α=0.726:节省22.8%的令牌层操作,验证损失成本仅为0.006 nats(+0.4%)。两个模型在相同步数达到所有收敛阈值,表明TSA不妨碍收敛(图2 (https://arxiv.org/html/2605.05222#S3.F2))。TSA曲线在计算轴上始终位于左侧,证实节省在整个训练过程中保持稳定。
参见图注
(a) 验证损失 vs. 训练步数。
(b) 验证损失 vs. 累积TLOps(×10^9)。
**图2:**TSA(红色)与Baseline(蓝色)在Tiny-Shakespeare上的表现。左:等效收敛速度。右:TSA以22.8%更少的令牌层操作达到相同的损失。
### 3.3 跨数据集验证:enwik8
为了测试TSA是否能泛化到单个语料库之外,我们在enwik8(Hutter, 2006 (https://arxiv.org/html/2605.05222#bib.bib14))上训练:英语维基百科的前10^8字节(原始XML,6,064个唯一字符)。该语料库比莎士比亚多样化得多——包含标记、多语言文本、表格和数学符号。我们使用d=256,L=6,H=8,d_ff=1024,上下文长度256,批次大小64,进行5,000步(6.35M参数基线;6.43M TSA)。实验在Apple M1 Pro上使用MLX(Apple Machine Learning Research, 2023 (https://arxiv.org/html/2605.05222#bib.bib16))进行。
**表3:**enwik8结果(d=256,L=6,上下文256)TSA质量 vs. 基线:-0.4%(TSA略好;在噪声范围内)。
TSA在enwik8上达到α=0.833,比莎士比亚的α=0.726更为保守。路由器在结构多样的维基百科语料库上分配了更多计算,同时仍以零质量成本节省了13.9%的TLOps。两个条件在相同步数(500、750、1,000步)达到所有收敛阈值(≤2.5、≤2.0、≤1.8 BPC)。跨数据集结果证实路由机制学习到了内容相关的信号,而不是过拟合到语料库特定模式。训练曲线见图5(附录)。
### 3.4 消融:深度正则化敏感性
我们在Tiny-Shakespeare上对λ∈{0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5}进行了扫描,所有其他超参数固定为§3.2中的值。图3(a)显示了质量-效率帕累托曲线;完整结果见表6(附录)。
参见图注
(a) λ扫描:验证损失 vs. 活跃分数。
(b) 早退出 vs. TSA帕累托曲线。
**图3:**在Tiny-Shakespeare上的消融研究。**左**:TSA在λ∈[0,0.1]范围内具有鲁棒性;质量范围为0.015 nats(1.04%)。即使λ=0也能产生有意义的路由。**右**:TSA(红星)在匹配的α≈0.726下主导了早退出阈值扫描(蓝色)。发现了三个结果。*(i) λ=0仍然路由:*没有任何显式的深度压力,路由器仅通过任务损失梯度就学会节省20.4%的TLOps(α=0.755)。这是核心发现:门控乘法h += (1-pl)⊙Δ提供了一个内在学习信号——当某层的残差更新有噪声或冗余时,梯度倾向于增加pl以衰减该更新,即使没有正则化。路由器因此充当了一个学习到的噪声门。*(ii) 鲁棒性:*在λ∈[0,0.1]范围内,质相似文章
面向大型语言模型归因引导的持续学习
本文提出了一种面向大型语言模型的归因引导持续微调框架,该框架能够估计 Transformer 层中特定任务相关的参数重要性并相应地调节梯度,在保持新任务性能的同时缓解了灾难性遗忘。
使用稀疏Transformer进行生成建模
OpenAI推出了稀疏Transformer,一种深度神经网络,将注意力机制的复杂度从O(N²)优化到O(N√N),使得能够对长度超过以前30倍的序列进行建模,适用于文本、图像和音频领域。该模型采用稀疏注意力模式和基于检查点的内存优化技术,可以训练深达128层的网络,在多个领域实现了最先进的性能。
@simplifyinAI: DeepSeek 对 Transformer 架构进行了根本性重构。它解决了导致大规模 AI 模型崩溃的“身份危机”……
DeepSeek 发表了一篇论文,介绍了 mHC(流形约束超连接,Manifold-Constrained Hyper-Connections),这是一种对 Transformer 架构的根本性重写,通过用数学约束的多流路径替换标准残差连接,来稳定大型模型。
(1D) 有序词元实现高效测试时搜索
# 论文页面 - (1D) 有序词元实现高效测试时搜索 来源:[https://huggingface.co/papers/2604.15453](https://huggingface.co/papers/2604.15453) ## 摘要 具有“粗到细”词元结构的自回归模型在测试时扩展上表现更佳,并在与图文验证器结合后,实现无需训练的文本到图像生成。 [词元化](https://huggingface.co/papers?q=Tokenization) 是自回归(AR)生成模型的关键组件,将原始
TIDE:每一层都知晓上下文中的令牌
本文介绍了 TIDE,一种通过嵌入记忆(Embedding Memory)将令牌身份注入每一层,从而解决大语言模型(LLM)中罕见令牌问题和上下文崩溃问题的方法。作者在理论上和经验上证明了该方法在语言建模和下游任务中的改进。