@thtrkim: FlashAttention 的手动可视化深入讲解(使用 Excalidraw 绘制)https://winterrykim.github.io/blog/2026/training-lm-…
摘要
深入理解 FlashAttention 的可视化讲解,涵盖内存优化和算子融合,以实现语言模型训练中的高效注意力计算。
FlashAttention 的手动可视化深入讲解(使用 Excalidraw 绘制)https://winterrykim.github.io/blog/2026/training-lm-from-scratch-part2-flashattention-memory/…
查看缓存全文
缓存时间: 2026/06/24 10:22
动手深度探索 FlashAttention(手绘 Excalidraw) https://winterrykim.github.io/blog/2026/training-lm-from-scratch-part2-flashattention-memory/… — # 从零训练语言模型(第二部分:FlashAttention 与设备内存) 来源:https://winterrykim.github.io/blog/2026/training-lm-from-scratch-part2-flashattention-memory/ 在之前的博客文章(https://winterrykim.github.io/blog/2026/training-lm-from-scratch-part1-building-blocks/)中,我介绍了语言模型的基本构建块:分词、嵌入、RoPE、注意力等。但如果放大注意力部分,之前的文章大多将其视为一个干净的等式: \[O = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V\] 这个数学对象是正确的,但它隐藏了一个系统问题:当计算机试图运行这个等式时,实际发生了什么?本文要探讨的就是这一缺失的层次。不仅是“什么是注意力?”,还有为什么朴素的注意力会变得内存密集,为什么 softmax 可能比 FLOP 计数显示的要慢,以及为什么 FlashAttention 是一个如此巧妙的算法技巧。 一句话总结: FlashAttention 实现了精确的注意力,但避免将巨大的注意力分数/概率矩阵写入 HBM。它对 Q/K/V 分块进行流式处理,在线维护 softmax 统计量,只保存输出和 log-sum-exp,并在反向传播中重新计算所需的内容。 — ## 先做性能分析 从性能分析时间线中可以观察到一些情况: 注意力内核的 Nsight 风格性能分析时间线 性能分析表明,注意力不仅仅是 matmul。即使 softmax 的 FLOP 计数远小于 matmul,它也可能占用时间线中有意义的一部分。第一个有用的划分是,“昂贵”这个词并不单一。对于 FLOP 来说,明显的巨大运算是矩阵乘法。在注意力中,这些运算是: \[QK^\top\] 以及: \[PV\] 其中: \[P = \operatorname{softmax}(S), \quad S = \frac{QK^\top}{\sqrt{d}}\] 但挂钟时间不仅仅是 FLOP。像 softmax 这样的运算可能耗时较长,因为它们归约密集且内存流量密集。稳定的 softmax 需要扫描一行以找到最大值、指数化、求和、除法并写入结果。此外,这些运算并不是加速器极其擅长的专用 GEMM 工作。 因此,与传统算法一样,我们有两个轴: 1. 速度 2. 内存 对于注意力来说,内存尤其痛苦,因为朴素注意力会实例化形状如下的矩阵: \[[B, H, N, N]\] 其中 N 是序列长度。这个 N^2 是悄然将长上下文变成问题的部分。 — ## FlashAttention 之前 在进入 FlashAttention 之前,有几个常见的技术几乎随处可见。 ### 混合精度 细节混合精度在通常安全且更快的地方使用较低精度,同时将更敏感的操作保持在较高精度。例如,matmul 通常可以使用 FP16/BF16 来利用专用的 matmul 单元并降低内存带宽。但归约、归一化、损失、优化器状态和长累加通常需要更多关注。关键不是“让一切都降低精度”,而是策略性精度。 ### 算子融合 细节算子融合将多个运算合并到一个内核中。如果我们这样做: x = x - x.max(dim=-1, keepdim=True).values num = torch.exp(x) den = num.sum(dim=-1, keepdim=True) y = num / den 朴素实现可能会在每一步之间读写中间张量。融合实现可以将更多工作保留在内核内部,减少内存流量。这也是 softmax 是一个好例子的原因之一。其算术并不可怕,但反复写入和读取中间行是浪费的。 ### 重计算 / 检查点 细节Autograd 通常在正向传播中保存张量,因为反向传播稍后需要它们。检查点改变了权衡:保存更少的中间张量,并在反向传播中重新计算缺失的值。因此,权衡是: \[\text{更少内存} \quad \leftrightarrow \quad \text{更多计算}\]这个想法将在 FlashAttention 反向传播中再次出现。我们不再保存巨大的 S 和 P 矩阵,而是保存足够的信息,以便稍后重新计算必要的概率块。 — ## 缺失的想法:IO 感知 上述技术很有用,但它们本身并不完全具备 IO 感知能力。这里说的 IO 感知是指:算法显式关心数据在不同级别内存之间移动的次数。从高层次看,相关的内存层次是: 寄存器 每线程值,最快,极小 片上 SRAM / 共享内存 GPU/加速器上的快速暂存器,但大小有限 HBM / 设备 DRAM 加速器外的内存,通常几十 GB NAND / SSD 持久化闪存,在注意力内核循环之外 对于 FlashAttention,主要目标就是减少 HBM 与片上内存层次之间的流量。 — ## 厨房比喻 关键在于,我们希望在相关分块已在片上内存时,尽可能少地访问 HBM,并尽可能多地完成工作。想象一位厨师在做牛肉汤。储藏室是 HBM。厨房台面是片上 SRAM/共享内存。厨师的手是寄存器。如果厨师每进行一个小步骤就要走到储藏室,那会很糟糕。切一种蔬菜,走回去存放。再拿出来。加盐,再走回去。再拿出来。这时厨师的双腿会燃烧。自然会问:为什么不一次性把所有的东西都拿到厨房?因为厨房很小,门很窄。更具体地说,加速器无法将整个注意力矩阵放入快速的片上内存。它必须通过寄存器/共享内存移动分块,而完整的张量则位于 HBM 中。这基本上就是注意力问题。在朴素注意力中,我们取出 Q、K、V 的大块,计算完整的 QK^\\top 矩阵,将其写入 HBM,读取回来进行 softmax,将概率写回,再次读取,然后乘以 V。在下面的符号中,我有时会省略 1/\\sqrt\{d\} 缩放而直接写 QK^\\top。缩放仍然是实际注意力的一部分。数学很简洁: \[S = QK^\top\] \[P = \operatorname{softmax}(S)\] \[O = PV\] 但内存移动很丑陋: 计算 S -> 将 S 写入 HBM 读取 S -> 计算 P -> 将 P 写入 HBM 读取 P 和 V -> 计算 O -> 写入 O FlashAttention 问:我们能避免这种来回传输吗? — ## FlashAttention 简述 FlashAttention 结合了分块、在线 softmax 和重计算。它不会近似注意力。输出仍然是: \[\operatorname{softmax}(QK^\top / \sqrt{d})V\] 区别在于算法的调度方式。FlashAttention 不会在 HBM 中实例化完整的 S 和 P 矩阵,而是: 1. 加载一个 Q 分块。 2. 流式处理 K 和 V 的分块。 3. 计算局部分数块。 4. 为每个查询行维护运行中的 softmax 统计量。 5. 累积输出分子。 6. 仅写入最终输出和一个小的 log-sum-exp 向量。 在 FlashAttention-2 中,正向传递使用一个外层循环遍历 Q 行块,和一个内层循环遍历 K/V 列块。最初的 FlashAttention 论文也使用了分块,但 FlashAttention-2 改变了循环顺序和工作划分,以提高并行性并减少不必要的共享内存流量。 — ## 原始注意力与 FlashAttention-2 分块 理解这一点最简单的方法是可视化矩阵乘法。 手绘 FlashAttention 分块示意图,显示 Q 行块和 K/V 列块 FlashAttention 将注意力拆分为行/列块的工作,而不是实例化整个注意力矩阵。输出侧的累积仍然需要在线 softmax 校正。从概念上讲,QK^\\top 计算每个查询 token 关注每个键 token 的程度。对于固定的查询行,softmax 权重分布在所有键位置上,输出是匹配的 V 行的加权和。因此,我们可以按行将这些工作拆分为块: Q 块 i 关注 K/V 块 0 Q 块 i 关注 K/V 块 1 Q 块 i 关注 K/V 块 2 ... 问题在于 softmax 耦合了整行。我们不能独立地对每个块进行 softmax 然后将答案拼接起来。softmax 的分母需要整行。这就是在线 softmax 发挥作用的地方。 — ## 累积部分输出 对于一个查询块,每个 K/V 块都会对最终输出贡献一部分。 手绘 FlashAttention 部分和示意图,显示跨 K/V 块的输出累积 每个 K/V 块贡献最终加权和的一部分。这些部分贡献在累积之前会使用运行中的最大值和分母进行重新缩放。如果 softmax 分母不是问题,这会很容易: \[O_i = \sum_j P_{ij}V_j\] 按 K/V 块拆分: \[O_i = \sum_{j \in \text{块 0}} P_{ij}V_j + \sum_{j \in \text{块 1}} P_{ij}V_j + \cdots\] 问:分块说得通,但我们如何解决 softmax? 如果 P 已经知道,累积过程就成立。但 P\_\{ij\} 依赖于所有 K/V 块的行最大值和分母,因此 softmax 似乎需要访问整行。我们如何能够流式处理块,同时获得与一次性看到整行相同的 softmax? — ## 安全 / 在线 softmax 通常,一行数据的稳定 softmax 是: \[\operatorname{softmax}(x_j) = \frac{e^{x_j - m}}{\sum_t e^{x_t - m}}, \quad m = \max_t x_t\] 这需要知道整行的最大值。但假设该行被分为块 A 和块 B。块 A 的最大值为 m\_A。块 B 的最大值为 m\_B。全局最大值是: \[m = \max(m_A, m_B)\] 如果我们先使用 m\_A 处理块 A,之后发现 m\_B 更大,那么旧的块 A 值是用错误的最大值归一化的。但这是可以修正的。对于来自块 A 的旧值 x: \[e^{x - m} = e^{x - m_A} \cdot e^{m_A - m}\] 这个微小的重新缩放项就是整个技巧: \[e^{m_\text{旧} - m_\text{新}}\] 因此,当我们流式处理块时,我们保持: \[m_i = \text{运行中的行最大值}\] \[\ell_i = \text{运行中的分母}\] \[u_i = \text{运行中的未归一化输出分子}\] 当新块到达时: \[m_\text{新} = \max(m_\text{旧}, m_\text{块})\] 重新缩放旧的分母和分子: \[\ell_\text{新} = e^{m_\text{旧}-m_\text{新}}\ell_\text{旧} + \sum_{j \in \text{块}} e^{s_j - m_\text{新}}\] \[u_\text{新} = e^{m_\text{旧}-m_\text{新}}u_\text{旧} + \sum_{j \in \text{块}} e^{s_j - m_\text{新}}v_j\] 最后: \[O_i = \frac{u_i}{\ell_i}\] 这就是在流式/分块设置中的安全 softmax 技巧。 — ## 使用 LSE 连接正向和反向传播 在反向传播中,我们不希望保存完整的 S 和 P 矩阵。但我们仍然需要恢复 P 块,因为梯度方程依赖于 softmax 概率。那么我们要保存什么呢? 手绘 log-sum-exp 技巧,用于在 FlashAttention 反向传播中重建 softmax 概率 FlashAttention 为每个查询行保存一个 log-sum-exp 值。+m\_i 项是行最大值稳定项;恢复的概率是 e^{S_{ij} - L_i}。紧凑的值是 log-sum-exp: \[L_i = \log \sum_t e^{S_{it}}\] 出于数值稳定性,计算方式为: \[L_i = m_i + \log \sum_t e^{S_{it} - m_i}\] 其中 m\_i 是行最大值。然后,在反向传播中,如果我们重新计算一个分数块: \[S_{ij} = \frac{q_i^\top k_j}{\sqrt{d}}\] 我们可以通过以下方式恢复 softmax 概率: \[P_{ij} = e^{S_{ij} - L_i}\] 这是关键步骤。我们不需要完整的旧概率矩阵。我们只需要 Q、K、V、输出 O、上游梯度 dO 和逐行的 L_i 值。然后,当该块已经加载时,我们重新计算概率块。有一个重要的细节:L_i 不是“左侧部分分母和右侧部分分子”。它是完整 softmax 分母的对数,以稳定的方式书写。从分数中减去 L_i 正好得到分数减去对数分母,指数化后得到 softmax 概率。 — ## 反向传播 现在我们有了反向传播的构建块。朴素注意力可能会保存像 S 和 P 这样的大型中间张量。FlashAttention 避免了保存它们。它保存或可以访问: - Q, K, V - 输出 O - 下游梯度 dO - 每个查询行的 log-sum-exp 向量 L - O 和 dO 的逐行点积(通常称为 D) 然后它逐块重新计算 S 和 P。 手绘 FlashAttention 反向传播示意图 反向传播再次在块上进行流式处理。它重新计算分数/概率块,为每个 K/V 块累积 dK/dV,并更新相应查询块的 dQ。概念上的顺序是: 1. 从 QK^\\top 和 LSE 重新计算局部 P 块。 2. 使用 P 计算 dV 和 dP。 3. 使用 softmax 梯度计算 dS。 4. 使用 dS 计算 dQ 和 dK。 在矩阵形式中,主要方程为: \[S = \frac{QK^\top}{\sqrt{d}}\] \[P = e^{S - L}\] 其中 L 在各行的列上广播。 \[dV = P^\top dO\] \[dP = dO V^\top\] \[dS = P \odot (dP - D)\] 这里的 D 是一个逐行修正项,通常由 dO 和 O 计算得出。关键在于,这个 softmax 梯度是逐元素的,而不是与 P 的矩阵乘法。 \[dQ = \frac{dS K}{\sqrt{d}}\] \[dK = \frac{dS^\top Q}{\sqrt{d}}\] 在分块的反向传播中,FlashAttention-2 从概念上翻转了调度:它可以将一个 K/V 块放在外层循环中,在内层循环中流式处理 Q/dO/LSE/O 块,为该 K/V 块累积 dK 和 dV,并为每个查询块更新 dQ。因此: 外层循环: K/V 块 内层循环: Q 块 这感觉很自然,因为反向传播是反向进行的:我们需要重建局部 P 块,使用它计算局部梯度,然后聚合各部分。 — ## 实现说明:Triton 下一个问题是:我们如何实际告诉计算机进行这种分块移动?PyTorch 张量代码非常适合表达数学,但它不会自然地让我们控制“加载这个块,保持这个累加器,避免实例化这个中间量”。CUDA 提供了这种控制,但级别非常低。Triton 处于两者之间。它提供了一种类似 Python 的方式来编写 GPU 内核,其中每个程序实例基本上是一个块工作器。 Triton 心理模型:Triton 概念我的理解grid启动多少个块工作器。tl.program_id(axis)当前块工作器的坐标。tl.arange一个块内的行/列偏移量。tl.load/tl.store显式地将块数据从全局内存移入/移出,通常使用边界块的掩码。tl.dot执行块 matmul。在我朴素的 forward 内核中,第一个网格轴选择 Q 块,第二个网格轴选择批次: pid_q = tl.program_id(axis=0) # 哪个 Q 块 pid_b = tl.program_id(axis=1) # 哪个批次 q_rows = pid_q * BLOCK_M + tl.arange(0, BLOCK_M) d_cols = tl.arange(0, BLOCK_D) 然后每个程序加载一个 Q 块: q_ptrs = ( q_ptr + pid_b * stride_qb + q_rows[:, None] * stride_qn + d_cols[None, :] * stride_qd ) q_tile = tl.load( q_ptrs, mask=(q_rows[:, None] < Nq) & (d_cols[None, :] < D), other=0.0, ) 内核将在线 softmax 状态保留在块形状的累加器中: m = tl.full((BLOCK_M,), -float("inf"), tl.float32) l = tl.full((BLOCK_M,), 0, tl.float32) acc = tl.full((BLOCK_M, BLOCK_D), 0, tl.float32) 然后它在 K/V 块上进行流式处理: `` scores = tl.dot(q_tile, tl.trans(k_tile)) * scale if IS_CAUSAL: scores = tl.where(q_rows[:, None] >= k_rows[None, :], scores, -float(“inf”)) m_old = m m_tile = tl.max(scores, axi
相似文章
@charles_irl: 去年秋天,我们分享了关于FA4内部机制的深度分析。但我们并未止步于理解内核。自那时起,我们一直在…
一篇博客文章详细介绍了对FlashAttention-4的贡献,通过调整并行策略和支持不规则内存访问,以提升其在大型语言模型推理中的性能,特别是针对解码密集型工作负载。
FlashMemory-DeepSeek-V4:通过前瞻稀疏注意力实现闪电索引超长上下文
提出在DeepSeek-V4上结合神经记忆索引器的前瞻稀疏注意力,将GPU内存使用降至全上下文基线的约13.5%,同时保持或略微提升准确率。
@levidiamode: GPU编程第157/365天:另一个对我非常有帮助的FlashAttention4资源是@charles_irl的演讲…
一个每日GPU编程帖子重点介绍了Charles_irl的演讲,该演讲在论文发布前逆向工程了FlashAttention4代码,并赞扬了Modal团队对代码的深入剖析和对前向传播的合理推断。
@levidiamode: GPU编程第158/365天——我觉得我大致理解了FlashAttention 2、3和4前向传播的高级区别…
作者记录了学习GPU编程的进展,重点在于理解FlashAttention 2、3和4前向传播的高级区别,并列出了需要进一步探索的几个底层概念。
动态线性注意力
本文提出DLA,一种用于多状态线性注意力的动态内存建模框架,它能根据令牌信息变化自适应地合并状态,并维护固定大小的状态缓存,从而在无需标准注意力二次复杂度的前提下实现更好的长上下文表示。