@Andy_ShuoYang: Flash-KMeans 只是一个开始。今天,Flash-KMeans 团队发布了 FlashLib——一个用于……的 GPU 库。

X AI KOLs Following 工具

摘要

Flash-KMeans 团队发布了 FlashLib,这是一个面向经典机器学习算子的 GPU 库,在 Hopper GPU 上相比 cuML 可实现高达 208 倍的加速,专注于为智能体 AI 工作负载提供快速、可预测的性能。

Flash-KMeans 只是一个开始。 今天,我们 Flash-KMeans 团队发布了 FlashLib——一个面向快速、可预测、智能体就绪的经典机器学习算子的 GPU 库。 在 KMeans 上最高 26 倍,KNN 上 19 倍,HDBSCAN 上 40 倍,TruncatedSVD 上 208 倍,PCA 上 47 倍,精确 t-SNE 上 147 倍,MultinomialNB 上 49 倍,均优于当前先进水平 (cuML)。 博客:https://flashml-org.github.io 代码:https://github.com/FlashML-org/flashlib…
查看原文
查看缓存全文

缓存时间: 2026/05/27 03:00

Flash-KMeans 仅仅是一个开始。

今天,来自 Flash-KMeans 团队,我们发布了 FlashLib —— 一个用于快速、可预测、适合智能体的经典机器学习算子的 GPU 库。

在 KMeans 上最高达 26 倍,KNN 上 19 倍,HDBSCAN 上 40 倍,TruncatedSVD 上 208 倍,PCA 上 47 倍,精确 t-SNE 上 147 倍,MultinomialNB 上 49 倍,均超越当前最优 (cuML)。

博客:https://flashml-org.github.io 代码:https://github.com/FlashML-org/flashlib…


FlashLib:为经典机器学习算子带来魔法般的加速

来源:https://flashml-org.github.io/ Shuo Yang1, Haocheng Xi1, Yilong Zhao1, Qiuyang Mang1, Zhe Wang2, Shanlin Sun2, Kurt Keutzer1, Joseph E. Gonzalez1, Song Han3, Chenfeng Xu4,*, Ion Stoica1,*

代码:github.com/FlashML-org/flashlib (https://github.com/FlashML-org/flashlib)

26 倍 KMeans

19 倍 KNN

208 倍 TruncatedSVD

47 倍 PCA

7 倍 UMAP

40 倍 HDBSCAN

147 倍 t-SNE (精确)

49 倍 MultinomialNB

介绍 FlashLib —— 一个为现代硬件上的经典机器学习算子打造的 GPU 库,专为当今的 ML 工作负载和新兴的智能体 AI 系统而重建。以下是第一个版本的一些亮点成果:

  • 在 Hopper GPU 上大幅超越 cuML:KMeans 最高 26 倍,KNN 19 倍,HDBSCAN 40 倍,TruncatedSVD 208 倍,PCA 47 倍,精确 t-SNE 147 倍,MultinomialNB 49 倍
  • 信息丰富的 Flash API:在约 5 微秒的纯 CPU 时间内预测任意工作负载的运行时间、内存占用和开销,无需任何 GPU 性能分析。
  • 快速冷启动,可扩展:FlashLib 使用启发式内核选择,避免长时间自动调优循环,并已支持多 GPU 执行大型工作负载。
  • 接近最优的硬件利用率:FlashLib 驱动内核更接近现代 GPU 的极限,其中 Flash-KMeans 在 H200 上达到 峰值 FLOPs 的 61%,Flash-KNN 达到 峰值 HBM 带宽的 85.2%

AI 效率的下一个前沿不仅仅是更快的 LLM 推理。而是更快的智能组装。在过去几年,MLsys 的工作主要遵循以模型为中心的智能观。随着 LLM 通过更好的推理、更大规模的测试时计算和更强的推理能力变得更强,系统社区专注于加速 Transformer 核心:FlashAttention、FlashDecoding、KV 缓存管理和 LLM 服务系统等。

但智能体 AI 的兴起正在改变瓶颈。现代智能日益围绕基础模型,通过工具、工具集、检索、验证、搜索和编排构建。LLM 不再是独立的推理者;它成为更广泛计算系统的控制器。因此,性能瓶颈不再局限于 Transformer 推理。它扩展到模型周围的整个计算基础。例如,在科学智能体 AI 中,LLM 智能体可能生成假设或候选解,但周围的循环通常依赖搜索、聚类、近邻检索、PCA、SVD 和其他经典机器学习算子进行验证和反馈。在多模态生成和物理 AI 中,模型在进入模型之前必须在线处理、压缩、检索和重组流式特征。这些例子指向一个更广泛的转变:经典机器学习算子正在成为围绕 LLM 模型的核心原语。我们设想未来的智能体工作流中,聚类、检索、降维、验证和线性代数不再是离线工具,而是成为智能组装关键路径上的在线原语。图 1 说明了这一转变。

五种经典机器学习算子从批次延迟层迁移到毫秒服务层,跨越十年,并标有细化标签。与之前的延迟图相同,有两个标签细化:视频生成现在是流式视频生成,PCA 基础的 KV 压缩缩写为 PCA 基础的压缩。K-means, k-NN, TruncatedSVD, PCA, HDBSCAN, 1 ms, 10 ms, 100 ms, 1 s, 1 min, 1 hr, 2015, 2018, 2021, 2024, 2027。年份算子进入此延迟层。用户分群,特征降维,话题建模,文档聚类;项目-项目推荐系统,管线 PCA,嵌入压缩,话题发现;RAG 检索,语义缓存,PCA 基础压缩,SVD 基础压缩,流式视频生成,KV 缓存聚类,智能体工具路由。 图 1:过去十年间,经典 ML 算子 (KMeans, k-NN, TruncatedSVD, PCA, HDBSCAN) 的延迟预算在对数尺度上持续下降。曾经在分钟到小时层级离线运行的相同原语 (用户分群、话题建模、批次特征降维) 现在正被调用到在线服务路径 (RAG 检索、语义缓存、KV 缓存聚类、智能体工具路由) 中,其预算以毫秒为单位。随着这一趋势继续,系统社区需要这些算子的实现,这些实现不仅快速、硬件高效、可靠,而且在数值上足够忠实,能够胜任关键路径。*悬停 (或点击) 任意点以查看它所代表的具体工作。*然而,这些经典算子的底层实现并未跟上这一转变。它们的核心设计假设仍然来自 FlashAttention 之前、Hopper 之前、智能体之前的时代,这造成了四个方面的不匹配。首先,许多算子具有天然不适用于 GPU 的实现。其次,许多库针对所有工作负载和硬件层级提供单一静态内核实现,导致现代 GPU 硬件特性未被利用。第三,许多库不了解用户的精度需求:它们没有提供声明精度预算的方法,导致用户无法要求满足其容差的最快算法。第四,性能是黑盒:性能分析成本高昂,难以修改,如果不先阅读代码库就无法进行预算,这让开发者和基于 LLM 的智能体都处于黑暗之中。

FlashLib 是我们尝试弥合这些差距并加速这一新兴基础的努力,使其足够快以嵌入智能体 AI 的循环中。它将经典 ML 算子从慢速的离线工具转变为快速的在线 ML 原语。此外,FlashLib 暴露了 flash-informative API,向更高级别的智能体管线揭示这些原语的成本、容差和执行行为,从而实现更好的调度和编排。我们还想指出,虽然 FlashLib 是受 LLM 中心和智能体 AI 系统新兴需求驱动,但我们也认识到经典 ML 算法在当今的机器学习栈中仍然广泛使用。除了生成式 AI 之外,KMeans、KNN、PCA、SVD、t-SNE 和 HDBSCAN 等算子仍然是推荐系统、检索管线、科学计算、异常检测、可视化和下游 ML 模型预处理的基石模块。FlashLib 提供了一个快速、易用且自适应的软件栈,通过即插即用的 GPU 加速覆盖这些多样化的应用。

我们围绕四个设计原则构建 FlashLib。首先,我们在实现数学等价性的同时,重塑算法以适应硬件。其次,我们为每个算子构建内核变体,利用现代硬件特性在不同硬件上充分探索不同工作负载。第三,我们让用户声明一个精度预算,并路由到满足该预算的最快算法。第四,我们保持整个库足够透明,使用户和 LLM 智能体能够轻松阅读、组合和修改内核。

01 / 04 重新表述

数学等价的重新表述:将算子重写为对 GPU 友好

许多经典 ML 算子具有天然不适用于 GPU 的实现:它们在 HBM 中物化大型中间结果,引入原子竞争,或在分块不佳的维度上运行规约。FlashLib 的第一个原则是将它们重写为数学上等价但对现代加速器友好的形式。KMeans 分配是最清晰的例子:自然实现在 HBM 中形成一个 N × K 距离矩阵,并对每行运行 argmin,但流式融合版本将运行中的局部最小值保存在寄存器中,从不物化矩阵。同样的模式在整个库中重复出现:KNN 的融合 top-K 跳过 ‖xy‖2 = ‖x‖2 + ‖y*‖2 − 2⟨x,y⟩ 中的 ‖x‖2 项,PCA 的双 Gram 路径选择 XX (D×D) 和 X X⊤ (N×N) 中较小的一个,避免浪费 O(max(N,D)3) 的 eigh,而 cuML 的固定 D×D 路径在宽数据上运行的是这个;MultinomialNB 将原子散列更改为段级规约,t-SNE 的梯度从不物化 N×N Q 矩阵。

02 / 04 硬件感知内核

硬件感知实现:针对不同硬件和工作负载的内核变体

为了将这些数学公式直接映射到芯片,FlashLib 构建了多个内核变体,以适应硬件和工作负载。Flash-KNN 说明了这种方法。首先,在后端层,我们提供了一个可移植的 Triton 实现,适用于 Ampere 和 Hopper。对于 Hopper,一个可选的 CuteDSL FA3 后端额外解锁了现代特性,如 TMA 获取和 warp 专用流水线。其次,在内核层,设计适应工作负载。对于大查询,内核镜像标准 FlashAttention 以最大化 TensorCore 利用率。对于针对巨大语料库的小查询,内核镜像 Flash-Decoding:一个 split-k 布局沿语料库维度分布工作,防止 SM 空闲。第三,在启发式层,我们根据硬件特性 (如缓存大小和寄存器容量) 选择超级参数,如 tile 大小和 warp 计数。结果是,即使是一个 Q=1 对 1 亿向量语料库的查询,内核也能在 H200 上保持 峰值 HBM 带宽的 85.2%

03 / 04 容差路由

容差驱动分发:在精度预算内路由到最快算法

FlashLib 还将速度-精度权衡作为用户选择暴露出来。经典科学计算通常要求 FP32 甚至 FP64 的高精度——用于求解 PDE、认证数值方法,或任何微小误差会级联成错误答案的地方。许多 AI 工作负载没有这样的要求:对嵌入向量进行聚类 pass、top-K 检索或对噪声数据回归,可以吸收一个小的声明残差以换取显著加速。FlashLib 将此区别交由用户通过每个调用的参数 tol 来划定。在 tol=None 时,规约保持精确精度,调用从上述的内核融合优势中获益。在 tol > 0 时,分发器通过精度模拟 (融合变体如 3xbf16 和 Ozaki-II INT8) 和算法替换 (Halko 子空间迭代) 的 Pareto 前沿路由,选择在声明残差内吞吐量最高的那个。

# GEMM:相同的调用,不同的 tol -> 不同的变体。
flashlib.gemm(A, B)               # 精确 fp32
flashlib.gemm(A, B, tol=1e-3)     # bf16
flashlib.gemm(A, B, tol=1e-5)     # 3xbf16 (cute-fused)
flashlib.gemm(A, B, tol=1e-7)     # ozaki2_cute(s=8):更紧且更快
flashlib.gemm(A, B, tol=1e-12)    # ozaki2_int8(s=14):FP64 级别

# PCA:tol 释放算法替换,而不仅仅是精度。
flashlib.flash_pca(X, K=32)            # 精确 eigh 在 Gram / cov 矩阵上
flashlib.flash_pca(X, K=32, tol=1e-4)  # Halko 子空间:约 30 倍更快

04 / 04 成本可预测 API

智能体原生 API:对用户和智能体透明的源码和可预测成本

在基于 LLM 的智能体越来越频繁地读取、调用和修改性能代码的时代,库的成本不仅仅是其内核吞吐量,还有其成本模型和源码的可读性。FlashLib 使用 Triton 和 CuteDSL 编写,没有不透明的二进制文件——从 flash_kmeans(...)tl.dot 调用的每个内核都是可编辑的。并且每个原语都附带一个无需 GPU 的成本预测接口:flashlib.info.estimate(...) 接受一个形状和容差,返回一个运行时间、FLOPs、HBM 字节和绑定状况的递归树,在纯 CPU 上大约 5 微秒内完成,绝不导入 torch、triton 或 cutlass。LLM 智能体可以组合一个包含十个原语的管线,遍历该成本树,并在花费一个 FLOP 之前决定预算是否合适。

import flashlib.info as info   # 纯 stdlib -- 无需 torch/triton/cutlass。

# 在不接触 GPU 的情况下预测成本 -- 纯 CPU 上约 5 微秒。
est = info.estimate("pca", shape=(1_000_000, 512), params={"K": 32}, device="H200")

print(est.summary_line())
# pca   13.18 ms  bound=compute   42 TF  (83% peak)  res~1e-7  [roofline]

est.print_tree()              # 遍历递归调用栈树
# pca           13.18 ms  2.18 GB  compute   42 TF  res~1e-7
# ├── cov_gemm  10.49 ms  2.05 GB  compute   50 TF
# ├── eigh       0.12 ms  0.00 GB  compute    3 TF
# └── transform  2.57 ms  2.18 GB  compute   13 TF

# 在 H200 上 4Kx4Kx4K 矩阵乘法的 Pareto 最优 GEMM 变体:
for v in info.pareto("gemm", shape=(4096, 4096, 4096), device="H200"):
    print(v)
# Variant('gemm_fp16'           : 0.2 ms  residual~8e-04)
# Variant('gemm_tf32'           : 0.4 ms  residual~3e-04)
# Variant('gemm_3xfp16'         : 0.5 ms  residual~2e-04)
# Variant('gemm_fp16_x3_kahan'  : 0.6 ms  residual~5e-07)
# Variant('gemm_ozaki2_cute'    : 0.8 ms  residual~2e-15)

基准测试

以下所有结果均在单个 NVIDIA H200 (SM90, 150 GB HBM3e) 上测量,使用 CUDA 13.0、驱动580.126PyTorch 2.11Triton 3.6,并与 cuML 25.10 比较。每个单元格是 5 次迭代的中位数,第一次调用丢弃以摊销 JIT 成本;输入在两侧均驻留在 GPU 上;匹配算法行 (相同的 algorithmmethodsvd_solver) 与降低精度和算法捷径行配对,以便在每个形状上进行公平比较。

1. 广度:在 13 个原语上相比 cuML 的加速比

第一个基准测试是大范围扫描:13 个原语 × 194 个 (shape, dtype, hyperparameter) 单元格,全部在相同的 H200 上与 cuml 25.10 运行。**这里每个单元格都是同类比较:匹配的算法、匹配的精度、匹配的超参数——FlashLib 无需使用任何降低精度的 GEMM (无 bf16/fp16/Ozaki) 或算法捷径 (无 Halko、无 FFT t-SNE、无 NN-Descent KNN)。**因此,下面的柱状图严格低于本文顶部的标题数字:英雄统计数据是 FlashLib 在每个原语上的最佳上限加速比 (如果适用,确实允许用户通过原则 03 的 tol 按钮权衡精度或算法精确性与吞吐量),而广泛扫描故意移除这些自由度以孤立纯内核工程的优势。下面的柱状图将每个原语的 8–34 个单元格压缩到一根柱状条——Geomean 显示该原语所有单元格的几何平均加速比,Max 显示在最优单元格上的每原语上限。FlashLib 在 193 / 194 个单元格上至少与 cuML 一样快;126 个单元格超过 5 倍,11 个超过 50 倍。

亮点点击五个原语按钮之一将其固定到顶部,然后在 geomean 和 max 之间切换,观察排序和柱状条长度平滑变化。

2. 深度:一个原语 (GEMM) 内部精度 × 运行时间权衡

第二个基准测试深入单个原语,即 H200 上的 40963 GEMM,以展示容差路由在实践中的样子。FlashLib 在 flashlib.linalg.gemm 中提供约 10 个 GEMM 变体——bf16、fp16、tf32、融合多通道 (3xbf16fp16_x3_kahan) 以及 Ozaki-II INT8 系列 (ozaki2_cuteozaki2_int8)——全部位于原则 03 中单个 tol 路由分发器之后。下面的散点图绘制了每个变体的 RMS 相对误差 (相对于 FP64 参考) 与每次调用运行时间。虚线是 Pareto 前沿:它下面的变体优于其余变体。有趣的点是 ozaki2_cute(s=8):它位于 fp32 的下方 左侧,意味着它在相同误差下更快。ozaki2_int8(s=14) 甚至更左,接近 FP64 误差但速度快得多。这就是容差路由的实际运作方式:用户设置 tol,FlashLib 自动选择前沿上满足约束的变体。

相似文章

Flash-GMM:一种用于可扩展软聚类的内存高效内核

Hugging Face Daily Papers

Flash-GMM 引入了一个用于高斯混合模型的融合Triton内核,实现了20倍加速,并能在单个GPU上训练比之前大100倍的数据集,使软聚类成为近似最近邻搜索中k-means的可行替代方案。