集合通信操作和其实现
集合通信(Collective Communication) 描述的是多个设备之间如何协同交换和聚合数据,其核心操作包括:Broadcast(一个设备向所有设备扩散数据)、Gather/Scatter(数据的收集与分发)、Reduce(对多个设备上的数据执行求和等归约计算)以及它们的 All 版本(所有设备最终都获得结果)。这些操作本质上是在回答两个问题:数据该流向哪里(一个 / 所有),以及 数据在流动过程中是否需要被计算(Reduce)。在实际系统中,这些逻辑语义通常通过 Ring 通信实现:所有设备首尾相连形成环,每一步只与左右邻居通信,从而形成高带宽、可流水化的数据传输。基于这种思想,AllReduce 可以进一步分解为 ReduceScatter + AllGather:先让不同设备分别完成不同数据分片的归约计算,再将结果重新同步给所有设备。这种分解避免了单点聚合带来的瓶颈,也是现代分布式训练(如 DDP、ZeRO、Megatron 等)中的核心通信基础。对于 个设备和大小为 的 tensor,Ring AllReduce 的总通信量约为 ,这也是为什么 AllReduce 常被认为需要约两倍 tensor 大小的通信开销。

基础概念定义
我们定义每个设备为一个 rank.
在并行集合通信中,
- All 表示通信的 dst 是所有设备
- Reduce 表示对于数据执行 associative/commutative 计算(例如求和/求平均)
- Gather 表示将“分散”(在各个设备)的数据 shard 合并起来
- Scatter 则是 Gather 的反面,将完整的数据分块分发给多个设备
集合通信操作定义
假设我们需要对于 数据进行操作, 表示数据的一部分。
Broadcast:一个 -> 所有(扩散)
情景:在所有 rank 中,只有一个 rank 持有完整数据 ,而其他 rank 不持有该数据。
目标:我们希望让所有 rank 都持有 .
因此,我们需要将该 rank 上的数据复制并分发到所有其他设备,使数据从“单点”扩散到“全体”。这个过程称为 广播(Broadcast)。
| D0 | D1 | D2 | D3 |
| | | data | |
--copy-->
| data | data | data | data |Gather:所有 -> 一个(收敛)
情景:在所有 rank 中,每个 rank 持有数据 的一部分(例如 表示第 个分片)。
目标:将这些分散在不同设备上的数据片段汇总到一个指定的 rank 上,恢复出完整的数据 。
因此,需要将各个 rank 上的数据发送到目标设备,并按照原有顺序进行拼接,使数据从“分散”重新汇聚为“整体”。这个过程称为 收集(Gather)。
| D0 | D1 | D2 | D3 |
| data[0] | data[1] | data[2] | data[3] |
--concatenate-->
| | | data[0:3] | |AllGather:所有 -> 所有(复制扩散)
情景:在所有 rank 中,每个 rank 持有数据 的一部分(例如 表示第 个分片)。
目标:让每个 rank 都拥有完整的数据 。
因此,需要让各个 rank 之间交换各自持有的数据分片,并在本地按照顺序进行拼接,使每个设备最终都重建出完整的数据。这个过程称为 全收集(AllGather)。
这被称之为 All Gather.
| D0 | D1 | D2 | D3 |
| data[0] | data[1] | data[2] | data[3] |
--concatenate-->
| data[0:3] | data[0:3] | data[0:3] | data[0:3] |Reduce:所有 → 一个(计算 + 收敛)
情景:在所有 rank 中,每个 rank 都持有形状相同的数据 (例如不同设备上的梯度或中间结果)。
目标:对这些数据执行某种归约运算(例如求和、求平均等),并将结果保留在一个指定的 rank 上。
因此,需要将各个 rank 上的数据发送到目标设备,并在汇聚的过程中执行归约运算,使多个设备上的数据在逻辑上合并为一个结果。这个过程称为 归约(Reduce)。
| D0 | D1 | D2 | D3 |
| data | data | data | data |
--reduce (e.g. sum/avg)-->
| | | data | | (Reduced)AllReduce:所有 → 所有(计算 + 扩散)
情景:在所有 rank 中,每个 rank 都持有形状相同的数据 (例如各自计算得到的梯度)。
目标:对这些数据执行归约运算(例如求和、求平均),并让所有 rank 都获得相同的归约结果。
因此,需要先对各个 rank 上的数据进行归约计算,将多个设备上的数据融合为一个结果;随后再将该结果分发给所有设备,使每个 rank 最终都持有一致的数据。这个过程称为 全归约(AllReduce)。
| D0 | D1 | D2 | D3 |
| data | data | data | data |
--reduce (e.g. sum/avg)-->
| data | data | data | data | (Reduced)Reduce Scatter:所有 → 所有(计算 + 分散)
情景:在所有 rank 中,每个 rank 都持有形状相同的完整数据 。
目标:对这些数据执行归约运算(例如求和、求平均),但最终每个 rank 只保留结果中的一部分(例如第 个 rank 保留第 个分片)。
因此,需要在各个 rank 之间交换数据分片,并在传输过程中对对应位置的数据执行归约计算,使每个分片只在一个设备上完成归约并被保留下来。这个过程称为 归约分散(ReduceScatter)。
这被称之为 Reduce Scatter.
| D0 | D1 | D2 | D3 |
| data[0:3] | data[0:3] | data[0:3] | data[0:3] |
--reduce[0]-|-reduce[1]-|-reduce[2]-|-reduce[3]-|
| data[0] | data[1] | data[2] | data[3] | (Reduced)Ring 通信
在上一个章节介绍了通信的逻辑语义,,而在实际系统中,这些操作需要通过具体的通信算法来实现。其中,ring 通信是一种最常见且广泛使用的实现方式。其核心思想是:
- 将所有 device 连成一个环。假设我们有 4 个设备,这意味着
0 -> 1 -> 2 -> 3 -> 0 - 系统形成一个 pipeline,在每一步中:
- 每个 GPU 发送给其 右 邻居,并且接受来自其 左 邻居的数据
例子:
| D0 | D1 | D2 | D3 |
| A[0] | A[1] | A[2] | A[3] |
- - - -
| ---A[0]-->| ---A[1]-->| ---A[2]-->| ---A[3]-->|
| A[0,3] | A[1,0] | A[2,1] | A[3,2] |
- - - -
| ---A[3]-->| ---A[0]-->| ---A[1]-->| ---A[2]-->|
| A[0,3,2] | A[1,0,3] | A[2,1,0] | A[3,2,1] |
- - - -
| ---A[1]-->| ---A[2]-->| ---A[0]-->| ---A[3]-->|
| A[0,3,2,1]| A[1,0,3,2]| A[2,1,0,3]| A[3,2,1,0]|All-Reduce 计算优化
假设我们有 个设备。回顾一下,All-Reduce 操作
- 首先,所有设备上都有形状相同的完整的数据
- 我们希望对于所有 ,跨设备进行 Reduce 操作
我们可以
- 利用 Reduce-Scatter 的思想,让每个 device 都负责一小部分的计算,即 device 负责 的 Reduce 计算,
- 再通过一次 All-Gather 操作将每个 device 得到的 reduced data shard 合并成完整的数据,完成同步
因此,一个重要的结论是:

通信量计算
通信计算:
- 假设不同 device 之间传递的 tensor size = , 是整个 tensor 的大小且一共有 个 device,每个设备持有 的数据
- 一共要进行 个 step
- 因此每个设备在整个流程中要传输 的数据
在 All-Reduce 的情景下:
- ReduceScatter 需要发送 数据(这是因为进行 Reduce 操作时,需要接收所有其他设备上的数)
- AllGather 拼接,同理也需要发送 数据
因此,一共是
数据,当 数量很大的时候近似于 ,即一次 All-Reduce 大约需要传输 2 倍 tensor size 的数据量。
总结
我们介绍了 Reduce, Broadcast, All {Gather/Reduce} 和 Reduce Scatter 通信原语,介绍其实现方式 Ring 通信以及如何优化 All-Reduce 操作。