Transformer中的自回归下一个词元预测与KV缓存

Hacker News Top 工具

摘要

解释Transformer中的自回归下一个词元预测以及用于加速词元生成的KV缓存优化技术。

暂无内容
查看原文
查看缓存全文

缓存时间: 2026/05/20 14:27

# Transformer中的自回归下一个Token预测与KV缓存 来源:https://medium.com/advanced-deep-learning/autoregressive-next-token-prediction-kv-cache-in-transformers-afad22285baf Frederik vom Lehn (https://medium.com/@frederik.vl?source=post_page---byline--afad22285baf---------------------------------------) 理解LLM中加速token生成的关键优化技术 按回车或点击可查看全尺寸图片 总体概览(作者提供的图片) ## 宏观图景 在深入探讨注意力头、KV缓存以及生成机制之前,不妨先退一步,快速了解自回归语言模型究竟*是*什么。 一个提示以纯文本形式输入:“How are you?”。分词器将其切分为词汇ID——这里以`3, 7, 1, 9`表示,并在前面加上一个BOS(“序列开始”)token。每个ID只是一个整数,指向一个**查找表**:一个形状为`(vocab_size, c)`的已学习矩阵,其中每一行是词汇表中一个token的嵌入向量。为我们的5个输入ID选择对应的行,得到`X`,一个`(5, 4)`矩阵,五个token,每个都位于一个4维的嵌入空间中。至此,文本离开了符号世界,进入了向量世界。这里的维度只是示例中的简化尺寸。 从这里开始,`X`流过一叠**解码器块**。每个块都是相同的架构:多头自注意力后接一个MLP,每个块都将其输入转换为形状相同的精炼`(5, 4)`表示。使深层transformer可训练的关键是包裹在每个块周围的**残差连接**:每个块并不替换输入,而是在输入上*相加*(`X1 = X + block_output`)。信息沿着一条连续的“残差流”流动,每个层对其进行编辑而非覆盖。叠放三个这样的块,得到`X3`,即最终的隐藏状态。 最后一步则逆转了第一步。**解嵌入矩阵**(通常就是查找表的转置,因为输入和输出词汇表相同)将`X3`的每一行投影回词汇空间,产生一个`(5, 12)`的logits矩阵:每个位置上的每个词汇token都有一个得分。对于下一个token生成,只有最后一行有用。它的argmax就是模型想要说的下一个token。这里就是token ID 5。 这就是整个前向传播的俯瞰。本文的其余部分将聚焦于其中一个解码器块内部发生的情况,以及优化技术——**KV缓存**,正是它使得生成长序列成为可能。 让我们放大看看在一个解码层的第一次前向传播过程中,单层内部发生了什么。 按回车或点击可查看全尺寸图片 预填充前向传播(作者提供的图片) ## 预填充前向传播 在语言模型能生成任何新token之前,它必须先处理提示。这个步骤(**预填充**)将整个输入序列在一次并行前向传播中通过网络运行。它的作用有两个:产生第一个预测的token,以及填充KV缓存,使得后续的解码步骤保持廉价。 让我们在一个小型模型中一步步过一遍一个5-token提示的处理过程:隐藏维度`c = 4`,2个注意力头,词汇表大小为12。 ### 从token到Q、K、V 输入`X`以一个`(5, 4)`矩阵到达:5个token,每个由从查找表中提取的4维嵌入表示。三个已学习的投影矩阵`Wq`、`Wk`、`Wv`,每个形状为`(4, 4)`,将`X`转换为查询(Query)、键(Key)和值(Value)矩阵`Q`、`K`、`V`,它们形状均为`(5, 4)`。 由于我们有2个头,每个`(5, 4)`矩阵按列分割为两个`(5, 2)`的切片,每个头一个切片。每个头会在自己的2维子空间中独立计算注意力。 ### 头内部的注意力 在单个头内部,注意力是一种带权重的查找。该头的`Q`切片`(5, 2)`与它的`K`切片的转置相乘,产生一个`(5, 5)`的注意力得分矩阵——每个token的查询与每个token的键做点积。经过缩放和softmax(并应用因果掩码,因为这是一个自回归模型,token*t*不能看到位置 > *t*的token),该矩阵的每一行变成了一个概率分布,表示“我应该从哪些过去的token中提取信息”。 然后这些权重乘以该头的`V`切片`(5, 2)`,得到该头的输出,形状为`(5, 2)`:每个token现在持有来自其允许位置的值向量的上下文感知混合。 ### 拼接与投影 两个头的输出被拼接回一个`(5, 4)`矩阵,然后通过一个输出投影`(4, 4)`。结果`X'`仍然是`(5, 4)`,形状与输入相同,但每一行现在反映了从整个序列中收集的信息。 ### MLP 然后每个token的向量被独立地送入一个两层MLP。形状为`(4, 8)`的`W_up`将每一行扩展到8维,GeLU添加非线性,而形状为`(8, 4)`的`W_down`将其投影回低维。输出`X1`是`(5, 4)`,在实际模型中,它会馈入下一个transformer块。叠放几个这样的层(这里为3层),就得到了完整的前向传播。我们假设这里就是最后一层。 ### Logits与第一个预测 在最后一层之后,`(5, 4)`的隐藏状态与解嵌入矩阵`(12, 4).T`相乘,得到形状为`(5, 12)`的logits,即每个位置上每个词汇token的得分。对于生成,只有**最后一行**重要:它告诉我们模型认为在token 5之后应该是什么。对该行取argmax(或采样),就得到第一个生成的token。在我们的例子中是token ID 5。 ### 缓存保存了什么 这里有一个安静但关键的部分:在这次单次前向传播中,每一层都为提示计算了形状为`(5, 4)`的`K`和`V`。这些张量被**存储**起来。它们是未来所有token在该层关于提示所需知道的一切。嵌入、查询、MLP激活——全部丢弃。从此,生成进入解码模式,一次处理一个新token,并从该缓存中读取,而不是重新进行计算。 所以现在让我们理解全局图景:当我们使用KV缓存生成下一个token时会发生什么。 带有KV缓存的第二次前向传播(作者提供的图片)## 带有KV缓存的解码步骤 预填充完成后,模型切换到**解码模式**。每个后续token都由一次前向传播生成,该前向传播在结构上看起来与预填充相似——但只在一个*行*上操作,并依赖KV缓存记住之前的所有内容。 让我们继续示例。预填充预测了token 5,所以我们现在将token 5作为下一个步骤的输入重新馈入。 ### 一个token进,一个token出 新的输入`X`是一个单行,形状为`(1, 4)`,也就是token 5的嵌入,从与预填充相同的查找表中检索。提示之前的5个token**不再**被重新馈入。它们不需要:模型在该层从它们那里所需的一切都已经存在于缓存中。 将这个`(1, 4)`的行乘以`Wq`、`Wk`、`Wv`(每个仍是`(4, 4)`),会得到新的`Q`、`K`和`V`,每个形状为`(1, 4)`。只有新token的查询、键和值被计算。 ### 追加到缓存 新计算出的`K`和`V`行被追加到上一步缓存中的`K`和`V`矩阵。预填充后缓存保存了`(5, 4)`,现在保存了`(6, 4)` :提示的5行加上token 5的一个新行。这个拼接后的张量就是注意力将要读取的对象。 ### 针对缓存的注意力 与之前一样按头分割,每个头现在有一个形状为`(1, 2)`的查询和一个完整的键/值矩阵,形状为`(6, 2)`。点积`Q · K^T`产生一个`(1, 6)`的得分行——token 5在所有6个位置上的注意力权重,包括它自己。这里不需要因果掩码:每个缓存位置在构造上都是过去的位置,因此每个得分都是有效的。 Softmax将其转化为一个概率分布,然后对`V` `(6, 2)`加权求和得到形状为`(1, 2)`的头输出。拼接两个头得到`(1, 4)`,输出投影`(4, 4)`产生形状为`(1, 4)`的`X'`。 ### 为什么这很重要 比较形状。预填充处理了一个`(5, 4)`输入,并行地在5行上运行每个操作,这对于填充缓存是必要的。解码处理一个`(1, 4)`输入,在单行上运行每个操作,而缓存则在需要的地方(注意力内部)静默地提供历史上下文。MLP、投影、解嵌入,它们所做的工作量相当于无缓存前向传播的`1/N`。 这就是长上下文生成变得可行的根本原因。没有KV缓存,每个新token都意味着重新做整个预填充,而且一次比一次稍长,生成N个token的成本会呈二次方增长。有了缓存,每个新token的成本大致相同,再加上对不断增长的缓存进行廉价注意力求和的开销。 生成一个token,本质上是一小部分新鲜工作在大量被记住工作的肩膀上完成。

相似文章

自剪枝键值注意力:通过预测未来效用决定何时写入

arXiv cs.LG

提出了自剪枝键值注意力(SP-KV),一种通过学习预测键值对未来效用的机制,动态剪枝KV缓存,将内存使用和解码速度提升3-10倍,且性能下降极小。模型和效用预测器通过下一词元预测进行端到端联合训练。