Transformer中的自回归下一个词元预测与KV缓存
摘要
解释Transformer中的自回归下一个词元预测以及用于加速词元生成的KV缓存优化技术。
暂无内容
查看缓存全文
缓存时间: 2026/05/20 14:27
# Transformer中的自回归下一个Token预测与KV缓存
来源:https://medium.com/advanced-deep-learning/autoregressive-next-token-prediction-kv-cache-in-transformers-afad22285baf
Frederik vom Lehn (https://medium.com/@frederik.vl?source=post_page---byline--afad22285baf---------------------------------------)
理解LLM中加速token生成的关键优化技术
按回车或点击可查看全尺寸图片
总体概览(作者提供的图片)
## 宏观图景
在深入探讨注意力头、KV缓存以及生成机制之前,不妨先退一步,快速了解自回归语言模型究竟*是*什么。
一个提示以纯文本形式输入:“How are you?”。分词器将其切分为词汇ID——这里以`3, 7, 1, 9`表示,并在前面加上一个BOS(“序列开始”)token。每个ID只是一个整数,指向一个**查找表**:一个形状为`(vocab_size, c)`的已学习矩阵,其中每一行是词汇表中一个token的嵌入向量。为我们的5个输入ID选择对应的行,得到`X`,一个`(5, 4)`矩阵,五个token,每个都位于一个4维的嵌入空间中。至此,文本离开了符号世界,进入了向量世界。这里的维度只是示例中的简化尺寸。
从这里开始,`X`流过一叠**解码器块**。每个块都是相同的架构:多头自注意力后接一个MLP,每个块都将其输入转换为形状相同的精炼`(5, 4)`表示。使深层transformer可训练的关键是包裹在每个块周围的**残差连接**:每个块并不替换输入,而是在输入上*相加*(`X1 = X + block_output`)。信息沿着一条连续的“残差流”流动,每个层对其进行编辑而非覆盖。叠放三个这样的块,得到`X3`,即最终的隐藏状态。
最后一步则逆转了第一步。**解嵌入矩阵**(通常就是查找表的转置,因为输入和输出词汇表相同)将`X3`的每一行投影回词汇空间,产生一个`(5, 12)`的logits矩阵:每个位置上的每个词汇token都有一个得分。对于下一个token生成,只有最后一行有用。它的argmax就是模型想要说的下一个token。这里就是token ID 5。
这就是整个前向传播的俯瞰。本文的其余部分将聚焦于其中一个解码器块内部发生的情况,以及优化技术——**KV缓存**,正是它使得生成长序列成为可能。
让我们放大看看在一个解码层的第一次前向传播过程中,单层内部发生了什么。
按回车或点击可查看全尺寸图片
预填充前向传播(作者提供的图片)
## 预填充前向传播
在语言模型能生成任何新token之前,它必须先处理提示。这个步骤(**预填充**)将整个输入序列在一次并行前向传播中通过网络运行。它的作用有两个:产生第一个预测的token,以及填充KV缓存,使得后续的解码步骤保持廉价。
让我们在一个小型模型中一步步过一遍一个5-token提示的处理过程:隐藏维度`c = 4`,2个注意力头,词汇表大小为12。
### 从token到Q、K、V
输入`X`以一个`(5, 4)`矩阵到达:5个token,每个由从查找表中提取的4维嵌入表示。三个已学习的投影矩阵`Wq`、`Wk`、`Wv`,每个形状为`(4, 4)`,将`X`转换为查询(Query)、键(Key)和值(Value)矩阵`Q`、`K`、`V`,它们形状均为`(5, 4)`。
由于我们有2个头,每个`(5, 4)`矩阵按列分割为两个`(5, 2)`的切片,每个头一个切片。每个头会在自己的2维子空间中独立计算注意力。
### 头内部的注意力
在单个头内部,注意力是一种带权重的查找。该头的`Q`切片`(5, 2)`与它的`K`切片的转置相乘,产生一个`(5, 5)`的注意力得分矩阵——每个token的查询与每个token的键做点积。经过缩放和softmax(并应用因果掩码,因为这是一个自回归模型,token*t*不能看到位置 > *t*的token),该矩阵的每一行变成了一个概率分布,表示“我应该从哪些过去的token中提取信息”。
然后这些权重乘以该头的`V`切片`(5, 2)`,得到该头的输出,形状为`(5, 2)`:每个token现在持有来自其允许位置的值向量的上下文感知混合。
### 拼接与投影
两个头的输出被拼接回一个`(5, 4)`矩阵,然后通过一个输出投影`(4, 4)`。结果`X'`仍然是`(5, 4)`,形状与输入相同,但每一行现在反映了从整个序列中收集的信息。
### MLP
然后每个token的向量被独立地送入一个两层MLP。形状为`(4, 8)`的`W_up`将每一行扩展到8维,GeLU添加非线性,而形状为`(8, 4)`的`W_down`将其投影回低维。输出`X1`是`(5, 4)`,在实际模型中,它会馈入下一个transformer块。叠放几个这样的层(这里为3层),就得到了完整的前向传播。我们假设这里就是最后一层。
### Logits与第一个预测
在最后一层之后,`(5, 4)`的隐藏状态与解嵌入矩阵`(12, 4).T`相乘,得到形状为`(5, 12)`的logits,即每个位置上每个词汇token的得分。对于生成,只有**最后一行**重要:它告诉我们模型认为在token 5之后应该是什么。对该行取argmax(或采样),就得到第一个生成的token。在我们的例子中是token ID 5。
### 缓存保存了什么
这里有一个安静但关键的部分:在这次单次前向传播中,每一层都为提示计算了形状为`(5, 4)`的`K`和`V`。这些张量被**存储**起来。它们是未来所有token在该层关于提示所需知道的一切。嵌入、查询、MLP激活——全部丢弃。从此,生成进入解码模式,一次处理一个新token,并从该缓存中读取,而不是重新进行计算。
所以现在让我们理解全局图景:当我们使用KV缓存生成下一个token时会发生什么。
带有KV缓存的第二次前向传播(作者提供的图片)## 带有KV缓存的解码步骤
预填充完成后,模型切换到**解码模式**。每个后续token都由一次前向传播生成,该前向传播在结构上看起来与预填充相似——但只在一个*行*上操作,并依赖KV缓存记住之前的所有内容。
让我们继续示例。预填充预测了token 5,所以我们现在将token 5作为下一个步骤的输入重新馈入。
### 一个token进,一个token出
新的输入`X`是一个单行,形状为`(1, 4)`,也就是token 5的嵌入,从与预填充相同的查找表中检索。提示之前的5个token**不再**被重新馈入。它们不需要:模型在该层从它们那里所需的一切都已经存在于缓存中。
将这个`(1, 4)`的行乘以`Wq`、`Wk`、`Wv`(每个仍是`(4, 4)`),会得到新的`Q`、`K`和`V`,每个形状为`(1, 4)`。只有新token的查询、键和值被计算。
### 追加到缓存
新计算出的`K`和`V`行被追加到上一步缓存中的`K`和`V`矩阵。预填充后缓存保存了`(5, 4)`,现在保存了`(6, 4)` :提示的5行加上token 5的一个新行。这个拼接后的张量就是注意力将要读取的对象。
### 针对缓存的注意力
与之前一样按头分割,每个头现在有一个形状为`(1, 2)`的查询和一个完整的键/值矩阵,形状为`(6, 2)`。点积`Q · K^T`产生一个`(1, 6)`的得分行——token 5在所有6个位置上的注意力权重,包括它自己。这里不需要因果掩码:每个缓存位置在构造上都是过去的位置,因此每个得分都是有效的。
Softmax将其转化为一个概率分布,然后对`V` `(6, 2)`加权求和得到形状为`(1, 2)`的头输出。拼接两个头得到`(1, 4)`,输出投影`(4, 4)`产生形状为`(1, 4)`的`X'`。
### 为什么这很重要
比较形状。预填充处理了一个`(5, 4)`输入,并行地在5行上运行每个操作,这对于填充缓存是必要的。解码处理一个`(1, 4)`输入,在单行上运行每个操作,而缓存则在需要的地方(注意力内部)静默地提供历史上下文。MLP、投影、解嵌入,它们所做的工作量相当于无缓存前向传播的`1/N`。
这就是长上下文生成变得可行的根本原因。没有KV缓存,每个新token都意味着重新做整个预填充,而且一次比一次稍长,生成N个token的成本会呈二次方增长。有了缓存,每个新token的成本大致相同,再加上对不断增长的缓存进行廉价注意力求和的开销。
生成一个token,本质上是一小部分新鲜工作在大量被记住工作的肩膀上完成。
相似文章
@TheTuringPost: 为什么 KV cache 是 LLM 速度快的主要原因之一?KV cache 将注意力机制与生成阶段连接起来……
KV cache 在自回归生成过程中存储先前计算的键向量和值向量,使模型能够避免在每一步重新计算整个序列,从而显著加速推理,但代价是内存使用增加。
让每个 Token 都物尽其用:通过 KV 缓存淘汰提升长上下文性能
本文提出了一种基于学习的全局保留率 KV 缓存淘汰方法,通过选择性保留有用 Token 并减少注意力稀释来改善长上下文推理能力,同时显著降低内存占用。
自剪枝键值注意力:通过预测未来效用决定何时写入
提出了自剪枝键值注意力(SP-KV),一种通过学习预测键值对未来效用的机制,动态剪枝KV缓存,将内存使用和解码速度提升3-10倍,且性能下降极小。模型和效用预测器通过下一词元预测进行端到端联合训练。
ReST-KV:基于逐层输出重构与时空平滑的鲁棒 KV Cache 驱逐方法
本文介绍了 ReST-KV,一种用于大型语言模型的新型鲁棒 KV Cache 驱逐方法。该方法利用逐层输出重构与时空平滑技术来提升效率,显著降低了解码延迟,并在 LongBench 和 RULER 等长上下文基准测试中超越了现有的最先进基线模型。
Tensor Cache: 基于驱逐条件的Transformer关联记忆
Tensor Cache 引入了一种两级缓存机制,将滑动窗口注意力中驱逐的键值对压缩成固定大小的关联记忆,从而在无需无界内存增长的情况下改进长上下文语言建模。