有人在他们的 V100 上使用 Flash Attention 2 (ai-bond) 吗?性能如何?
摘要
一位用户对 Flash Attention 2 的 V100 兼容端口进行了基准测试,报告称相比默认的 PyTorch 注意力机制,速度提升了 3 到 17 倍,内存减少了高达 94%。
我刚刚从这里安装了 Flash Attention 2:https://github.com/ai-bond/flash-attention-v100 我进行了一些基本基准测试,内存利用率提升了 4 到 7 倍。然而,基准测试并不总能反映真实场景。**我注意到,回答前的思考时间已被最小化。**以下是我的一些结果:
Test: B=1, H=1, M=128, N=128, D=128, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 17.1 MB, PyTorch: 17.6 MB (Δ: -0.5 MB, -3.1%) (fwd): Custom: 0.09ms, PyTorch: 0.90ms (9.63x speedup) (bwd): Custom: 0.10ms, PyTorch: 2.48ms (24.31x speedup) (tot): Custom: 0.20ms, PyTorch: 3.38ms (17.28x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×9.77e-04 (Bwd): dQ err=9.77e-04 ≤ 3×1.95e-03 dK err=9.77e-04 ≤ 3×1.95e-03 dV err=9.77e-04 ≤ 3×1.95e-03 ====================================================================== Test: B=1, H=1, M=256, N=256, D=256, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 19.3 MB, PyTorch: 21.4 MB (Δ: -2.1 MB, -9.9%) (fwd): Custom: 0.10ms, PyTorch: 0.67ms (7.06x speedup) (bwd): Custom: 0.12ms, PyTorch: 2.18ms (18.49x speedup) (tot): Custom: 0.21ms, PyTorch: 2.85ms (13.38x speedup) Validation: (Fwd): dO err=2.44e-04 ≤ 2×7.32e-04 (Bwd): dQ err=2.44e-04 ≤ 3×4.88e-04 dK err=4.88e-04 ≤ 3×4.88e-04 dV err=4.88e-04 ≤ 3×9.77e-04 ====================================================================== Test: B=1, H=1, M=256, N=256, D=256, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 19.6 MB, PyTorch: 21.8 MB (Δ: -2.2 MB, -10.0%) (fwd): Custom: 0.09ms, PyTorch: 0.90ms (9.57x speedup) (bwd): Custom: 0.12ms, PyTorch: 2.29ms (19.64x speedup) (tot): Custom: 0.21ms, PyTorch: 3.19ms (15.14x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=9.77e-04 ≤ 3×9.77e-04 dK err=9.77e-04 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×1.95e-03 ====================================================================== Test: B=1, H=16, M=1024, N=1024, D=16, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 28.5 MB, PyTorch: 351.9 MB (Δ: -323.4 MB, -91.9%) (fwd): Custom: 0.28ms, PyTorch: 0.94ms (3.36x speedup) (bwd): Custom: 0.70ms, PyTorch: 2.46ms (3.53x speedup) (tot): Custom: 0.98ms, PyTorch: 3.40ms (3.48x speedup) Validation: (Fwd): dO err=2.44e-04 ≤ 2×4.88e-04 (Bwd): dQ err=4.88e-04 ≤ 3×9.77e-04 dK err=4.88e-04 ≤ 3×9.77e-04 dV err=4.88e-04 ≤ 3×7.32e-04 ====================================================================== Test: B=1, H=16, M=1024, N=1024, D=16, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 30.0 MB, PyTorch: 354.4 MB (Δ: -324.4 MB, -91.5%) (fwd): Custom: 0.20ms, PyTorch: 1.30ms (6.38x speedup) (bwd): Custom: 0.41ms, PyTorch: 3.06ms (7.42x speedup) (tot): Custom: 0.62ms, PyTorch: 4.36ms (7.07x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.95e-03 ≤ 3×3.91e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×1.95e-03 ====================================================================== Test: B=1, H=32, M=1024, N=1024, D=16, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 41.8 MB, PyTorch: 688.5 MB (Δ: -646.8 MB, -93.9%) (fwd): Custom: 0.45ms, PyTorch: 1.35ms (3.03x speedup) (bwd): Custom: 1.15ms, PyTorch: 3.77ms (3.29x speedup) (tot): Custom: 1.59ms, PyTorch: 5.12ms (3.21x speedup) Validation: (Fwd): dO err=2.44e-04 ≤ 2×4.88e-04 (Bwd): dQ err=4.88e-04 ≤ 3×9.77e-04 dK err=4.88e-04 ≤ 3×9.77e-04 dV err=4.88e-04 ≤ 3×7.32e-04 ====================================================================== Test: B=1, H=32, M=1024, N=1024, D=16, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 43.8 MB, PyTorch: 691.5 MB (Δ: -647.8 MB, -93.7%) (fwd): Custom: 0.35ms, PyTorch: 2.01ms (5.72x speedup) (bwd): Custom: 0.76ms, PyTorch: 5.09ms (6.72x speedup) (tot): Custom: 1.11ms, PyTorch: 7.10ms (6.40x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.95e-03 ≤ 3×3.91e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×1.95e-03 ====================================================================== Test: B=1, H=16, M=1024, N=1024, D=32, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 43.5 MB, PyTorch: 370.4 MB (Δ: -326.9 MB, -88.3%) (fwd): Custom: 0.25ms, PyTorch: 0.93ms (3.74x speedup) (bwd): Custom: 0.69ms, PyTorch: 2.37ms (3.43x speedup) (tot): Custom: 0.94ms, PyTorch: 3.30ms (3.51x speedup) Validation: (Fwd): dO err=2.44e-04 ≤ 2×7.32e-04 (Bwd): dQ err=2.44e-04 ≤ 3×1.22e-03 dK err=2.44e-04 ≤ 3×1.22e-03 dV err=2.44e-04 ≤ 3×9.77e-04 ====================================================================== Test: B=1, H=16, M=1024, N=1024, D=32, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 43.5 MB, PyTorch: 371.4 MB (Δ: -327.9 MB, -88.3%) (fwd): Custom: 0.18ms, PyTorch: 1.26ms (7.09x speedup) (bwd): Custom: 0.45ms, PyTorch: 3.00ms (6.61x speedup) (tot): Custom: 0.63ms, PyTorch: 4.26ms (6.75x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=9.77e-04 ≤ 3×1.95e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×3.91e-03 ====================================================================== Test: B=1, H=32, M=1024, N=1024, D=32, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 66.8 MB, PyTorch: 720.5 MB (Δ: -653.8 MB, -90.7%) (fwd): Custom: 0.46ms, PyTorch: 1.44ms (3.16x speedup) (bwd): Custom: 1.38ms, PyTorch: 3.93ms (2.85x speedup) (tot): Custom: 1.84ms, PyTorch: 5.37ms (2.93x speedup) Validation: (Fwd): dO err=2.44e-04 ≤ 2×1.22e-03 (Bwd): dQ err=4.88e-04 ≤ 3×1.46e-03 dK err=4.88e-04 ≤ 3×1.46e-03 dV err=4.88e-04 ≤ 3×1.10e-03 ====================================================================== Test: B=1, H=32, M=1024, N=1024, D=32, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 70.8 MB, PyTorch: 725.5 MB (Δ: -654.8 MB, -90.2%) (fwd): Custom: 0.30ms, PyTorch: 2.07ms (6.89x speedup) (bwd): Custom: 0.82ms, PyTorch: 5.27ms (6.46x speedup) (tot): Custom: 1.12ms, PyTorch: 7.34ms (6.58x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.46e-03 ≤ 3×2.93e-03 dK err=1.95e-03 ≤ 3×2.93e-03 dV err=1.95e-03 ≤ 3×1.95e-03 ====================================================================== Test: B=1, H=16, M=1024, N=1024, D=64, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 70.5 MB, PyTorch: 404.4 MB (Δ: -333.9 MB, -82.6%) (fwd): Custom: 0.34ms, PyTorch: 1.02ms (2.97x speedup) (bwd): Custom: 1.00ms, PyTorch: 2.63ms (2.63x speedup) (tot): Custom: 1.34ms, PyTorch: 3.65ms (2.72x speedup) Validation: (Fwd): dO err=1.22e-04 ≤ 2×4.88e-04 (Bwd): dQ err=4.88e-04 ≤ 3×7.32e-04 dK err=4.88e-04 ≤ 3×7.32e-04 dV err=2.44e-04 ≤ 3×4.88e-04 ====================================================================== Test: B=1, H=16, M=1024, N=1024, D=64, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 70.5 MB, PyTorch: 405.4 MB (Δ: -334.9 MB, -82.6%) (fwd): Custom: 0.24ms, PyTorch: 1.38ms (5.73x speedup) (bwd): Custom: 0.69ms, PyTorch: 3.27ms (4.76x speedup) (tot): Custom: 0.93ms, PyTorch: 4.65ms (5.01x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.95e-03 ≤ 3×1.95e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×1.95e-03 ====================================================================== Test: B=1, H=32, M=1024, N=1024, D=64, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 116.8 MB, PyTorch: 784.5 MB (Δ: -667.8 MB, -85.1%) (fwd): Custom: 0.57ms, PyTorch: 1.74ms (3.04x speedup) (bwd): Custom: 1.94ms, PyTorch: 4.81ms (2.48x speedup) (tot): Custom: 2.51ms, PyTorch: 6.54ms (2.61x speedup) Validation: (Fwd): dO err=2.44e-04 ≤ 2×4.88e-04 (Bwd): dQ err=4.88e-04 ≤ 3×1.46e-03 dK err=4.88e-04 ≤ 3×1.46e-03 dV err=4.88e-04 ≤ 3×9.77e-04 ====================================================================== Test: B=1, H=3
相似文章
@rohanpaul_ai: 相当惊人,MiniMax Sparse Attention 在100万token时将注意力计算量减少28.4倍,预填充速度提升14.2倍,以及…
MiniMax Sparse Attention (MSA) 通过增加一个路由分支,选择性选择键值块进行注意力计算,在100万token时实现了注意力计算量最高减少28.4倍,在H800 GPU上实现了14.2倍更快的预填充和7.6倍更快的解码,同时匹配全注意力基准性能。
RDNA2 闪存注意力在官方版本中未启用,我通过这个构建启用了它,速度翻倍
自定义二进制解决方案为 llama.cpp 在 AMD RDNA2 GPU 上启用了闪存注意力,推理速度翻倍(70-80 tok/s,而官方版本崩溃)。仅确认与 Qwen3.6 35B/27B 配合使用。
@levidiamode: GPU编程第158/365天——我觉得我大致理解了FlashAttention 2、3和4前向传播的高级区别…
作者记录了学习GPU编程的进展,重点在于理解FlashAttention 2、3和4前向传播的高级区别,并列出了需要进一步探索的几个底层概念。
@rohanpaul_ai:新阿里与南京大学论文声称百万token预填充速度可提升9.36倍(与FlashAttention-2相比)……
来自阿里巴巴和南京大学的新论文介绍了RTPurbo,这是一种通过仅在必要处选择性应用完整注意力机制(无需重新训练模型)将百万token预填充速度相比FlashAttention-2提升最多9.36倍的方法。
构建了一个AI加速器并将其开源。[P]
作者开源了一个在FPGA上实现的自定义AI加速器(atik),原生支持BF16和注意力机制,展示了在各种模型上相比PyTorch的显著加速效果。