从第一性原理出发让深度学习飞速运行
摘要
一篇综合性的博客文章,解释如何通过理解三个关键组成部分来优化深度学习性能:计算、内存带宽和开销,利用第一性原理识别性能区间并专注于有效的优化。
暂无内容
查看缓存全文
缓存时间: 2026/05/23 12:30
# 让深度学习‘Brrrr’起来:从第一性原理出发
来源:https://horace.io/brrr_intro.html
那么,你想提升深度学习模型的性能。如何着手这项任务呢?通常,人们会搬出一套以前可能有效过的技巧,或者是在推特上看到的“秘籍”:“使用原地操作!”、“把梯度设为None!”、“安装PyTorch 1.10.0,但别装1.10.1!”
用户常常采用这种临时拼凑的方法是可以理解的,因为在现代系统(尤其是深度学习)上获取性能,感觉上既像科学,又像炼金术。话虽如此,从第一性原理进行推理仍能排除大量不相关的方法,从而使问题更易处理。
例如,用深度学习在数据集上获得良好性能也需要大量猜测。但如果你的训练损失远低于测试损失,你就处于“过拟合”状态,此时尝试增加模型容量只会浪费精力。或者,如果训练损失与验证损失完全相同,那么尝试正则化模型也是浪费时间。
同样,你可以把深度学习场景的效率理解为由三个不同部分组成:
1. 计算:GPU在处理实际浮点运算(FLOPS)上所花的时间
2. 内存:在GPU内传输张量所花的时间
3. 开销:其余一切
与训练ML模型类似,了解自己处于哪种状态,就能把精力集中在真正重要的优化上。例如,如果你把所有时间都花在内存传输上(即处于**内存带宽受限**状态),那么提高GPU的FLOPS也无济于事。另一方面,如果你把所有时间都花在处理大型矩阵乘法上(即**计算受限**状态),那么用C++重写模型逻辑以减少开销也毫无帮助。
所以,如果你想让你的GPU持续“Brrrr”般高速运转,我们来讨论一下你的系统可能花费时间的三个组成部分——计算、内存带宽和开销。
*苦涩教训背后,是一群让GPU高效运行的工程师。图片来自 Gwern (https://www.gwern.net/Scaling-hypothesis#if_slide_2)*
注意:本文大部分内容将以GPU和PyTorch为例(因为我在PyTorch团队工作),但其中的原理几乎都适用于所有硬件和框架。
优化深度学习系统的一个视角是,我们希望最大化处于计算受限状态的时间。你为那312万亿次浮点运算付了钱,理想情况下,你理应获得这312万亿次浮点运算的性能。但要想让昂贵的矩阵乘法物有所值,就需要减少花在其他部分的时间。
但为什么重点关注最大化计算利用率,而不是内存带宽呢?原因很简单——你可以减少开销或内存成本,但除非改变实际执行的操作,否则(基本上)无法减少所需的计算量。
使最大化计算利用率更加困难的是计算增长速度相比内存带宽的差异。看看这张关于CPU FLOPS翻倍时间与内存带宽翻倍时间的表格
一种思考计算的方式是将其视为工厂。我们向工厂发送指令(开销),为它运送原材料(内存带宽),所有这一切都是为了保持工厂高效运转(计算)。
因此,如果工厂的效率提升速度超过了供应原材料的速度,工厂就越难达到峰值效率。
*尽管工厂规模(FLOPS)翻倍了——但如果带宽跟不上,性能也不会翻倍*
这种利用计算的难度日益增加,不仅意味着ML系统工程师有了永久的工作保障,也使得理解瓶颈变得更加重要。
关于FLOPS还有一个补充说明。现代机器学习加速器都拥有专门用于矩阵乘法的硬件,例如NVIDIA的“Tensor Core”。
因此,如果你没有进行矩阵乘法,就只能达到19.5万亿次浮点运算,而不是宣称的312万亿次。请注意,这并非GPU独有——事实上,TPU比GPU更加不通用。
GPU在除了矩阵乘法之外的所有操作上慢得多的现象,初看起来可能有问题——那么我们其他的算子(如层归一化或激活函数)怎么办呢?事实是,这些算子就FLOPS而言只是舍入误差。例如,看看这张来自这篇论文的BERT中不同类型算子的FLOP计数表,其中“张量收缩”就是矩阵乘法。
你可以看到,总共非矩阵乘法的算子只占FLOPS的0.2%,所以GPU计算这些算子慢15倍并没有问题。
但在这个例子中,归一化和逐点算子实际只达到了矩阵乘法**250倍和700倍更少的FLOPS**。
那么为什么我们的非矩阵乘法算子花费的时间远超预期呢?
回到我们的类比,罪魁祸首往往是原材料运往工厂和从工厂运出的时间。换句话说,就是内存带宽。
## 带宽
带宽成本本质上就是将数据从一个地方移动到另一个地方的成本。这可能包括将数据从CPU移动到GPU、从一个节点移动到另一个节点,甚至从CUDA全局内存移动到CUDA共享内存。我们这里重点讨论的是最后一种情况,通常称为“带宽成本”或“内存带宽成本”。
另外两种(通常分别称为“数据传输成本”和“网络成本”)当然也很重要,但深入讨论分布式性能的话,这篇文章就永远写不完了。
要理解内存带宽成本是什么,让我们回到工厂类比。
虽然工厂是我们实际工作的地方,但它不适合作为散装存储单元。很大程度上是因为既然我们在那里做实际工作,所有的存储都优化为便于快速**使用**(SRAM),而不是拥有很大的容量。
那么,实际结果和原材料存储在哪里呢?通常的做法是有一个仓库,可能位于土地便宜且空间充足的地方(DRAM)。然后,我们将物资运往工厂和从工厂运出(内存带宽)。
这种将数据进出计算单元的成本就是所谓的“内存带宽”成本。顺便说一句,你的GPU的DRAM就是`nvidia-smi`中显示的内容,也是导致你喜闻乐见的“CUDA内存不足”错误的主要因素。
需要注意的一点是,每当我们执行一个GPU内核,我们都必须将数据从GPU的DRAM移出并移回。
现在,想象一下执行像`torch.cos`这样的单元操作时会发生什么。我们需要将数据从存储运送到仓库,然后对每块数据执行一小点计算,然后再把数据运送回去。运输成本非常高。结果,我们几乎所有的时间都花在运输数据上,而不是实际计算本身。
由于我们所有的时间都花在内存带宽上,这种操作称为**内存受限操作**,这意味着我们并没有花太多时间在计算上。
好吧,这并不理想。我们能做些什么呢?让我们看看一系列算子可能的样子。
*这是一系列逐点算子可能的样子。*
嘿!这安排太蠢了。为什么要反复将相同的数据发送到全局内存再送回计算单元呢?我们应该把数据留在工厂,执行完所有计算,然后再送回去!
*与其把三角形(比喻)送回全局内存再读回来,不如一次性完成所有操作。*
这就是**算子融合**——深度学习编译器中最关键的优化。简单来说,与其将数据写入全局内存只是为了再次读取,我们通过同时执行几个计算来省略多余的内存访问。
例如,如果我们执行`x.cos().cos()`,通常需要进行4次全局读写。
```
x1 = x.cos()
x2 = x1.cos()
```
但通过算子融合,我们只需要2次全局内存读写!因此算子融合会将其加速2倍。
```
x2 = x.cos().cos()
```
好多了。
有几个注意事项让这变得有些棘手。首先,GPU在执行当前操作时需要知道接下来会发生什么。因此,在PyTorch的eager模式下(一次运行一个算子)无法进行这种优化。其次,我们实际上需要为这个融合生成CUDA代码,这又会打开一个全新的问题。
并非所有算子融合都像逐点算子那样简单。你可以将逐点算子融合到归约中,或者融合到矩阵乘法中。甚至矩阵乘法本身也可以看作是融合了广播乘法然后进行归约。
如果你有兴趣编写自定义CUDA内核,很可能在这里会获得最大的收益。任意两个PyTorch算子都提供了融合的机会,从而节省了它们之间读取/写入全局内存的内存带宽成本。此外,许多现有的编译器通常可以执行“简单”的融合——NVFuser和XLA就是两个例子。然而,自动化系统无法与人类的创造力相匹敌,所以如果你想尝试自己编写一些自定义CUDA内核,Triton是一个很好的起点。
最后,算子融合会带来一些令人惊讶的结果。例如,融合后的`x.cos().cos()`与单独调用`x.cos()`所花费的时间几乎完全相同。这就是为什么激活函数几乎成本相同,尽管`gelu`显然比`relu`包含更多操作。
这个事实导致了一些关于重新物化/激活检查点的有趣后果。本质上,进行额外的重新计算可能会导致**更少**的内存带宽,从而减少运行时间。因此,我们可以通过重新物化来同时降低内存和运行时间,我们利用这一点在AOTAutograd中构建了一个巧妙的min-cut优化环节。你可以在这里了解更多详情(也可能会在未来的博客文章中进一步探讨!)
#### 推理内存带宽成本
当需要判断你的操作是否受内存带宽限制时,一个计算器能派上大用场。
对于简单算子,直接推理内存带宽是可行的。例如,一个A100拥有每秒1.5TB的全局内存带宽,并且可以每秒执行19.5万亿次浮点运算。所以,如果你使用32位浮点数(即4字节),在GPU可以执行20万亿次运算的同一时间内,你可以加载4000亿个数字。此外,要执行一个简单的单元操作(比如将张量乘以2),我们实际上还需要将张量**写回**全局内存。
因此……在你执行大约一百次单元操作之前,花费在内存访问上的时间会多于实际计算的时间。
借助像NVFuser这样的融合编译器,我们实际上可以很容易地自己衡量这一点!你可以在Colab上查看代码。
如果你采用一个PyTorch函数,比如
```python
def f(x: Tensor[N]):
for _ in range(repeat):
x = x * 2
return x
```
并用融合编译器对其进行基准测试,我们就可以计算不同`repeat`值下达到的FLOPS和内存带宽。增加`repeat`是一种简单的增加计算量**而不增加**内存访问次数的方法——这也被称为提高**计算强度**。
具体来说,假设我们对这段代码进行基准测试,并找到每秒执行的迭代次数。然后,作为N(张量大小)的函数,我们将执行`2*N`次内存访问和`N * repeat`次FLOP。因此,达到的内存带宽为`bytes_per_elem * 2 * N * itrs_per_second`,达到的FLOPS为`N * repeat * itrs_per_second`。
现在,让我们将运行时间、FLOPS和达到的内存带宽绘制为计算强度的函数。注意所有轴都是对数尺度。
首先,注意运行时间在达到64次乘法之前**几乎没有明显增加**。这意味着在那之前,我们主要受内存带宽限制——我们的计算大部分处于空闲状态。
结果,我们开始时仅达到可怜的0.2万亿次浮点运算。随着计算强度每翻一倍,这个数字线性增长,直到接近我们9.75万亿次浮点运算的峰值[\[1\]](https://horace.io/brrr_intro.html#fn1)。一旦接近峰值FLOPS,我们就被认为是“计算受限”的。
最后,你可以看到达到的内存带宽从峰值附近开始,随着计算强度增加而下降。这完全符合预期,因为我们花在实际计算上的时间越来越多,而花在内存访问上的时间越来越少。
在这种情况下,很容易看出何时受计算限制以及何时受内存限制。当`repeat < 32`时,我们的内存带宽饱和而计算未充分利用。相反,当`repeat > 64`时,我们看到计算饱和(即接近峰值FLOPS),而使用的内存带宽开始下降。
对于更大的系统,往往更难判断是受计算限制还是内存带宽限制,因为通常它们同时包含计算受限和内存受限的组件。
衡量受计算限制程度的一种常用方法是计算已达到的FLOPS占峰值FLOPS的百分比。例如,如果你达到了峰值FLOPS的80%,那么你就知道至少80%的时间是受计算限制的,这已经相当不错了!其余时间很可能花在内存带宽操作上。[\[2\]](https://horace.io/brrr_intro.html#fn2)
然而,除了内存带宽成本之外,还有一件事可能导致你的GPU无法“Brrrr”起来。
## 开销
开销是指你的代码花费在**既不是传输张量也不是计算**上的时间。例如,花在Python解释器上的时间?开销。花在PyTorch框架上的时间?开销。花在启动CUDA内核(但不执行它们)上的时间?也是……开销。
开销之所以如此棘手,主要原因是现代GPU**真的很快**。一个A100每秒可以执行312万亿次浮点运算(312 TeraFLOPS)。相比之下,Python**真的慢得离谱**。在本地测试中,Python一秒钟只能执行3200万次加法。
这意味着,在Python执行**一次浮点运算**的时间内,A100可能已经消化掉了**975万次浮点运算**。
更糟的是,Python解释器甚至不是唯一的开销来源——像PyTorch这样的框架在到达实际内核之前还有多层分发。如果用PyTorch做同样的实验,我们每秒只能获得28万次操作。当然,PyTorch不是为了处理小张量而设计的,但……如果你在使用小张量(如科学计算中),你可能会发现PyTorch相比C++慢得惊人。
例如,看看这个PyTorch执行一次加法的火焰图配置。那个框吗?那就是实际执行计算的部分。其他一切都是纯粹的开销。
鉴于此,你可能
相似文章
@bqbrady: https://x.com/bqbrady/status/2064055370809778371
一篇关于现代深度学习的详细个人综述,聚焦于基础模型、视觉语言模型及其架构决策,面向那些希望获得直觉而非密集数学的读者。
深度表示学习的原理与实践:或记忆的数学理论
本书提出了深度表示学习的数学理论,旨在利用优化和信息论揭开大型深度网络内部机制的神秘面纱,使架构设计成为线性代数和微积分的问题。
本地 AI 硬件内存带宽(2026 年版)
本文深入解析内存带宽作为本地 AI 硬件性能的关键指标,对比了 NVIDIA、Apple、AMD、Intel 等厂商在不同性能层级下的当前 GPU 与统一内存系统。
@harshbhatt7585: https://x.com/harshbhatt7585/status/2063593933314113587
作者分享了从头训练一个160M参数大语言模型的经验,尝试了多种架构,如多Token预测和分层推理模型。他强调快速迭代、简化思路以及理解架构有效原因的重要性。
@ManningBooks: PyTorch 能带你走得很远,但当性能成为问题时,了解 GPU 层面的情况就至关重要…
为 Elliot Arledge 所著的《CUDA for Deep Learning》一书做的推广帖子,提供第一章总结视频,讲解 GPU 性能、CUDA 编程模型,以及何时需要编写自定义 CUDA 内核。