Google Pathways记录
背景
Google 在 2022 年 3 月公开了 Pathways 架构设计。
概况
Google一直是深度学习变革的引领者,从 TensorFlow 到 JAX 再到 Pathways。
其中模型的训练过程中,分为数据并行、模型并行,以及数据模型并行三类。
- 数据并行 (Data Parallel)
- 模型并行 (Model Parallel)
- 数据模型并行(Data-Model Parallel)
[3]
数据并行
在现代深度学习中,因为数据集越来越大,以至于我们很难将其把它全部载入内存,我们通常使用所谓的随机梯度下降法[2],对数据集的每个批次(batch)进行梯度求导。举例而言,如果我们的数据集有10K个数据点,每次我们只从中取出16个数据点去计算根据这个批次得到的对梯度的估计,如果我们尝试一次性去计算所有数据点的梯度,我们的GPU显存极可能无法容纳那么多的数据。[2]
然而,随机梯度下降法的缺点在于其对梯度的估计, 相比于用整个数据集进行梯度计算得到的真实梯度来说,可能不够精确。因此,随机梯度下降法通常需要更多的训练时间去达到模型收敛。
有个自然的做法就是在一个更大的批次尺寸上进行更为精确的梯度估计,甚至我们可以使用整个的数据集大小的批次。为了实现这个目的,我们把这个大批次分割为很多小批次,在每个GPU上计算一个小批次,若干个GPU的梯度估计结果进行汇总后,进行加权平均,最终求和就得到了最终的大批次的梯度估计结果。[2]
模型并行
模型并行性乍一听挺唬人的,但是其实和令人生畏的数学没太大关系。模型并行更多的是一种对计算机资源的分配问题。有时候我们的模型可能太大了,甚至大到不能把整个模型载入一个GPU中,因为其中有着太多的层,太多的参数。因此,我们可以考虑把整个模型按层分解成若干份,把每一份(其中的层是连续的)载入不同的节点中,也即是每个不同的节点计算着整个模型的不同的层,计算着不同的层的梯度。通过这种方法,单个节点的参数量就减少了,并且使得用更为精确的梯度进行计算提供了可能性。[2]
fw注:这里的模型并行的解释可能不太准。但是不影响本文的逻辑,所以暂不去更正了。
JAX
谷歌于 2018 年底推出了 JAX。2020 年,DeepMind 宣布使用 JAX 来加速其研究。越来越多来自谷歌大脑(Google Brain)和其他机构的项目也都在使用 JAX。[4]
JAX的优势:速度。JAX 的速度比 NumPy 快了 N 个数量级。需要注意,JAX 使用的是 TPU,NumPy 使用了 CPU,以此强调 JAX 的速度上限远高于 NumPy。
Pathways
在JAX中,Google倡导SPMD (single program multiple data) ,也就是multi-client,没有所谓的master节点,各个worker 的script是对称的,各个worker 各干各的,但是有协同。multi-client在数据并行和模型并行下非常自然,各个worker就是完全对称的,在有流水并行的情况下,各个worker 执行不同的stage,不对称,SPMD并不是很协调,但multi-client还有另外一些比较微妙的优势。[1]
Single Client,如Hadoop 和 Spark,在single-client下写分布式程序脑力负担还是低一些,特别是解决了auto placement和auto parallelism之后,分布式代码就应该和单卡代码是一样的,只有single client 才会给人那种像写单机代码一样的感觉。
Google 开始倡导SPMD (single program multiple data) ,也就是multi-client,没有所谓的master节点,各个worker 的script是对称的,各个worker 各干各的,但是有协同。multi-client在数据并行和模型并行下非常自然,各个worker就是完全对称的,在有流水并行的情况下,各个worker 执行不同的stage,不对称,SPMD并不是很协调,但multi-client还有另外一些比较微妙的优势。