WAV:面向深度仅解码器Transformer的多分辨率块残差路由

arXiv cs.LG 论文

摘要

本文提出多分辨率残差路由方法WAV v1,这是块注意力残差机制的扩展,通过引入方向性细节基来增强块表示,从而改进深度仅解码器Transformer的训练效果。

arXiv:2606.06564v1 公告类型:新 摘要:残差连接是训练深度Transformer的核心,但标准的PreNorm残差流以固定单位权重聚合子层更新。近期提出的注意力残差机制用基于内容的深度路由替代了这种固定累加,而块注意力残差通过路由块级残差摘要使该机制更加高效。然而,单个块摘要仅存储块内的低频总残差位移,丢弃了方向性结构,例如注意力与MLP之间的不平衡以及块内早期与晚期的动态差异。我们提出WAV v1,一种轻量级的多分辨率残差路由方法,适用于仅解码器Transformer。WAV v1不是仅用累加残差和表示每个块,而是为每个块补充两个方向性细节基:一个相位基,用于对比注意力和MLP更新;一个分割基,用于对比早期和晚期子层更新。这些基与标准块摘要一起通过相同的深度softmax混合器进行路由,同时负细节源初始化和分离的均方根匹配稳定了训练过程。在字符级TinyStories和Text8语言建模任务上,WAV v1显示出明显的深度依赖性收益。虽然在12层时并不总是有益,但在24层时变得有竞争力,并在48层时优于所有基线。在48层时,相对于Block AttnRes,WAV v1将TinyStories的验证损失从0.4960降至0.4738,将Text8的验证损失从0.9363降至0.9305,且额外参数可忽略不计。这些结果表明,方向性残差细节(而非仅仅块级和)对于在更深Transformer中扩展残差路由至关重要。
查看原文
查看缓存全文

缓存时间: 2026/06/08 09:16

# 深度解码器专用Transformer的多分辨率残差路由
来源: https://arxiv.org/html/2606.06564

###### 摘要

残差连接是训练深度Transformer的核心,但标准的PreNorm残差流以固定的单位权重聚合子层更新。最近的注意力残差用内容相关的深度路由替代了固定累加,而块注意力残差通过对块级残差摘要进行路由来提高效率。然而,单个块摘要仅存储块内的低频、总残差位移,丢弃了方向结构,例如注意力与MLP的不平衡以及早期与晚期块的动态变化。我们提出*多分辨率残差路由*,实例化为WAV v1,这是块注意力残差的一种轻量级扩展,为每个块表示增加两个零和方向细节基:一个相位基,对比注意力和MLP更新;一个分裂基,对比块的前半部分和后半部分。这些细节基与块摘要由相同的深度softmax混合器路由,但通过负偏置和均方根匹配引入,以保持训练早期稳定性。在使用TinyStories和Text8进行字符级GPT解码器专用语言建模时,WAV v1表现出强烈的深度依赖趋势:在12层时无益,在24层时变得有竞争力,在48层时在两个数据集上均达到最佳验证损失。在48层时,WAV v1在TinyStories上比Block AttnRes验证损失降低0.0222,在Text8上降低0.0057,同时保持注意力和MLP模块不变。这些初步结果表明,深度残差路由不仅受益于选择*读哪个块*,还受益于路由每个块的内部方向结构。

## 1 引言

现代解码器专用Transformer通常使用PreNorm残差连接进行训练,其中每个子层对残差流贡献一个加性更新。这种简单结构使得训练非常深的网络成为可能,但它也施加了一个固定的聚合规则:每个残差更新都以固定单位权重累积。随着深度增加,这种均匀累积可能会稀释单个层的贡献,使残差流越来越冗余。注意力残差通过用学习到的对先前层输出的softmax注意力代替固定残差累积来解决这个问题,使每一层都能执行内容相关的深度路由[3 (https://arxiv.org/html/2606.06564#bib.bib3)]。块注意力残差通过将多个层压缩成块级残差摘要,进一步降低了内存和通信开销。

本文的中心假设是,单个块摘要仍然是对块内残差轨迹的不完整表示。如果一个块包含一个子层更新序列{ub,i}i=1m,Block AttnRes只存储它们的和Cb=∑iub,i。这类似于保留块轨迹的低频或DC分量。然而,轨迹的内部形状可能携带有用信息。例如,该块可能以注意力为主或以MLP为主;其早期更新可能指向与晚期更新不同的方向。当仅用Cb表示块时,这种方向信息就丢失了。

我们引入WAV v1,这是Block AttnRes的一种最小多分辨率扩展。对于每个块,WAV v1存储原始块和Cb以及两个零和细节基:Dbphase,注意力和MLP更新之间的差异;Dbsplit,前半部分和后半部分更新之间的差异。这些基的计算成本很低,因为它们是在生成残差更新时在线累积的。它们也是保守的:主要的注意力和MLP函数保持不变,细节源接收初始负偏置,细节的RMS与相应的块摘要匹配,并且最终的预测混合器默认不使用细节源。

我们的初步实验在TinyStories和Text8上评估了12层、24层和48层的GPT解码器专用模型。结果揭示了清晰的深度依赖模式。在12层时,WAV v1表现不如Block AttnRes,表明当残差深度较浅时,方向细节源没有用处。在24层时,WAV v1变得有竞争力。在48层时,WAV v1始终最佳,在两个数据集上都优于Block AttnRes,并大幅超过ReZero和LayerScale。这种模式支持了多分辨率残差信息随着残差轨迹长度增长而变得更有价值的观点。

本文草稿作出三项贡献:

1.  我们将块残差路由表述为一个多分辨率表示问题:块摘要提供低频状态信息,而方向细节基提供块内轨迹信息。
2.  我们提出WAV v1,它是Block AttnRes的即插即用扩展,在保持原始注意力和MLP模块不变的情况下,为每个块增加相位和分裂细节基。
3.  我们在两个字符级语言建模数据集上提供了初步的扩展证据,表明WAV v1随深度增加而变强,并在评估的残差机制中取得了最佳的48层验证损失。

## 2 相关工作

#### 深度网络中的残差连接。

残差学习由ResNets[1 (https://arxiv.org/html/2606.06564#bib.bib1)]推广,其中恒等跳跃连接使得非常深网络的优化变得容易得多。Transformer[2 (https://arxiv.org/html/2606.06564#bib.bib2)]继承了这一残差原则,并通常使用PreNorm变体以保证稳定性。然而,标准残差加法将每次更新的聚合系数固定为1,这在非常深的架构中可能变得次优。

#### 学习的残差缩放。

一些方法通过学习缩放残差更新来改进深度训练。ReZero引入了一个初始化为零的标量残差门,使得在很大深度下也能稳定信号传播[5 (https://arxiv.org/html/2606.06564#bib.bib5)]。LayerScale使用每个通道可学习的残差缩放,已被证明能改进深度视觉Transformer[6 (https://arxiv.org/html/2606.06564#bib.bib6)]。这些方法调节残差幅度,但它们不对先前的残差状态进行内容相关的选择。

#### 注意力残差与块路由。

注意力残差用学习到的对先前表示的注意力替代加性残差累积[3 (https://arxiv.org/html/2606.06564#bib.bib3)]。Block AttnRes通过将层分组为块并对块级表示进行路由来压缩该机制。这相对于对每一层进行注意力,大大降低了成本。一个相关的近期方向,Delta注意力残差,对残差增量而不是累积隐藏状态进行路由,强调所选路由表示的重要性[4 (https://arxiv.org/html/2606.06564#bib.bib4)]。我们的工作是互补的:我们保留Block AttnRes的块级效率,但用结构化的方向细节丰富每个块表示。

## 3 方法

### 3.1 预备知识:块残差路由

考虑一个具有L层的解码器专用Transformer。每个Transformer层包含一个注意力子层和一个MLP子层。我们将每个子层输出视为一个残差更新。对于包含m个子层更新{ub,i}i=1m的块b,Block AttnRes存储块级残差摘要

Cb=∑i=1mub,i。 (1)

在后面的子层,深度混合器接收一个源集合,例如

Sblock={e,C0,C1,...,Cb−1,PC}, (2)

其中e是词嵌入源,PC是当前部分块和。

给定源张量{sj}j=1S、查询向量q和源偏置{βj},混合器计算

lj = q⊤RMSNorm(sj)+βj, (3)

αj = exp(lj)∑k=1Sexp(lk), (4)

h = ∑j=1Sαjsj。 (5)

得到的h用作注意力子层或MLP子层的上下文输入。

### 3.2 多分辨率块基

WAV v1保留原始块摘要Cb,但用两个零和细节基进行扩充。让ai∈{+1,−1}指示更新ub,i来自注意力子层还是MLP子层:

ai={+1,ub,i来自注意力,−1,ub,i来自MLP。 (6)

相位基为

Dbphase=∑i=1maiub,i。 (7)

该基捕获块的残差位移是以注意力类更新为主还是MLP类更新为主。

类似地,让ri∈{+1,−1}指示子层更新在块的前半部分还是后半部分:

ri={+1,i≤m/2,−1,i\>m/2。 (8)

分裂基为

Dbsplit=∑i=1mriub,i。 (9)

该基捕获块内部粗略的早期vs晚期运动。

得到的源集合为

SWAV={e,C0,D~0phase,D~0split,...,Cb−1,D~b−1phase,D~b−1split,PC,P~phase,P~split}。 (10)

### 3.3 稳定的细节注入

直接将细节源添加到深度混合器中可能会破坏早期训练的稳定性。因此我们使用两种保守机制。

#### 负细节偏置。

两个细节源初始化为负源偏置βD=−2.0,而嵌入和C源使用零偏置。这使得WAV v1在初始化时接近Block AttnRes,并让模型仅在有益时逐渐增加细节使用。

#### 分离的RMS匹配。

对于与块摘要C相关的细节张量D,我们计算

D~=D⋅stopgrad(clip(RMS(C)RMS(D)+ε,1ρ,ρ)), (11)

其中ρ是最大缩放因子。这防止细节源仅仅因为其原始规模大于C而被激活。在我们的实现中,最终的预测混合器只读取嵌入和C源,不读取细节源。

参见图注图1:WAV v1的详细概览。(a) 在每个残差块内,子层更新累加为一个状态基Cb和两个方向细节基Dbphase和Dbsplit。(b) 与Block AttnRes相比,深度混合器接收一个扩展的源池,包括已完成块和部分块的细节源。(c) 在正向步骤中,MLP分支在注意力更新写入后读取部分基;最终的预测读出仅使用嵌入和C源。细节源通过负初始偏置和分离的RMS匹配进行稳定。

### 3.4 计算成本

WAV v1保持注意力、MLP、词嵌入和输出头不变。其额外成本来自增加块级路由源的数量。如果模型有N个残差块,Block AttnRes路由大约O(N)个块摘要,而WAV v1路由大约O(3N)个块基源。渐近成本仍然是块级,而不是层级,因为它不需要对先前所有子层状态进行注意力。额外参数可忽略不计:每个层有两个用于注意力混合器的细节偏置和两个用于MLP混合器的细节偏置,即每个Transformer层四个标量参数。

## 4 实验

### 4.1 设置

我们在TinyStories[9 (https://arxiv.org/html/2606.06564#bib.bib9)]和Text8[10 (https://arxiv.org/html/2606.06564#bib.bib10)]上评估字符级GPT解码器专用语言模型。所有模型使用PreNorm RMSNorm[7 (https://arxiv.org/html/2606.06564#bib.bib7)]、因果自注意力和SwiGLU MLP[8 (https://arxiv.org/html/2606.06564#bib.bib8)]。我们比较五种残差机制:标准残差、Block AttnRes、ReZero、LayerScale和WAV v1。当前arXiv草稿报告了我们实验摘要中的验证损失;误差条将在原始日志完全合并后添加。

表1:用于初步arXiv草稿的实验设置。
### 4.2 主要结果

表2:50k步时的最终验证损失。数值越低越好。Δ与Block的差值为WAV v1减去Block AttnRes,因此负值表示改进。PPL降低根据验证损失计算为1−exp(LWAV)/exp(LBlock)。表2 (https://arxiv.org/html/2606.06564#S4.T2)显示了最终验证损失。最重要的结果是深度依赖趋势。在12层时,WAV v1在两个数据集上均劣于Block AttnRes。在24层时,WAV v1变得有竞争力:它在TinyStories上略有改进,在Text8上接近。在48层时,WAV v1在两个数据集上都是最佳方法。在TinyStories上,它将验证损失从0.4960降低到0.4738(相对于Block AttnRes)。在Text8上,它将验证损失从0.9363降低到0.9305。

参见图注图2:WAV v1相对于Block AttnRes的深度依赖增益。负值表示WAV v1更好。该方法在浅层时无益,但在48层时明显更强。参见图注图3:不同深度和数据集上的最终验证损失。WAV v1在48层时最强,而Block AttnRes在较浅深度时最强或有竞争力。
### 4.3 48层时的训练动态

图4 (https://arxiv.org/html/2606.06564#S4.F4)比较了48层的训练曲线。在TinyStories上,WAV v1很早就与Block AttnRes分离,并在整个训练过程中保持持续优势。在Text8上,优势较小,但训练结束时一致。这些曲线表明,改进不仅仅是一个晚期阶段的人工产物:当模型足够深时,多分辨率源可以影响优化轨迹。

参见图注图4:48层模型的验证损失曲线。WAV v1在TinyStories和Text8上都获得了最佳最终验证损失。
### 4.4 排名总结

表3:按最终验证损失排名的最佳和次佳方法。表3 (https://arxiv.org/html/2606.06564#S4.T3)中的排名突显了一个关键的质性区别。Block AttnRes在浅层和中层深度时非常有竞争力,但WAV v1在评估的最深配置中变成了最强。这支持了这样的解释:当每个块总结一个较长

相似文章

Block-Based Double Decoders

arXiv cs.LG

提出了一种基于块的雙解碼器(block-based double decoders),这是一种使用双重因果块注意力掩码的新型Transformer架构,结合了解码器仅训练效率与编码器-解码器推理效率,实现了强大的扩展性能并减少了KV缓存内存。

Delta Attention Residuals

Hugging Face Daily Papers

Delta Attention Residuals 通过关注特征变化(增量)而非累积隐藏状态,改进了Transformer模型中的逐层路由,在220M到7.6B参数的规模上实现了1.7-8.2%的验证困惑度提升。

学习跳跃块:自我发现的超度量路由用于硬件加速稀疏注意力

Reddit r/artificial

本文介绍了动态超度量注意力(Dynamic Ultrametric Attention),这是一个框架,其中Transformer在训练期间学习每头块稀疏路由拓扑,然后在推理时将这些拓扑卸载到自定义的Triton块稀疏内核上,与密集注意力相比,实现了高达28倍的加速和98.4%的内存减少。