集合通信操作和其实现

10 min

集合通信(Collective Communication) 描述的是多个设备之间如何协同交换和聚合数据,其核心操作包括:Broadcast(一个设备向所有设备扩散数据)、Gather/Scatter(数据的收集与分发)、Reduce(对多个设备上的数据执行求和等归约计算)以及它们的 All 版本(所有设备最终都获得结果)。这些操作本质上是在回答两个问题:数据该流向哪里(一个 / 所有),以及 数据在流动过程中是否需要被计算(Reduce)。在实际系统中,这些逻辑语义通常通过 Ring 通信实现:所有设备首尾相连形成环,每一步只与左右邻居通信,从而形成高带宽、可流水化的数据传输。基于这种思想,AllReduce 可以进一步分解为 ReduceScatter + AllGather:先让不同设备分别完成不同数据分片的归约计算,再将结果重新同步给所有设备。这种分解避免了单点聚合带来的瓶颈,也是现代分布式训练(如 DDP、ZeRO、Megatron 等)中的核心通信基础。对于 pp 个设备和大小为 NN 的 tensor,Ring AllReduce 的总通信量约为 2p1pN2N2\frac{p-1}{p}N \approx 2N,这也是为什么 AllReduce 常被认为需要约两倍 tensor 大小的通信开销。

基础概念定义

我们定义每个设备为一个 rank.

在并行集合通信中,

  • All 表示通信的 dst 是所有设备
  • Reduce 表示对于数据执行 associative/commutative 计算(例如求和/求平均)
  • Gather 表示将“分散”(在各个设备)的数据 shard 合并起来
  • Scatter 则是 Gather 的反面,将完整的数据分块分发给多个设备

集合通信操作定义

假设我们需要对于 XX 数据进行操作,X[0:3]X[0:3] 表示数据的一部分。

Broadcast:一个 -> 所有(扩散)

情景:在所有 rank 中,只有一个 rank 持有完整数据 XX,而其他 rank 不持有该数据。

目标:我们希望让所有 rank 都持有 XX.

因此,我们需要将该 rank 上的数据复制并分发到所有其他设备,使数据从“单点”扩散到“全体”。这个过程称为 广播(Broadcast)

|  D0  |  D1  |  D2  |  D3  |
|      |      | data |      |
--copy-->
| data | data | data | data |

Gather:所有 -> 一个(收敛)

情景:在所有 rank 中,每个 rank 持有数据 XX 的一部分(例如 X[i]X[i] 表示第 ii 个分片)。

目标:将这些分散在不同设备上的数据片段汇总到一个指定的 rank 上,恢复出完整的数据 XX

因此,需要将各个 rank 上的数据发送到目标设备,并按照原有顺序进行拼接,使数据从“分散”重新汇聚为“整体”。这个过程称为 收集(Gather)

|    D0     |    D1     |    D2     |    D3     |
| data[0]   | data[1]   | data[2]   | data[3]   |
--concatenate-->
|           |           | data[0:3] |           |

AllGather:所有 -> 所有(复制扩散)

情景:在所有 rank 中,每个 rank 持有数据 XX 的一部分(例如 X[i]X[i] 表示第 ii 个分片)。

目标:让每个 rank 都拥有完整的数据 XX

因此,需要让各个 rank 之间交换各自持有的数据分片,并在本地按照顺序进行拼接,使每个设备最终都重建出完整的数据。这个过程称为 全收集(AllGather)

这被称之为 All Gather.

|    D0     |    D1     |    D2     |    D3     |
| data[0]   | data[1]   | data[2]   | data[3]   |
--concatenate-->
| data[0:3] | data[0:3] | data[0:3] | data[0:3] |

Reduce:所有 → 一个(计算 + 收敛)

情景:在所有 rank 中,每个 rank 都持有形状相同的数据 XX(例如不同设备上的梯度或中间结果)。

目标:对这些数据执行某种归约运算(例如求和、求平均等),并将结果保留在一个指定的 rank 上。

因此,需要将各个 rank 上的数据发送到目标设备,并在汇聚的过程中执行归约运算,使多个设备上的数据在逻辑上合并为一个结果。这个过程称为 归约(Reduce)

|  D0  |  D1  |  D2  |  D3  |
| data | data | data | data |
--reduce (e.g. sum/avg)-->
|      |      | data |      | (Reduced)

AllReduce:所有 → 所有(计算 + 扩散)

情景:在所有 rank 中,每个 rank 都持有形状相同的数据 XX(例如各自计算得到的梯度)。

目标:对这些数据执行归约运算(例如求和、求平均),并让所有 rank 都获得相同的归约结果。

因此,需要先对各个 rank 上的数据进行归约计算,将多个设备上的数据融合为一个结果;随后再将该结果分发给所有设备,使每个 rank 最终都持有一致的数据。这个过程称为 全归约(AllReduce)

|  D0  |  D1  |  D2  |  D3  |
| data | data | data | data |
--reduce (e.g. sum/avg)-->
| data | data | data | data | (Reduced)

Reduce Scatter:所有 → 所有(计算 + 分散)

情景:在所有 rank 中,每个 rank 都持有形状相同的完整数据 XX

目标:对这些数据执行归约运算(例如求和、求平均),但最终每个 rank 只保留结果中的一部分(例如第 ii 个 rank 保留第 ii 个分片)。

因此,需要在各个 rank 之间交换数据分片,并在传输过程中对对应位置的数据执行归约计算,使每个分片只在一个设备上完成归约并被保留下来。这个过程称为 归约分散(ReduceScatter)

这被称之为 Reduce Scatter.

|    D0     |    D1     |    D2     |    D3     |
| data[0:3] | data[0:3] | data[0:3] | data[0:3] |
--reduce[0]-|-reduce[1]-|-reduce[2]-|-reduce[3]-|
| data[0]   | data[1]   | data[2]   | data[3]   | (Reduced)

Ring 通信

在上一个章节介绍了通信的逻辑语义,,而在实际系统中,这些操作需要通过具体的通信算法来实现。其中,ring 通信是一种最常见且广泛使用的实现方式。其核心思想是:

  • 将所有 device 连成一个环。假设我们有 4 个设备,这意味着 0 -> 1 -> 2 -> 3 -> 0
  • 系统形成一个 pipeline,在每一步中:
  • 每个 GPU 发送给其 邻居,并且接受来自其 邻居的数据

例子:

|    D0     |    D1     |    D2     |    D3     |
| A[0]      | A[1]      | A[2]      | A[3]      |
    -           -           -           -
| ---A[0]-->| ---A[1]-->| ---A[2]-->| ---A[3]-->|
| A[0,3]    | A[1,0]    | A[2,1]    | A[3,2]    |
      -           -           -           -
| ---A[3]-->| ---A[0]-->| ---A[1]-->| ---A[2]-->|
| A[0,3,2]  | A[1,0,3]  | A[2,1,0]  | A[3,2,1]  |
        -           -           -           -
| ---A[1]-->| ---A[2]-->| ---A[0]-->| ---A[3]-->|
| A[0,3,2,1]| A[1,0,3,2]| A[2,1,0,3]| A[3,2,1,0]|

All-Reduce 计算优化

假设我们有 pp 个设备。回顾一下,All-Reduce 操作

  • 首先,所有设备上都有形状相同的完整的数据 xRNx\in \mathbb{R}^N
  • 我们希望对于所有 xi,i{1,,N}x_{i}, i\in \{1, \dots, N\},跨设备进行 Reduce 操作

我们可以

  • 利用 Reduce-Scatter 的思想,让每个 device 都负责一小部分的计算,即 device ii 负责 [xi/p×N,,xi/p×N+N][x_{i/p \times N}, \dots, x_{i/p \times N+N}] 的 Reduce 计算,
  • 再通过一次 All-Gather 操作将每个 device 得到的 reduced data shard 合并成完整的数据,完成同步

因此,一个重要的结论是:

All-Reduce=Reduce-Scatter+All-Gather\text{All-Reduce} = \text{Reduce-Scatter} + \text{All-Gather}

通信量计算

通信计算:

  • 假设不同 device 之间传递的 tensor size = NNNN 是整个 tensor 的大小且一共有 pp 个 device,每个设备持有 N/pN/p 的数据
  • 一共要进行 p1p-1 个 step
  • 因此每个设备在整个流程中要传输 (p1)N/p(p-1)N/p 的数据

在 All-Reduce 的情景下:

  • ReduceScatter 需要发送 (p1)×N/p(p-1) \times N/p 数据(这是因为进行 Reduce 操作时,需要接收所有其他设备上的数)
  • AllGather 拼接,同理也需要发送 (p1)×N/p(p-1) \times N/p 数据

因此,一共是

2×p1pN2 \times \frac{p-1}{p} N

数据,当 pp 数量很大的时候近似于 2N2N,即一次 All-Reduce 大约需要传输 2 倍 tensor size 的数据量。

总结

我们介绍了 Reduce, Broadcast, All {Gather/Reduce} 和 Reduce Scatter 通信原语,介绍其实现方式 Ring 通信以及如何优化 All-Reduce 操作。