作者|郭冉、李一鹏、柳俊丞、袁进辉
常用深度学习框架的自动并行机制还不够完善,还需要用户根据经验来配置并行方式,这给开发者带来了不小的智力负担。因此,实现自动最优并行就成为一个有趣的课题。
矩阵乘是深度学习最常用的底层计算原语,譬如卷积算子,注意力机制都是通过矩阵乘来实现的,所以大规模神经网络的并行实现大多数时候也是在处理分布式矩阵乘。本文就以如何最优地实现分布式矩阵乘为例来展示自动并行的解决思路。
1
如何实现最优的分布式矩阵乘?
通过上一篇文章《手把手推导 Ring all-reduce 的数学性质》我们知道了常见集群通信操作的通信量和所需通信时间的数学性质,在这篇文章里我们看看怎么使用这些性质来选择最优的并行矩阵乘策略。
在文章《如何超越数据并行和模型并行:从GShard 谈起》,我们介绍了如何从一般的数据并行、模型并行提炼出最一般性的算子并行的抽象表示 SBP。
假设我们希望在4张显卡(2台服务器,每台服务器上有2张显卡)上完成一个矩阵乘\( X\times W=Y \),也就是\(y_{ij}=\sum_{k}{x_{ik}\times w_{kj}}\),其中\( X \)和\( W \)按照特定的 SBP 签名被摆放(place)到4张显卡上,那么将有多个方式实现分布式矩阵乘,它们在数学上等价,不过需要调用的集群通信操作不同,从而触发的通信代价也不同。
沿用《手把手推导 Ring all-reduce 的数学性质》里的符号,\( p \)表示设备数,\( V \)表示矩阵大小 \( V_{x} \)表示矩阵\( X \)的大小,\( V_{w} \)表示矩阵\( W \)的大小), \( \beta \)表示传输带宽。
2
数据并行还是模型并行?
图 1:基于1D 矩阵乘的数据并行
如果\( X \)和\( W \)的SBP签名分别是\( S(0) \)和 \( B \) ,那么可以推导出来\( Y \)的 SBP是\( S(0) \),也就是左矩阵\( X \)是行划分,右矩阵\( W \)是在各个卡上是一模一样的拷贝(broadcast)。如果\( X \)表示特征数据 (feature map), \( W \)表示模型参数,那么这是一个典型的数据并行,下面我们分析一下数据并行的通信代价。
数据并行的反向需要执行集群通信操作 all-reduce,如果采用环状算法,那么所有设备间的数据传输量是\( 2(p-1)V_{w} \) ,执行时间是 \( \frac{2(p-1)V_{w}}{p\beta} \)。
图 2:基于输出层神经元划分的模型并行
如果\( X \)和\( W \)的SBP签名分别是\( B \)和\( S(1) \),那么可以推导出来\( Y \)的 SBP是\( S(1) \),也就是左矩阵\( X \)在各个卡上是一模一样的拷贝(broadcast),右矩阵\( W \)在各个卡上列划分。如果 \( X \)表示特征数据 (feature map),\( W \)表示模型参数,那么这是一个典型的模型并行,下面我们分析一下这种模型并行的通信代价。
如果\( Y \)以\( S(1) \)的状态参与下游的计算,那么\( Y=X \times W \)本身并不需要引入额外的通信。但假设\( Y \)需要被恢复成和 \( X \)一样的状态(broadcast)参与下游计算,则前向计算时需要在 \( S(1) \)签名的\( Y \) 上调用 all-gather操作,后向计算时需要在\( Y \)的反向error signal上调用reduce-scatter操作。那么前向和反向总的通信量是\( 2(p-1)V_{y} \),执行时间是\( \frac{2(p-1)V_{y}}{p\beta} \)。
注意,矩阵乘引入的通信量不只是由当前算子决定的,还取决于它所处的上下文;我们这里的分析假设下游的算子需要\( Y \)保持和输入\( X \)一样的SBP签名,在这种情况下讨论不同并行方式的通信量。
图 3:基于输入层神经元划分的模型并行
如果\( X \)和\( W \)的SBP签名分别是\( S(1) \)和 \( S(0) \) ,那么可以推导出来\( Y \)的 SBP是\( P \),也就是左矩阵\( X \)在各个卡上是列划分,右矩阵\( W \)在各个卡上行划分。如果\( X \)表示特征数据 (feature map),\( W \)表示模型参数,那么这也是一个模型并行的方式(只不过是对全连接层的输入神经元划分而来),下面我们分析一下这种模型并行的通信代价。
如果\( Y \)以与\( X \)相同的\( S(1) \)的状态参与下游的计算,则前向计算时需要在\( P \)签名的\( Y \)上调用 reduce-scatter 操作,后向计算时需要在\( Y \)的误差上调用all-gather操作。那么前向和反向总的通信量是\( 2(p-1)V_{y} \),执行时间是\( \frac{2(p-1)V_{y}}{p\beta} \)。
根据以上的分析,数据并行的通信量是\( 2(p-1)V_{w} \),模型并行的通信量是\( 2(p-1)V_{y} \),因此单就这一个矩阵乘而言,到底使用数据并行还是模型并行是比较容易确定的,也就是取决于\( V_{w} \)和\( V_{y} \)哪个大,如果\( V_{w} > V_{y} \),表示权重矩阵的容量大于输出特征数据的容量(譬如超大的全连接层),那么适合模型并行;如果\( V_{w} < V_{y} \),表示表示权重矩阵的容量小于输出特征数据的容量(譬如卷积层),那么适合数据并行。
值得一提的是,在实践中,数据并行和模型并行还不单单由\( V_{w} \)和\( V_{y} \)哪个大来决定,数据并行中all-reduce通信比较容易被反向计算所掩盖,而模型并行的通信不容易被计算掩盖,因此即使\( V_{w} > V_{y} \)了,理论上应该用模型并行了,但当数据并行反向掩盖all-reduce的优势超过模型并行中通信量更小的优势时,使用数据并行还是更优的。这就是问题的复杂之处,最优的并行方式不仅仅是一个代价函数决定的,还和系统具体实现密切相关。
3
高维并行(矩阵乘)是怎么回事?
在英伟达为大规模预训练模型开发的Megatron-LM里,矩阵乘使用了2D并行,譬如同一个算子在机器间使用了数据并行,机器内部使用了模型并行。有一篇论文也提出2D并行来实现矩阵乘An Efficient 2D Method for Training Super-Large Deep Learning Models(https://arxiv.org/pdf/2104.05...)。
2D并行是怎么回事?真的会带来好处吗?为什么呢?我们还没有发现已有文献对这个问题从理论上讨论清楚,希望这篇博客能彻底搞清楚这些问题。
图 4:2D 并行
假设我们有 2 台机器,每台机器 2 个设备,\( X \)在机器间是\( S(0) \),在机器内部是\( B \),而\( W \)在机器间是\( B \),在机器内部是\( S(1) \),计算结果在机器间是\( S(0) \),机器内部是\( S(1) \)。
这个例子里,机器间是数据并行,机器内部是模型并行。
把\( Y \)从 \( {S(0),S(1)} \)转换成和X 一样的\( {S(0),B} \) ,那么前向计算需要每台机器内部执行all-gather,反向需要在每台机器内部执行reduce-scatter,其传输量是 \( 2(\sqrt{p}-1)V_{y} \)。同时机器之间是数据并行,反向计算需要在第1台机器的第1张卡和第2台机器的第1张卡之间,以及第1台机器的第2张卡和第2台机的第2张卡之间分别调用all-reduce,传输量是 \( 2(\sqrt{p}-1)V_{w} \),总的传输量是 \( 2(\sqrt{p}-1)(V_{y}+V_{w} \))。
2D 的all-gather为例,我们再细致的解释一下上面的传输量是怎么推导出来的。
假设一共\( \sqrt{p} \)台机器,每台机器上有\( \sqrt{p} \)个设备,每台机器内部需要在\( \sqrt{p} \)个设备之间完成\( \frac{V_{y}}{\sqrt{p}} \)大小的矩阵,所以每台机器内部的传输量是\( \frac{2(\sqrt{p}-1)V_{y}}{\sqrt{p}} \),一共\( \sqrt{p} \)台机器,因此前向all-gather 传输量是\( 2(\sqrt{p}-1)V_{y} \) 。
图 5:2D 矩阵乘
2 台机器,每台机器 2 个设备,\( X \)在机器间是\( S(0) \) ,在机器内部是 \( S(1) \) ,而 \( W \)在机器间是\( B \) ,在机器内部是 \( S(0) \) ,计算结果在机器间是 \( S(0) \),机器内部是\( P \)。
机器间是数据并行,机器内部是模型并行。
把\( Y \)从\( {S(0),P} \)转换成和 \( X \) 一样的\( {S(0),S(1)} \) ,那么前向计算需要每台机器内部执行reduce-scatter,反向需要在每台机器内部执行all-gather,其传输量是 \( 2(\sqrt{p}-1)V_{y} \) 。同时机器之间是数据并行,反向计算需要在第1台机器的第1张卡和第2台机器的第1张卡之间,以及第1台机器的第2张卡和第2台机器的第2张卡之间分别调用all-reduce,传输量是 \( 2(\sqrt{p}-1)V_{w} \),
总的传输量是 \( 2(\sqrt{p}-1)(V_{y}+V_{w}) \)。
图 6:2D 矩阵乘
图 6 展示了经典的 2D SUMMA 算法的实现。直接按照图6所示的数据分布是无法直接执行矩阵乘的,\( X \)和 \( W \) 在机器内部都需要执行all-gather计算,变成图4所示的数据分布才可以,相应的反向计算需要在机器内部执行reduce-scatter,总的通信量是 \( 2(\sqrt{p}-1)(V_{x}+V_{w}) \) 。
4
高维矩阵乘有什么好处?
我们以图 4 所示的2D 矩阵乘为例来讨论高维矩阵乘相对于1D 矩阵乘带来了什么好处。
首先假设 \( V_{x}=V_{w}=V_{y}=V \),那么 1D 矩阵乘的通信量是 \( 2(p-1)V \),而2D 矩阵乘的通信量是\( 4(\sqrt{p}-1)V \),基本上可以认为,当 \( p>4 \),2D 矩阵乘通信量就小于 1D 矩阵乘的通信量了。
可以推测,如果是3D 矩阵乘,那么通信量是和\( \sqrt[3]{p} \)成正比的。高维矩阵乘的本质是减小了每一个集群通信操作的”宽度“,我们在上一篇博客《手把手推导 Ring all-reduce 的数学性质》推导过集群通信的通信量是和通信宽度成正比的。
5
高维矩阵乘会降低通信时间吗?
细心的朋友可能注意到了,我们在讨论1D矩阵乘的通信代价时,总是同时讨论通信量和通信时间,但是在讨论2D矩阵乘的通信代价时,却只讨论了通信量,没有讨论通信时间。刚才我们也讨论了,高维矩阵乘会降低通信量,那么高维矩阵乘的通信时间也会降低吗?
实际上不会。结论有点违反直觉,为什么呢?原因是:通信量变成原来的 \( \frac{1}{\sqrt{p}} \),但每个设备同时参与多组集群通信,每组集群通信可使用的带宽也变成原来的 \( \frac{1}{\sqrt{p}} \)。下面我们看一个具体的例子。
图 7:DGX-A100 通信拓扑
图7展示了DGX-A100机器的通信拓扑,假设一共有4台机器,每台机器有4个 GPU,每台机器有4张网卡,因此机器之间的带宽是每张网卡带宽的4倍。
图 8:1D 并行的环状通信拓扑
在1D 并行,假设所有 GPU 构成图8 所示的一个大环。机器间通信带宽为 \( \beta=\sqrt{p}\times \beta_{IB} \)(注意:下文的公式和上文公式带宽差一个\( \sqrt{p} \)系数,来源于此),其中 \( \beta_{IB} \) 表示IB网卡带宽,在DGX A100拓扑中机器间 IB 带宽通常小于机器内GPU设备间通信带宽,因此此处整体通信受限于机器间带宽,通信时间为 \( \frac{2(p-1)V}{p\times \sqrt{p}\beta_{IB}} \) (注意:分母需要乘以设备总数\( p \))。
图 9:2D 并行的环状通信拓扑
在2D 并行,以SUMMA 矩阵乘法为例,每行的 4 个GPU设备构成一个环,即[machine 0 : gpu 0, machine 1 : gpu0, machine 2 : gpu 0, machine 3 : gpu0]、[machine 0 : gpu 1, machine 1 : gpu1, machine 2 : gpu 1, machine 3 : gpu1]组成一个环等。每列的4个GPU设备也构成一个环。前向计算时,每个环上都要同时执行 all-gather 操作,跨机器的每个集群通信操作都会占用 \( \frac{1}{\sqrt{p}} \)的网络带宽,也就是\( \beta_{IB} \),机器内部的每个集群通信带宽不是瓶颈所在,因此不影响最终结果。通信时间不难推导,是\( \frac{2(\sqrt{p}-1)V}{p\times \beta_{IB}} \) (这里除以p得到的是每个设备的通信量),和 1D 并行的通信时间 \( \frac{2(p-1)V}{p\times \sqrt{p}\beta_{IB}} \) 是同一个数量级。
至此,我们知道:2D 矩阵乘减小了集群通信的宽度,因此降低了所需要的通信量,但不会降低通信时间。
甚至,在特定的情况下,1D 矩阵乘的通信时间要小于 2D 矩阵乘,这又是为什么呢?
2D 矩阵乘的通信时间是 \( \max{\frac{2(\sqrt{p}-1)V_{1}}{p\beta_{1}},\frac{2(\sqrt{p}-1)V_{2}}{p\beta_{2}}} \)
其中区别了不同的矩阵和不同环的传输带宽。假设 \( \beta_{1} < \beta_{2} \)(机器间带宽小于机器内部带宽),那么 2D 矩阵乘的通信时间至少是
\( \max{\frac{2(\sqrt{p}-1)V_{1}}{p\beta_{2}},\frac{2(\sqrt{p}-1)V_{2}}{p\beta_{2}}} \)
1D 矩阵乘的通信时间是选择数据并行和模型并行中更优的那一个:
\( \min{\frac{2(p-1)V_{1}}{p\sqrt{p}\beta_{1}},\frac{2(p-1)V_{2}}{p\sqrt{p}\beta_{1}}} \)
当 \( V_{1} \) 和\( V_{2} \) 相差比较悬殊时,不妨假设\( V_{1}<V_{2} \),那么 2D 并行通信时间的下界是 \( \frac{2(\sqrt{p}-1)V_{2}}{p\beta_{2}} \) ,而 1D 并行的通信时间是 \( \frac{2(p-1)V_{1}}{p\sqrt{p}\beta_{1}} \),不难得到,当 \( V_{1}<\frac{\beta_{1}}{\beta_{2}}V_{2} \) 时,1D 并行的通信时间一定小于 2D 并行的通信时间。
因此,2D 并行在降低通信量(或者带宽需求)上有优势,1D 并行在降低通信时间上有优势。
一般来说,一个神经网络中同时存在很多类似矩阵乘的算子,算子层次的并行都需要引入通信需求,通信带宽非常充裕,那么就可以放心的使用 1D 并行,这样确保通信时间是最小的;如果通信带宽是瓶颈,那么每一个算子都应该尽可能降低通信量的需求,节省带宽,这样才能让总体的通信时间最小。
2D 并行的带宽需求降低了,但通信时间没有变化,原因是什么呢? 直观的理解是,在2D 并行中一定有一部分带宽是被闲置了。想象一下,一个大环被切成几段,形成几个小环,小环和小环之间的带宽是不需要用的。
6
结语
如果你在GPU上实现过单卡矩阵乘法,那可能对上面2D矩阵乘的示意图很熟悉,没错,在单卡实现矩阵乘时,关键也在于尽可能减小global memory和shared memory之间的数据搬运。
因此,那里也需要做类似于分布式矩阵乘的通信代价分析,分布式是宏观层次的数据搬运,单卡是微观层次的数据搬运,二者在原理上非常相似。实际上,已有文献对分布式矩阵乘的通信代价的理论分析已经非常成熟,本文讨论的2D阵乘或3D矩阵乘的实现方式都已实现了各自拓扑下通信代价的理论下界。
本文只讨论了一个算子并行时的最优策略,其实每个算子的最优策略也和它所处的上下文相关,一个算子不仅仅要考虑那个并行策略对自身是不是有利,还要考虑它的计算结果对周围的算子是不是有利。
因此,给定一个神经网络,它的最优并行策略是一个组合优化问题,如果这个神经网络是链状(chain-structure)的,那么可以证明,使用动态规划算法就可以在多项式时间内求出全局最优解,当神经网络的结构不是链状时,就无法使用动态规划,就需要一系列手段尽可能降低搜索空间的规模。
auto-placement和auto-parallelism是业界广泛关注的一个热点问题。很多研究工作直接就把问题形式化成一个组合优化的问题,但比较少讨论分布式深度学习自身的数学规律。
OneFlow团队在研究过程中发现,如果能对问题本身的数学性质做深入的理论分析,充分利用这些理论性质,auto-placement和auto-parallelism的求解可以出乎意料的简单。
迄今为止,我们应该对数据并行和模型并行讨论得很深入了,未来,我们会对流水并行的理论性质展开讨论。
正如本文在讨论1D并行和2D并行实现时所画的各种示意图所示,不同的数据切分方式带来不同的并行方式,也带有不同的通信代价。有些切分方式并不直观,怎么才能从理论上保证一种切分方式是正确的?怎么才能穷尽所有理论上正确的切分方式?
OneFlow SBP提供了一种很强大的数学抽象,不仅可以用来分析1D矩阵乘,还可以很方便地分析2D矩阵乘,大大简化了分析这些复杂问题的难度。强烈推荐做这方面工作的小伙伴儿都来用这套工具。
如果想更具体了解SBP如何在分布式模型训练里发挥威力,可以参照 OneFlow 发布的LiBai(
https://github.com/Oneflow-In...) ,仅仅1万行核心代码就实现了NVIDIA Megatron-LM和Microsoft DeepSpeed需要五六倍代码量才能实现的功能。
欢迎下载体验 OneFlow v0.7.0 最新版本:
https://github.com/Oneflow-In...
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。