@hamzaelshafie: 新深度博客文章:《剖析ThunderKittens:高性能AI内核的紧凑型DSL解剖》这篇帖子……

X AI KOLs Following 工具

摘要

一篇详细分析ThunderKittens的博客文章,ThunderKittens是用于高性能AI内核的紧凑型DSL。文章包括从底向上的抽象分析,以及一个实现非因果注意力预填充内核的基准测试,该内核比FlashAttention-2快约1.55倍,与FlashAttention-3性能相当。

新深度博客文章:《剖析ThunderKittens:高性能AI内核的紧凑型DSL解剖》 这篇帖子是我从底层向上剖析ThunderKittens的尝试。我通过提出每个抽象真正给我们带来了什么来研究TK:它对应哪个硬件细节,如何映射到GPU实际需要的底层布局,移除了哪些样板代码,以及GPU编程模型的哪些部分仍然对我们作为内核作者可见。 帖子详细介绍了TK提供的瓦片抽象:寄存器、共享内存和张量内存瓦片,全局布局,向量抽象,线程束/线程束组计算,TMA,交换,Hopper WGMMA,Blackwell tcgen05,2xSM MMA,张量内存,集群启动控制,TK的流水线模板,以及静态持久化瓦片调度。 在末尾,我通过实现一个非因果注意力预填充内核并在H100 PCIe上针对不同序列长度与FlashAttention-2和FlashAttention-3进行基准测试,演示了TK的lcf流水线模板。该内核在整个扫描范围内平均比FA2快约1.55倍,并且与FA3紧密接近,其中FA3在较长序列长度上仅快约1.05倍至1.17倍。 博客链接:https://hamzaelshafie.bearblog.dev/dissecting-thunderkittens-anatomy-of-a-compact-dsl-for-high-performance-ai-kernels/… 仓库:https://github.com/HamzaElshafie/tk_attention… 我还在末尾提供了详尽的资源列表,对感兴趣的读者非常有用。 请注意:这是我个人的独立撰写。我与@HazyResearch无关联,文中如有错误均由我负责。若发现错误,请与我联系! 1 / xx
查看原文
查看缓存全文

缓存时间: 2026/05/21 21:37

新博文深入解析:“剖析 ThunderKittens:一个面向高性能 AI 内核的紧凑型 DSL 的解剖结构” 本文尝试从底层到上层全面剖析 ThunderKittens。我通过询问每个抽象究竟为我们带来了什么来切入 TK:它对应哪个硬件细节,如何映射到底层 GPU 实际需要的布局,它消除了哪些样板代码,以及 GPU 编程模型中有哪些部分仍然暴露给我们这些内核作者。本文详细介绍了 TK 提供的 tile 抽象:寄存器、共享和张量内存 tile、全局布局、向量抽象、warp/warpgroup 计算、TMA、swizzle、Hopper WGMMA、Blackwell tcgen05、2xSM MMA、张量内存、集群启动控制、TK 的流水线模板以及静态持久 tile 调度。最后,我通过实现一个非因果注意力预填充内核,并在 H100 PCIe 上针对不同序列长度与 FlashAttention-2 和 FlashAttention-3 进行基准测试,来演示 TK 的 lcf 流水线模板。该内核在整个扫描范围内平均比 FA2 快约 1.55 倍,并且紧追 FA3,在较长序列长度上 FA3 仅快约 1.05-1.17 倍。博客链接:https://hamzaelshafie.bearblog.dev/dissecting-thunderkittens-anatomy-of-a-compact-dsl-for-high-performance-ai-kernels/… 仓库:https://github.com/HamzaElshafie/tk_attention… 我还在末尾附上了一组广泛的参考资料,我认为对感兴趣的读者非常有用。请注意:这是我个人的独立撰文。我与 @HazyResearch 无关,文中的任何错误均由我负责。如果您发现任何错误,请与我联系!1 / xx


剖析 ThunderKittens:一个面向高性能 AI 内核的紧凑型 DSL 的解剖结构

来源:https://hamzaelshafie.bearblog.dev/dissecting-thunderkittens-anatomy-of-a-compact-dsl-for-high-performance-ai-kernels/
2026年5月21日

引言

现代 ML 工作负载高度依赖自定义 GPU 内核。即使模型以清晰的张量运算表达,性能几乎总是来自底层专门的实现。这方面的典型例子包括各种不同的注意力机制、不同精度下的 GEMM,以及 MoE 风格的分组 GEMM,这些已成为当前最先进模型中相当常见的架构选择。如果从缩放定律的角度来看,这一点非常重要。更好的模型通常来自算法质量、更多数据和更多算力的某种组合。如果我们想继续推动这一进步,我们关心的不仅是算法质量,还有这些算法在硬件上实际运行的效率。正如 Tri Dao 所说(https://www.youtube.com/watch?v=5qSN-R_E3w0),一个清晰的表述是:

IntelligenceDollar = IntelligenceFLOPS⏟算法与数据效率 × FLOPSDollar⏟硬件效率

我们希望改进这两个项。在算法方面,研究人员需要快速迭代新架构以及新的训练和推理方案。但要让这一切在大规模上发挥作用,必须将其转化为能在真实硬件上快速运行的代码。这里存在一个持续的矛盾:我们希望编程环境对研究足够高效,同时又足够贴近底层以获得卓越性能并良好扩展。这正是 GPU 编程 DSL 所处的空间,它们覆盖了相当广泛的范围。在高端,像 PyTorch 这样的框架让研究人员无需考虑 GPU 即可编写张量表达式。框架负责内核调度,通过 PyTorch 2(https://pytorch.org/get-started/pytorch-2-x/),TorchDynamo + TorchInductor 可以生成有竞争力的 GPU 代码,通常通过发射 Triton 实现。一个更低的层次,Triton 提供了对分块、内存访问模式和程序结构更显式的控制,同时仍隐藏了大部分 CUDA 复杂性。再往下走,进入 CUDA C++、CUTLASS/CuTe 或 PTX,我们获得了直接的硬件控制,但现在需要管理内存布局、warp 同步、张量核心调度以及大量样板代码。越深入,我们能够推理的 GPU 层次结构就越多,但做任何事情所需的专业知识和代码也越多。

ThunderKittens(https://github.com/HazyResearch/ThunderKittens/tree/main)是斯坦福大学 Hazy Research Lab(https://hazyresearch.stanford.edu/)的一个嵌入在 CUDA 中的 DSL,它处于这个谱系上一个真正有趣的位置。其背后的研究问题很清晰:编程抽象可以做到多小,同时仍能支持广泛 AI 工作负载上的快速内核? TK 既没有隐藏硬件,也没有暴露所有细节,而是找到了一个中间地带:它抽象了重复的管道工作——tile 布局、共享内存分配、寄存器片段、TMA 张量映射、张量核心描述符——同时仍让我们足够接近底层,以便仔细推理数据移动的位置、流水线如何分段以及工作如何调度。而且由于它嵌入在 CUDA 中,当我们需要库未暴露的功能时,我们总能降级到原始 CUDA 或内联 PTX。

这就是我在这篇文章中想要使用的框架。我想了解 TK 暴露了哪些抽象,它们为什么如此设计,以及它们如何映射到 Hopper 和 Blackwell GPU 硬件。我们将从核心编程模型开始:全局布局、共享 tile、寄存器 tile、向量、计算包装器和内存移动。然后我们将看看最新的 Blackwell 专用新增功能:tcgen05、2xSM MMA、张量内存和集群启动控制。最后,我们将通过使用 TK 的 lcf 流水线模板构建一个注意力预填充内核,并对比 FlashAttention-2 和 3 进行基准测试,使所有内容具体化。

ThunderKittens 编程模型与核心抽象

在深入之前,先退一步思考 ThunderKittens 试图做什么是有帮助的。在高层,TK 提供了一套精心设计的抽象,这些抽象很好地映射到 AI 内核。TK 不是直接用原始 CUDA 编写所有内容,而是让我们使用 tile 以及对这些 tile 执行的高级操作。其中许多操作感觉与 PyTorch 原语有些相似,这为具有 ML 背景的人提供了一种熟悉的内核开发体验。这确实是该框架背后的核心思想:降低编写高性能内核的复杂性,同时不放弃高效利用现代 GPU 所需的控制。正如论文所述:“尽管表面上需要大量技术来利用所有这些硬件功能,但我们的核心技术发现是,实际上,对于许多 AI 内核,存在少量关键抽象可以简化编写高性能内核的过程。” 在真正理解这一主张之前,对 GPU 硬件本身有一个扎实的心理模型非常重要。如果 warp、线程块、共享内存、张量核心和占用率等概念尚未直观,我强烈建议先阅读我的 H100 GEMM(https://hamzaelshafie.bearblog.dev/worklog-optimising-gemm-on-nvidia-h100-for-cublas-like-performance-wip/)优化文章的开头部分。对于较新的 Blackwell 组件,我们将在本博客中讨论,请耐心等待。有了这个基础,TK 编程模型就变得更容易推理了。从那里,我们可以开始查看 ThunderKittens 围绕的核心抽象,以及为什么它们如此自然地映射到现代 GPU 硬件上。

作为快速回顾,以下是 GPU 内存层次结构及其对应的 CUDA 编程模型,以及 TK 的 tile 抽象(请原谅我笨拙的三角形):

TK_pyramid_fixed

我们现在可以看到 TK 的 tile 抽象如何融入这个内存层次结构的视角。Tile 抽象是 TK 编程模型的基本构建块;所有其他组件都分层在其之上。在高层,ThunderKittens 的主要抽象如下所示:

TK-toolbox

Tile 抽象

TK 建立在这样一个理念之上:所有内容都应该以 tile 的形式表达,而这些 tile 能够干净地映射到 GPU 层次结构上。在最基本的层面上,TK 使用一个高度固定为 16 的基础 tile,而宽度取决于数据类型。对于 fp16bf16 和其他 16 列的情况,基础 tile 是 16×16。对于 1 字节类型(如 fp8),它变为 16×32(https://hazyresearch.stanford.edu/blog/2024-11-27-tk-fp8)。这不是一个任意的选择。它直接源于张量核心指令如何暴露其片段(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=wgmma#asynchronous-warpgroup-level-matrix-fragment),以及这些片段如何需要通过共享内存进行 staging。我们已经在较老的张量核心指令中看到了这一点。在 SASS 中,HMMA 指令通常作用于诸如 16×8×1616×8×8 的形状,因此即使硬件指令本身实际上不是方形的 16×16 乘法,16 的幅度在片段布局中仍然是真实存在的。在 Hopper 上,warp group 版本表现为 HGMMA,具有更大的形状,例如 64×256×16。所以重点不在于每个张量核心指令恰好是 16×16。重点在于硬件暴露了非常强的以 16 为基础的结构,更大的片段是通过重复组合该结构的各个部分来构建的。一旦我们查看 Hopper WGMMA 输入片段,数据类型依赖性就变得特别明显。对于 16 位输入,m64nNk16 系列自然暴露了一个 64×16 的输入 slab,TK 可以将其视为重复的 16×16 片段。对于 fp8,相应的系列是 m64nNk32,现在同一个输入侧加宽到 64×32。这正是 TK 将行粒度固定为 16,但允许列粒度对于 1 字节类型加宽的原因。抽象仍然是相同的抽象,但它所围绕的硬件片段现在更宽了。

(修正后)wgmma 的寄存器片段视图

因此,当 TK 构建更大的 tile 时,最简单的思考方式是堆叠这些基础片段,直到覆盖完整的 tile 形状。一个更大的 16 位 tile,例如 st_bf<64,64>,可以看作是一个 4×4 排列的 16×16 片段。一个 fp8 tile,例如 st_fp8e4m3<64,64>,则变成了一个 4×2 排列的 16×32 片段。这是推理 TK 处理形状的有用方式,即使共享 tile 本身仍然存储为单个共享内存对象。

逻辑组合图

灯泡表情符号

在继续之前有一个小提示:TK 并不会试图暴露 Hopper 或 Blackwell 硬件支持的每一种数值格式。在我正在查看的当前仓库状态(commit #01cb68c(https://github.com/HamzaElshafie/ThunderKittens/commit/01cb68cd9c0693eb56e6e21fd7e0413461a1648e))中,主要的 tile 类型涵盖了 TK 实际围绕构建的格式:bf16half(FP16)、float(FP32)、常见的 FP8 格式 fp8e4m3fp8e5m2、Blackwell 的 fp8e8m0,以及打包的 FP4 存储,例如 fp4e2m1_2。FP4 是在我们心理模型中需要额外注意的一个。TK 确实定义了一个标量 fp4e2m1 类型,但 tile 类型使用打包存储,因为单个 FP4 值只有 4 位。两个 FP4 值自然适合一个字节,因此面向 tile 的类型类似于 fp4e2m1_2。这意味着打包的 FP4 在 TK 的可寻址 tile 单元中仍然看起来像一个 16×32 的基础 tile,但如果我们计算实际的单个 FP4 值,同一个 tile 代表 16×64 个标量。所以一般规则仍然成立,我们只需要记住一旦数据类型本身被打包,“一个元素”意味着什么。

这是关键思想。TK 的 tile 形状自然地适应张量核心风格的矩阵片段,不是因为每条指令实际上就是一个方形 tile,而是因为它们提供了一个干净的软件单元来组合硬件实际使用的更大片段。这在较老的 HMMA 风格指令中已经可见,其中 16 的粒度很明显,并且在 Hopper HGMMA 指令中仍然成立,尽管完整的 warp group 操作要大得多。因此,当 TK 构建更大的东西时,并不是在 GPU 上强制推行一个笨拙的抽象。它只是将更小的片段组合成已经匹配硬件工作方式的更大 tile。在实践中,TK 围绕三个核心 tile 级抽象构建其编程模型:全局布局描述符(gl共享 tile(st寄存器 tile(rt。除此之外,它还提供了共享和寄存器 向量抽象(svrv,用于那些更自然地以向量形式表达的内核,例如 LayerNorm 或 RMSNorm。也就是说,TK 2.0 现在主要针对 Hopper 和 Blackwell GPU 构建和测试。该项目特别声明它不再积极支持 Ampere

在更详细地查看各个抽象之前,快速注意一下 TK 代码库中常见的几个通用常量和线程索引辅助函数会很有用。这些通常直接以 kittens::xxx 形式访问。

常量/辅助函数意图
BASE_TILE_DIM16基础的 16 粒度。
TILE_COL_DIM16,或对于 1 字节类型(如 FP8)为 32类型 T 的基础 tile 宽度。
TILE_ROW_DIM16类型 T 的基础 tile 高度。
TILE_ELEMENTSTILE_COL_DIM * TILE_ROW_DIM一个基础 tile 中的元素数量。
WARP_THREADS32一个 warp 中的线程数。
WARPGROUP_THREADS128一个 warpgroup 中的线程数。
WARPGROUP_WARPS4一个 warpgroup 中的 warp 数。
warpid()threadIdx.x >> 5块中的 warp 索引。
warpgroupid()threadIdx.x >> 7块中的 warpgroup 索引。
laneid()threadIdx.x & 0x1fwarp 中的 lane 索引。

这为本章节的其余部分提供了一个更清晰的基础,因为现在后面的抽象可以基于相同的底层思想来理解:TK 保持行粒度固定为 16,让宽度跟随数据类型,然后通过组合已经匹配硬件片段结构的形状来构建更大的共享和寄存器级对象。

  1. 寄存器 tile:寄存器 tile 是 TK 用于在计算期间存储在寄存器中的值的主要抽象。在 GEMM 风格的内核中,这些通常是持有累加器片段的 tile,这就是寄存器 tile 在张量核心指令周围如此突出的原因。在源代码中,一般形式是 rt,因此类型由数据类型、形状和布局参数化。但在实践中,我们通常会看到较短的别名,例如用于 FP32 寄存器 tile 的 rt_fl 或用于 BF16 寄存器 tile 的 rt_bf。在底层,寄存器 tile 遵循我们上面介绍的“构建块”故事。行粒度保持固定为 16,而宽度取决于数据类型。在 rt_base.cuh 中,基本寄存器片段直接使用 TILE_ROW_DIMTILE_COL_DIM,然后在 rt.cuh 中,更大的寄存器 tile 通过将这些基础片段显式组成一个 2D 网格来形成:rt_base tiles[height][width]。因此,例如,一个 rt_fl<64, 64> 内部是一个 4×416×16 基础寄存器片段网格。对于 fp8 寄存器 tile,同样的想法仍然成立,只是基础片段加宽到 16×32,因此得到的网格相应改变。这是寄存器 tile 抽象特别清晰的地方之一:与共享 tile 不同(共享 tile 只能在逻辑上接受这种解释),寄存器 tile 在源代码中直接表示为数据类型特定基础片段的网格。

这种数据类型依赖性在 TK 需要在寄存器 tile 类型之间进行转换时变得尤为重要。人们很容易认为这只是独立地转换每个值。但对于张量核心寄存器片段,这更复杂。这些值并不是作为简单的连续数组存在于单个线程中。它们已经按照硬件预期的布局分布在各个 lane 中,就像我们在前面的可视化中看到的那样。对于 16 位类型(如 fp16bf16),一个 32 位寄存器自然地打包两个值。对于 fp8,同样的 32 位数量打包四个值。因此,一个 fp8

相似文章