有人在他们的 V100 上使用 Flash Attention 2 (ai-bond) 吗?性能如何?

Reddit r/LocalLLaMA 工具

摘要

一位用户对 Flash Attention 2 的 V100 兼容端口进行了基准测试,报告称相比默认的 PyTorch 注意力机制,速度提升了 3 到 17 倍,内存减少了高达 94%。

我刚刚从这里安装了 Flash Attention 2:https://github.com/ai-bond/flash-attention-v100 我进行了一些基本基准测试,内存利用率提升了 4 到 7 倍。然而,基准测试并不总能反映真实场景。**我注意到,回答前的思考时间已被最小化。**以下是我的一些结果: Test: B=1, H=1, M=128, N=128, D=128, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 17.1 MB, PyTorch: 17.6 MB (Δ: -0.5 MB, -3.1%) (fwd): Custom: 0.09ms, PyTorch: 0.90ms (9.63x speedup) (bwd): Custom: 0.10ms, PyTorch: 2.48ms (24.31x speedup) (tot): Custom: 0.20ms, PyTorch: 3.38ms (17.28x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×9.77e-04 (Bwd): dQ err=9.77e-04 ≤ 3×1.95e-03 dK err=9.77e-04 ≤ 3×1.95e-03 dV err=9.77e-04 ≤ 3×1.95e-03 ====================================================================== Test: B=1, H=1, M=256, N=256, D=256, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 19.3 MB, PyTorch: 21.4 MB (Δ: -2.1 MB, -9.9%) (fwd): Custom: 0.10ms, PyTorch: 0.67ms (7.06x speedup) (bwd): Custom: 0.12ms, PyTorch: 2.18ms (18.49x speedup) (tot): Custom: 0.21ms, PyTorch: 2.85ms (13.38x speedup) Validation: (Fwd): dO err=2.44e-04 ≤ 2×7.32e-04 (Bwd): dQ err=2.44e-04 ≤ 3×4.88e-04 dK err=4.88e-04 ≤ 3×4.88e-04 dV err=4.88e-04 ≤ 3×9.77e-04 ====================================================================== Test: B=1, H=1, M=256, N=256, D=256, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 19.6 MB, PyTorch: 21.8 MB (Δ: -2.2 MB, -10.0%) (fwd): Custom: 0.09ms, PyTorch: 0.90ms (9.57x speedup) (bwd): Custom: 0.12ms, PyTorch: 2.29ms (19.64x speedup) (tot): Custom: 0.21ms, PyTorch: 3.19ms (15.14x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=9.77e-04 ≤ 3×9.77e-04 dK err=9.77e-04 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×1.95e-03 ====================================================================== Test: B=1, H=16, M=1024, N=1024, D=16, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 28.5 MB, PyTorch: 351.9 MB (Δ: -323.4 MB, -91.9%) (fwd): Custom: 0.28ms, PyTorch: 0.94ms (3.36x speedup) (bwd): Custom: 0.70ms, PyTorch: 2.46ms (3.53x speedup) (tot): Custom: 0.98ms, PyTorch: 3.40ms (3.48x speedup) Validation: (Fwd): dO err=2.44e-04 ≤ 2×4.88e-04 (Bwd): dQ err=4.88e-04 ≤ 3×9.77e-04 dK err=4.88e-04 ≤ 3×9.77e-04 dV err=4.88e-04 ≤ 3×7.32e-04 ====================================================================== Test: B=1, H=16, M=1024, N=1024, D=16, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 30.0 MB, PyTorch: 354.4 MB (Δ: -324.4 MB, -91.5%) (fwd): Custom: 0.20ms, PyTorch: 1.30ms (6.38x speedup) (bwd): Custom: 0.41ms, PyTorch: 3.06ms (7.42x speedup) (tot): Custom: 0.62ms, PyTorch: 4.36ms (7.07x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.95e-03 ≤ 3×3.91e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×1.95e-03 ====================================================================== Test: B=1, H=32, M=1024, N=1024, D=16, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 41.8 MB, PyTorch: 688.5 MB (Δ: -646.8 MB, -93.9%) (fwd): Custom: 0.45ms, PyTorch: 1.35ms (3.03x speedup) (bwd): Custom: 1.15ms, PyTorch: 3.77ms (3.29x speedup) (tot): Custom: 1.59ms, PyTorch: 5.12ms (3.21x speedup) Validation: (Fwd): dO err=2.44e-04 ≤ 2×4.88e-04 (Bwd): dQ err=4.88e-04 ≤ 3×9.77e-04 dK err=4.88e-04 ≤ 3×9.77e-04 dV err=4.88e-04 ≤ 3×7.32e-04 ====================================================================== Test: B=1, H=32, M=1024, N=1024, D=16, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 43.8 MB, PyTorch: 691.5 MB (Δ: -647.8 MB, -93.7%) (fwd): Custom: 0.35ms, PyTorch: 2.01ms (5.72x speedup) (bwd): Custom: 0.76ms, PyTorch: 5.09ms (6.72x speedup) (tot): Custom: 1.11ms, PyTorch: 7.10ms (6.40x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.95e-03 ≤ 3×3.91e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×1.95e-03 ====================================================================== Test: B=1, H=16, M=1024, N=1024, D=32, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 43.5 MB, PyTorch: 370.4 MB (Δ: -326.9 MB, -88.3%) (fwd): Custom: 0.25ms, PyTorch: 0.93ms (3.74x speedup) (bwd): Custom: 0.69ms, PyTorch: 2.37ms (3.43x speedup) (tot): Custom: 0.94ms, PyTorch: 3.30ms (3.51x speedup) Validation: (Fwd): dO err=2.44e-04 ≤ 2×7.32e-04 (Bwd): dQ err=2.44e-04 ≤ 3×1.22e-03 dK err=2.44e-04 ≤ 3×1.22e-03 dV err=2.44e-04 ≤ 3×9.77e-04 ====================================================================== Test: B=1, H=16, M=1024, N=1024, D=32, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 43.5 MB, PyTorch: 371.4 MB (Δ: -327.9 MB, -88.3%) (fwd): Custom: 0.18ms, PyTorch: 1.26ms (7.09x speedup) (bwd): Custom: 0.45ms, PyTorch: 3.00ms (6.61x speedup) (tot): Custom: 0.63ms, PyTorch: 4.26ms (6.75x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=9.77e-04 ≤ 3×1.95e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×3.91e-03 ====================================================================== Test: B=1, H=32, M=1024, N=1024, D=32, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 66.8 MB, PyTorch: 720.5 MB (Δ: -653.8 MB, -90.7%) (fwd): Custom: 0.46ms, PyTorch: 1.44ms (3.16x speedup) (bwd): Custom: 1.38ms, PyTorch: 3.93ms (2.85x speedup) (tot): Custom: 1.84ms, PyTorch: 5.37ms (2.93x speedup) Validation: (Fwd): dO err=2.44e-04 ≤ 2×1.22e-03 (Bwd): dQ err=4.88e-04 ≤ 3×1.46e-03 dK err=4.88e-04 ≤ 3×1.46e-03 dV err=4.88e-04 ≤ 3×1.10e-03 ====================================================================== Test: B=1, H=32, M=1024, N=1024, D=32, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 70.8 MB, PyTorch: 725.5 MB (Δ: -654.8 MB, -90.2%) (fwd): Custom: 0.30ms, PyTorch: 2.07ms (6.89x speedup) (bwd): Custom: 0.82ms, PyTorch: 5.27ms (6.46x speedup) (tot): Custom: 1.12ms, PyTorch: 7.34ms (6.58x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.46e-03 ≤ 3×2.93e-03 dK err=1.95e-03 ≤ 3×2.93e-03 dV err=1.95e-03 ≤ 3×1.95e-03 ====================================================================== Test: B=1, H=16, M=1024, N=1024, D=64, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 70.5 MB, PyTorch: 404.4 MB (Δ: -333.9 MB, -82.6%) (fwd): Custom: 0.34ms, PyTorch: 1.02ms (2.97x speedup) (bwd): Custom: 1.00ms, PyTorch: 2.63ms (2.63x speedup) (tot): Custom: 1.34ms, PyTorch: 3.65ms (2.72x speedup) Validation: (Fwd): dO err=1.22e-04 ≤ 2×4.88e-04 (Bwd): dQ err=4.88e-04 ≤ 3×7.32e-04 dK err=4.88e-04 ≤ 3×7.32e-04 dV err=2.44e-04 ≤ 3×4.88e-04 ====================================================================== Test: B=1, H=16, M=1024, N=1024, D=64, causal=True ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 70.5 MB, PyTorch: 405.4 MB (Δ: -334.9 MB, -82.6%) (fwd): Custom: 0.24ms, PyTorch: 1.38ms (5.73x speedup) (bwd): Custom: 0.69ms, PyTorch: 3.27ms (4.76x speedup) (tot): Custom: 0.93ms, PyTorch: 4.65ms (5.01x speedup) Validation: (Fwd): dO err=9.77e-04 ≤ 2×1.95e-03 (Bwd): dQ err=1.95e-03 ≤ 3×1.95e-03 dK err=1.95e-03 ≤ 3×1.95e-03 dV err=1.95e-03 ≤ 3×1.95e-03 ====================================================================== Test: B=1, H=32, M=1024, N=1024, D=64, causal=False ✅ Forward match OK ✅ Backward match OK Performance: (Mem): Custom: 116.8 MB, PyTorch: 784.5 MB (Δ: -667.8 MB, -85.1%) (fwd): Custom: 0.57ms, PyTorch: 1.74ms (3.04x speedup) (bwd): Custom: 1.94ms, PyTorch: 4.81ms (2.48x speedup) (tot): Custom: 2.51ms, PyTorch: 6.54ms (2.61x speedup) Validation: (Fwd): dO err=2.44e-04 ≤ 2×4.88e-04 (Bwd): dQ err=4.88e-04 ≤ 3×1.46e-03 dK err=4.88e-04 ≤ 3×1.46e-03 dV err=4.88e-04 ≤ 3×9.77e-04 ====================================================================== Test: B=1, H=3
查看原文

相似文章

构建了一个AI加速器并将其开源。[P]

Reddit r/MachineLearning

作者开源了一个在FPGA上实现的自定义AI加速器(atik),原生支持BF16和注意力机制,展示了在各种模型上相比PyTorch的显著加速效果。