使用稀疏Transformer进行生成建模
摘要
OpenAI推出了稀疏Transformer,一种深度神经网络,将注意力机制的复杂度从O(N²)优化到O(N√N),使得能够对长度超过以前30倍的序列进行建模,适用于文本、图像和音频领域。该模型采用稀疏注意力模式和基于检查点的内存优化技术,可以训练深达128层的网络,在多个领域实现了最先进的性能。
我们开发了稀疏Transformer,一种深度神经网络,在预测序列中的下一个元素方面创造了新纪录——无论是文本、图像还是声音。它采用了注意力机制的算法改进,能够从比以前长30倍的序列中提取模式。
查看缓存全文
缓存时间:
2026/04/20 14:46
# 使用稀疏Transformer进行生成建模
来源:https://openai.com/index/sparse-transformer/
我们开发了稀疏Transformer,一种深度神经网络,在预测序列中的下一个元素方面创造了新的记录——无论是文本、图像还是声音。它使用了对*注意力*机制的算法改进,能够从长度比以前可能的长度长30倍的序列中提取模式。
AI研究中存在的一个现有挑战是对复杂数据(如图像、视频或声音)中的长距离微妙相互依赖关系进行建模。稀疏Transformer将O(N²)的Transformer(https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html)自注意力机制重新表述为O(N√N)复杂度,并进行了多项其他改进,以便将其直接应用于这些丰富的数据类型。以前,用于这些数据的模型是针对特定领域专门设计的,或者难以扩展到长度超过数千个元素的序列。相比之下,我们的模型可以使用数百层对包含数万个元素的序列进行建模,在多个领域实现了最先进的性能。在OpenAI,我们正在使用它来帮助我们构建具有更强理解世界能力的AI系统。
在Transformer中,每个输出元素都连接到每个输入元素,它们之间的权重根据具体情况动态计算,这个过程称为*注意力*。虽然人们相信这使得Transformer比具有固定连接模式的模型更灵活,但在实践中,它需要为每一层和每个注意力头创建一个N×N的*注意力矩阵*,当应用于包含许多元素的数据类型(如图像或原始音频)时,会消耗大量内存。
深层Transformer(64层和4个头)的注意力内存使用情况(矩阵存储在内存中或在反向传播期间重新计算)。作为参考,用于深度学习的标准GPU通常具有12-32 GB的内存。
减少这种情况的一种方法是在反向传播期间从*检查点*重新计算注意力矩阵,这是深度学习中一种广泛使用的技术,可以以增加计算为代价来减少内存使用。当对Transformer中的注意力矩阵执行此操作时,意味着最大的内存成本与层数无关,使我们能够训练深度比以前可能的更深的网络。实际上,我们发现在CIFAR-10等基准任务上,深度达到128层的Transformer的性能优于较浅的网络。
为了训练这些深度更大的模型,我们对Transformer中的操作顺序进行了多项调整,并修改了初始化方案。完整的详细信息可以在我们的论文中看到。
但是,即使计算单个注意力矩阵对于非常大的输入也可能变得不切实际。我们改为使用稀疏注意力模式,其中每个输出位置仅从输入位置的一个子集计算权重。当该子集相对于完整输入集很小时(比如N√N个元素而不是N个元素),即使对于非常长的序列,生成的注意力计算也变得易于处理,算法复杂度为O(N√N)而不是O(N²)。
为了评估该方法的可行性,我们首先可视化了深层Transformer在图像上学习到的注意力模式,发现许多模式显示了可解释的和结构化的稀疏模式。下面的每个图像都显示了某个注意力头为了预测图像中的下一个值而关注的输入像素(以白色突出显示)。当输入部分集中在小子集上并显示高度规律性时,该层适合进行稀疏化。以下是128层模型在CIFAR-10图像上的一些样本:
虽然许多层显示了稀疏结构,但某些层显然显示了动态注意力,延伸到整个图像。为了保持我们网络学习此类模式的能力,我们实现了注意力矩阵的二维分解,其中网络可以通过两步稀疏注意力关注所有位置。
第一个版本是*步长*注意力,大致相当于每个位置关注其行和列,类似于上面网络学习到的注意力模式。(请注意,列注意力可以等价地表述为关注转置矩阵的行)。第二个版本是*固定*注意力,关注固定列和最新列元素之后的元素,这是我们在数据不适应二维结构(如文本)时发现有用的模式。有关更多详细信息,我们建议读者参考我们的论文。
稀疏Transformer在CIFAR-10、Enwik8和Imagenet 64的密度估计方面创造了新的最先进成绩。
各种基准数据集上的密度建模性能(比特/字节或维度)。M表示网络中使用的百万参数数,W表示网络的宽度,L表示层数,H表示头数。
我们还发现,稀疏注意力比完全注意力实现了更低的损失,而且速度明显更快(参见我们的论文中的比较)。这可能指向来自我们稀疏性模式的有用的归纳偏差,或者与密集注意力的潜在优化问题。
使用稀疏注意力的Transformer似乎具有全局结构的概念,可以通过查看图像补全来定性评估。这里我们可视化了一个在64×64 ImageNet上训练的模型:
用于训练稀疏注意力Transformer来补全图像的半图像
使用稀疏注意力Transformer补全的图像网格
用作样本训练稀疏注意力Transformer来补全图像的照片行
我们还用未调整的softmax温度1.0生成了完全无条件的样本。这些模型使用最大似然目标进行训练,众所周知这可以覆盖数据的所有模式(包括潜在不存在的模式),而不是增加较小部分数据的保真度。从这些模型中以未调整的温度进行采样使我们可以看到模型认为存在于世界中的图像的完整分布。因此,某些样本可能看起来很奇怪。
用于训练稀疏注意力Transformer模型的样本图像
Imagenet真实数据
稀疏Transformer也可以通过简单地更改位置嵌入来调整为生成原始音频而不是图像。随着深度学习扩展到新的数据类型,我们相信使用这类网络指定归纳偏差的便利性将成为有用的工具。
该模型在原始古典音乐片段上进行训练,并使用稀疏注意力生成长度为65,000的序列。这对应于约5秒的原始音频,我们在下面的每个片段中连接了多个样本。
通常,实现稀疏注意力会涉及将查询和密钥矩阵分块切片,为了便于实验,我们实现了一组块稀疏内核(https://openai.com/index/block-sparse-gpu-kernels/),它们在GPU上高效地执行这些操作。我们开源了这些内核,并在此存储库(https://github.com/openai/sparse_attention)中提供了稀疏注意力函数的示例。
- 我们引入的稀疏注意力模式只是朝着高效建模长序列方向的初步步骤。我们认为探索不同的模式和稀疏性组合是有用的,学习稀疏模式对于下一代神经网络架构的研究来说是一个特别有前景的方向。
- 即使有我们上面描述的改进,自回归序列生成对于非常高分辨率的图像或视频来说仍然看起来不切实际。但是,我们引入的优化注意力操作可能是有用的原语,可以与其他高维数据建模方法(如多尺度方法)结合使用。
如果您有兴趣推进AI能力并帮助推进我们确保AI造福人类的使命,我们正在招聘(https://openai.com/careers/)!
相似文章
OpenAI Blog
OpenAI 研究人员提出了一种训练稀疏神经网络的方法,通过强制大部分权重为零使其更易于解释,从而发现能够解释模型行为的小型解耦电路,同时保持性能。这项工作旨在推进机制可解释性,作为对稠密网络事后分析的补充,并支持 AI 安全目标。
Reddit r/MachineLearning
作者分享在 162 MB Transformer 上把 FP16 + ONNX + 剪枝用到极致却收益递减的经历,求教下一步该选量化、蒸馏、低秩分解还是硬件级技巧。
OpenAI Blog
OpenAI 推出 GPT-2,这是一个拥有 15 亿参数的基于 Transformer 的语言模型,在 40GB 的互联网文本上进行训练,在语言建模基准上达到了最先进的性能,并在阅读理解、翻译、问答和摘要生成等任务上展示了零样本学习能力。出于安全考虑,仅公开发布了较小的模型和技术论文,而非完整的训练模型。
OpenAI Blog
OpenAI 提出了一种两阶段方法来改进语言理解:首先在大规模无监督数据集上使用语言建模对 transformer 模型进行预训练,然后在较小的有监督数据集上针对特定任务进行微调。该方法在包括常识推理、语义相似度和阅读理解在内的多种任务上取得了最先进的成果,同时需要的超参数调优工作最少。
arXiv cs.LG
本文提出了Token-Selective Attention (TSA),一种可微的token路由机制,它学习在每个token上跳过Transformer层中不必要的计算,从而在语言建模任务中将token层操作减少14-23%,且质量损失极小。