@charles_irl: 去年秋天,我们分享了关于FA4内部机制的深度分析。但我们并未止步于理解内核。自那时起,我们一直在…

X AI KOLs Following 新闻

摘要

一篇博客文章详细介绍了对FlashAttention-4的贡献,通过调整并行策略和支持不规则内存访问,以提升其在大型语言模型推理中的性能,特别是针对解码密集型工作负载。

去年秋天,我们分享了关于FA4内部机制的深度分析。 但我们并未止步于理解内核。 自那时起,我们一直在开发推理性能的改进并将其提交到上游。 这篇博客文章解释了这些贡献。 https://t.co/xzDNHdq3Zw https://t.co/AzFs33Xqif
查看原文
查看缓存全文

缓存时间: 2026/06/11 23:42

去年秋天,我们分享了关于 FA4 内部机制的深入分析。

但我们并没有停留在理解内核层面。

此后,我们一直在开发推理性能的改进,并将它们上游贡献出去。

这篇博文解释了这些贡献。

https://t.co/xzDNHdq3Zw https://t.co/AzFs33Xqif


让 FlashAttention-4 在推理时更快

来源:https://modal.com/blog/flash-attention-4-faster 去年 FlashAttention-4 内核源码发布时,我们深入研究了它,并详尽地分享了关于内核如何工作的发现(https://modal.com/blog/reverse-engineer-flash-attention-4)。现在你可以通过阅读这篇来自核心团队的官方文章(https://www.together.ai/blog/flashattention-4)来确认我们推断的顶层结构。

在随后的几个月里,我们对这个内核进行了一些改进,使其更适合大型语言模型的推理,特别是解码密集型(decode-heavy)工作负载。与预训练工作负载不同,LLM 推理工作负载通常受内存带宽限制(https://modal.com/gpu-glossary/perf/memory-bound),主要处于“解码”或“token 生成”阶段(下图中浅蓝色部分)。

推理工作负载通常也更加多变——批量大小和序列长度变得不均匀;键值对大多需要从缓存中检索。

这需要新的内核代码,而且这段代码必须快速:“性能即产品”(https://modal.com/gpu-glossary/perf)。

在深入细节之前,先为更广泛的读者提供一些要点。

关于底层编程的高层要点

我们对内核所做的扩展,以适应我们想要运行的推理工作负载,大致可以分为两类:

  • 调整并行策略,即每个线程块的查询瓦片数量,以及从查询并行切换到键/值并行;
  • 支持不规则全局内存访问,即使用 cp.async 加载代替使用张量内存加速器(TMA)(https://modal.com/gpu-glossary/device-hardware/tensor-memory-accelerator)的 cp.async.bulk 加载。

以下图表分别展示了这两类优化,后文将详细说明。

没有 KV 并行和有 KV 并行的输出瓦片生成示意图我们的优化之一是将“split KV”技术移植到 FA4。这实现了跨 KV 瓦片的并行化(右侧)。规则与不规则全局内存访问示意图我们的多项优化需要处理不规则内存访问(右侧),这使用了与规则访问(左侧)不同的指令和硬件。调整并行策略对于提升现代大规模并行硬件(https://modal.com/gpu-glossary/perf/roofline-model)上的性能具有最大的杠杆作用。直观地说:如果你被锁定在某种特定的并行方法中,Amdahl 定律中的顺序项就固定了。如果你能改变并行策略,就可以在算法的并行部分和顺序部分之间移动工作。根据该定律,这通常比提高固定并行部分的速度具有更高的杠杆作用。

我们并没有选择 CUDA 模板领域特定语言(CuTe DSL)(https://modal.com/gpu-glossary/host-software/cute-dsl),这是原始内核作者的选择,但它对我们来说效果很好。它通过快速的 JIT 编译支持高生产力的开发循环,且运行时成本极低或为零。它也比旧工具更容易表达我们的许多想法。请注意,由于使用了模板,FA4 实际上是一个内核家族,如果“内核”大致意思是“可以启动到 CUDA 流中的东西”。我们仍会称它为“内核”。

CuTe DSL 很好。但是,正如我们在上一篇文章(https://modal.com/blog/reverse-engineer-flash-attention-4)中指出的,FA4 最好在瓦片级别上算法性地理解,而不是在实现的 warp(https://modal.com/gpu-glossary/device-software/warp)级别。显然,基于瓦片的适当编程在人体工程学和开发速度(顺便说一句,在智能体时代(https://modal.com/blog/agents-devex)这仍然很重要)上会更好。使用基于瓦片的编程模型,程序员可以更简单地表达和操作瓦片级别的流程。这可以更低的工程成本(第一类更改)来更改或向内核添加算法。此外,更高级的基于瓦片的模型使编译器更容易实现和优化,例如 cp.async 和 TMA 加载路径(第二类),并根据大小等进行调度。

因此,我们非常期待对 CUDA 瓦片编程模型(https://modal.com/gpu-glossary/device-software/cuda-tile-programming-model)的更好支持,以区别于经典的“CUDA SIMT”编程模型(https://modal.com/gpu-glossary/device-software/cuda-programming-model),从而构建未来的注意力机制和矩阵乘法内核。

我们做了什么、为什么这么做,以及我们如何知道效果不错

我们按 Pull Request 来组织我们的贡献。每部分都以“品质因数”开头:用于表明该贡献改进了性能的衡量标准。我们以性能工程师的传统格式——ASCII 表格——来报告这些数据。

PR 2109 (https://github.com/Dao-AILab/flash-attention/pull/2109):支持 FP8 输入(2026年4月17日合并)

品质因数:相对于 bf16 基线,吞吐量最高提升 1.16 倍

训练模型通常需要更高精度的浮点数(https://modal.com/llm-almanac/quant-formats/)来正确累积梯度中的许多微小变化。但在推理时,我们可以使用较低精度。将位宽减少一半,可以将内存和算术带宽需求降低一半,而对模型质量的影响却小得多。

这对于大型模型中的 MLP/MoE 层尤其如此,这些层通常使用微小的“半字节”大小的 4 位浮点数(https://modal.com/llm-almanac/quant-formats/0x6)。注意力操作,尤其是在长上下文上,涉及更多累加,因此更难量化。像 gpt-oss(https://modal.com/docs/examples/gpt_oss_inference)这样的模型将单精度(https://modal.com/llm-almanac/quant-formats/bf::0x0380)注意力操作与 4 位矩阵乘法结合使用,以两全其美。

然而,像 DeepSeek-V3 和 V4(https://modal.com/docs/examples/deepseek_v4)这样的关键模型系列原生地(即从训练开始)支持 8 位(https://modal.com/llm-almanac/quant-formats/e4::0x58)注意力操作。而其他模型,如 Qwen 和 Gemma 系列,有时会部署 8 位 KV 缓存以加速推理。

因此,我们添加了对 8 位浮点数(具有 4 或 5 个指数位,即 e4m3(https://modal.com/llm-almanac/quant-formats/e4::0x38)或 e5m2(https://modal.com/llm-almanac/quant-formats/e5::0x1c))的支持。相对于下面讨论的其他更改,这相当直接:更少的字节移动和操作意味着更快的推理!它还意味着更小的 KV 缓存,这进一步支持更长的上下文和/或推理期间更高的用户并发度。

值得注意的是,加速比低于从位宽减半可能预期的 2 倍,位宽减半同时将内存带宽(https://modal.com/gpu-glossary/perf/memory-bandwidth)和(有效)算术带宽(https://modal.com/gpu-glossary/perf/arithmetic-bandwidth)的需求降低一半。要确定这里的具体瓶颈(https://modal.com/gpu-glossary/perf/performance-bottleneck),需要更详细的分析。但结果与 softmax 操作的瓶颈一致,即使张量核心(Tensor Cores)(https://modal.com/gpu-glossary/device-hardware/tensor-core)在较低精度输入上运行,softmax 操作仍以相同精度(在 CUDA 核心(https://modal.com/gpu-glossary/device-hardware/cuda-core)和/或特殊函数单元(Special Function Units)(https://modal.com/gpu-glossary/device-hardware/special-function-unit)上)运行。

PR 1999 (https://github.com/Dao-AILab/flash-attention/pull/1999) 和 PR 2104 (https://github.com/Dao-AILab/flash-attention/pull/2104):支持任意 KV 页面大小(2025年11月13日合并)并优化性能(2026年1月15日合并)

品质因数:对于小页面大小,吞吐量最高提升 2.40 倍

FlashAttention-4 操作在瓦片(https://modal.com/blog/reverse-engineer-flash-attention-4)上,这些瓦片的大小旨在有效利用 Blackwell 张量核心(https://modal.com/gpu-glossary/device-hardware/tensor-core)。在推理的解码阶段,键和值张量的瓦片由 KV 缓存中的条目构成,这些条目在预填充期间被填充。在原始版本的 FlashAttention-4 中,KV 缓存页面需要与瓦片大小相同。

这个限制源于内核使用了张量内存加速器(TMA)(https://modal.com/gpu-glossary/device-hardware/tensor-memory-accelerator),这是使用 Hopper 和 Blackwell 流式多处理器(SM)架构(https://modal.com/gpu-glossary/device-hardware/streaming-multiprocessor-architecture)的 GPU 中用于某些规则内存访问的硬件引擎。TMA 大幅加速了大型仿射内存访问——即那些看起来像“偏移量加上步长乘以形状”的访问(对于多个步长),就像通过 CuTe Layout(https://modal.com/gpu-glossary/host-software/cute)访问时一样。如果页面大小足够大,这可以很好地用于访问基于页面的 KV 缓存(https://arxiv.org/abs/2309.06180)。

但是 TMA 无法在一次加载中将多个分散的块收集到一个单一的瓦片中,并且它不会加速(甚至可能减慢)较小的加载,这些加载是较小页面大小带来的后果。

因此,我们添加了一条使用 cpasync(CuTe DSL 对 PTX(https://modal.com/gpu-glossary/device-software/parallel-thread-execution)cp.async 指令的封装)的路径,通过 PagedKVManager 实现。

在基于 TMA 的版本中,一个 warp(https://modal.com/gpu-glossary/device-software/warp)中的单个线程(https://modal.com/gpu-glossary/device-software/thread)负责加载一个瓦片——生产者-消费者模型中的“生产者组”是单个线程。

cpasync 版本中,每个线程发起一次加载(warp 的加载由硬件合并(https://modal.com/gpu-glossary/perf/memory-coalescing)),因此它们计算自己的 page 和页面内的 offset。这很简单但效率不高;稍后会详细说明!

我们重新利用了原本空闲的 warp 15 来处理这个额外的工作——生产者组由两个 warp 组成。

在第一个 PR 中,这些较小的页面大小导致算术和内存吞吐量较低。但在许多推理工作负载中,KV 缓存效率非常重要,因此这可能是一个很好的权衡。

首先,大的页面大小可能导致不必要的重复。如果多个请求共享一个例如 64 个 token 的前缀,但之后不同,则使用 page_size=128 的注意力内核将需要为每个请求单独分配一个页面,因为前缀比页面大小短。使用 page_size=16 的注意力内核可以在多个请求之间共享四个页面,从而将所需存储量按请求数量进行乘法级减少(参见下图中左侧三个请求共享前缀“Thou shalt not”,而在右侧使用较大 page_size 的 KV 缓存中,该前缀被重复三次)。

大的页面大小导致 KV 缓存的严重内部碎片化。短序列仍然需要完整的页面——在最坏的情况下,单个 token 消耗了整个本可以容纳 128 个 token 的 KV 缓存数据的页面。对于该块来说,这导致了超过 99% 的内部碎片化。这消耗了大约 8 倍于 page_size=16 的 KV 缓存容量,后者“仅”有 93.75% 的内部碎片化。

这对于推测性解码尤其重要。推测器在 KV 缓存中创建许多短的(大约 1-16 个 token)序列,使用大的页面大小时,每个序列都会占用更多空间。

支持任意页面大小对于兼容性已经是一个胜利,但第一个实现的性能成本很高。对于最极端的情况 page_size=1,在 FA4 内核的内存受限(https://modal.com/gpu-glossary/perf/memory-bound)情况下,内存吞吐量不到有效内存带宽(https://modal.com/gpu-glossary/perf/memory-bandwidth)的一半;在计算受限(https://modal.com/gpu-glossary/perf/compute-bound)情况下,算术吞吐量不到有效算术带宽(https://modal.com/gpu-glossary/perf/arithmetic-bandwidth)的三分之一。我们在后续的 PR(https://github.com/Dao-AILab/flash-attention/pull/2104)中修复了性能问题。

类似的问题也影响了 FlashAttention-3 内核,因此我们将策略移植到了 FA4 的 PagedKVManager 中。

关键举措是将地址生成与地址使用解耦,以减少冗余计算。这是通过“转置”地址生成来完成的,如下所述。该方法也在 Zadouri 等人的论文(https://arxiv.org/abs/2505.21487)第 4.2 节中有详细描述。

我们将每个 warp(https://modal.com/gpu-glossary/device-software/warp)中的 32 个线程(https://modal.com/gpu-glossary/device-software/thread)组织成一个数组,包含四个“行”线程组,每组有八个“列”线程:

我们的原始方法是让每个线程计算它同时负责加载的 KV 缓存行的指针。

这里的加载模式受硬件约束——为了获得良好的内存合并(https://modal.com/gpu-glossary/perf/memory-coalescing),线程应访问连续内存。通过行方向加载,相邻线程最终会冗余计算同一行指针。

不幸的是,这种冗余代价高昂。指针是 64 位的,int64 操作代价高昂(近期数据中心 GPU 的 FLOP 和矩阵乘法 FLOP 的算术带宽(https://modal.com/gpu-glossary/perf/arithmetic-bandwidth)增长远超其他操作的带宽)。当需要计算更多地址时(例如使用较小的页面大小),这个成本更高。

解决方案是预先产生所有 32 个行指针,然后循环执行加载。这引入了跨线程同步(以 warp shuffle 的形式),但比地址计算便宜。

我们使用的具体模式是转置:warp 中一个“行”组内的八个线程为 1)不同的行产生行指针,2)这些行在逻辑上不是连续的。相反,跨组的“列”中的线程负责计算(但不使用)连续的行指针。

这相比于旧方法,内存吞吐量最多提升了 2.4 倍(对于 page_size=1),达到了与我们在更大页面大小下观察到的相同或更高的吞吐量。

PR 1940 (https://github.com/Dao-AILab/flash-attention/pull/1940):在 KV 维度上增加并行性(2025年11月4日合并)

品质因数:对于短查询长度,吞吐量最高提升 4.37 倍

推理性能通常由解码时间主导。一个“典型”的推理请求大部分时间用于一次或少数几次生成 token,基于一个或几个查询与许多缓存的 KV 值进行匹配。

但原始的 FlashAttention-4 内核架构在查询维度上并行化工作,而不是在键/值维度上。对于小批量推理(这对于高交互性、延迟敏感的应用(https://modal.com/blog/decagon-case-study)至关重要),这就像氪石。内核程序的不同可并行实例(协作线程数组(https://modal.com/gpu-glossary/device-software/cooperative-thread-array))的数量通常远低于流式多处理器(SM)(https://modal.com/gpu-glossary/de

相似文章

@derangineer: 游戏中的山羊

X AI KOLs Following

Charles Frye 宣布了一篇博客文章,详细介绍了对 FA4 内部结构的贡献,重点在于已上游的推理性能改进。