在训练如 Llama、GPT-4 这类千亿参数级别的大模型时,是否经常遇到这样的问题:前向传播刚刚开始,显存就已经耗尽?
CUDA out of memory
PyTorch 随即抛出内存溢出错误——这几乎是每一位大模型开发者都经历过的常态。
根本原因在于:
单张 GPU 的显存无法容纳整个模型的参数权重。以一个拥有 700 亿参数的模型为例,仅 FP16 精度下的权重就接近 140GB,而目前最高配置的 H100 显卡也只有 80GB 显存。面对这种困境,该如何解决?
虽然可以通过数据并行的方式提升吞吐量,但这种方式只是“表面缓解”——每张卡仍需保存完整的模型副本,显存压力并未减轻。真正有效的解决方案是将模型本身进行拆分,使其分布于多张 GPU 上协同运行。这就是模型并行的核心思想,而其中最为精细且高效的技术之一便是本文要深入探讨的主题:张量并行(Tensor Parallelism)。
可以这样类比:设想你要搬运一尊巨大的石雕,一个人难以承担全部重量。此时,并非寻找更强壮的人,而是将雕像切割成若干部分,每人负责一段,到达目的地后再重新拼合。张量并行正是遵循这一逻辑。它不满足于按层或按批次划分任务,而是深入到最基本的运算单元——矩阵乘法,对 $ Y = XW $ 这类核心操作进行分解。
其基本策略非常直观:
- 当输出维度过大时,将权重矩阵 $ W $ 按列切分;
- 当输入维度较高,或后续需要聚合梯度信息时,则采用行切分方式。
随后通过高效的集合通信机制(如 All-Gather 或 All-Reduce),确保各设备上的计算结果能够正确合并,恢复完整语义。
看似原理简单,但实际上涉及复杂的数学推导、实现细节和系统级优化。下面我们逐步解析两种典型场景。
列切分 + All-Gather:适用于输出扩展场景
考虑全连接层中的标准运算 $ Y = XW $,设输入 $ X \in \mathbb{R}^{b \times m} $,权重 $ W \in \mathbb{R}^{m \times n} $,输出 $ Y \in \mathbb{R}^{b \times n} $。当 $ n $ 很大时(例如 FFN 层中将隐藏维度放大四倍),我们可以考虑让每块 GPU 只计算输出的一部分。
具体做法是:使用 $ P $ 张 GPU,将权重 $ W $ 水平切分为 $ [W_1, W_2, ..., W_P] $,每个子块大小为 $ m \times (n/P) $。每个设备保留完整的输入 $ X $,并独立执行局部计算:
$$ Y_i = X \cdot W_i $$由于所有设备都持有完整的 $ X $,无需切分输入。计算完成后,各设备通过 All-Gather 操作交换各自的输出分片 $ Y_i $,最终拼接成完整输出:
$$ Y = [Y_1, Y_2, \dots, Y_P] $$这个过程类似于拼图游戏:每个人完成一部分画面,最终组合才能呈现整体图像。
import torch
import torch.distributed as dist
def tensor_parallel_linear_forward(x, weight_chunk, rank, world_size):
y_local = torch.matmul(x, weight_chunk) # 局部计算
# 所有设备交换结果
y_list = [torch.empty_like(y_local) for _ in range(world_size)]
dist.all_gather(y_list, y_local)
return torch.cat(y_list, dim=-1) # 拼接
该方法的优点在于计算完全独立,通信阶段统一处理;但缺点也明显:All-Gather 需要传输完整的输出张量,通信开销较大,尤其在 batch size 较大时,$ b \times n $ 的数据量会带来显著的带宽压力。
行切分 + All-Reduce:降低通信负载的替代方案
如果不需要保留各个分片,而只需要得到最终的求和结果,那么可以采用更节省通信资源的方式:行切分配合 All-Reduce。
这一策略常见于 FFN 结构的第二层,或注意力机制后的输出投影层。同样是 $ Y = XW $,此时我们将 $ W $ 垂直切分为:
$$ W = \begin{bmatrix} W_1 \\ W_2 \\ \vdots \\ W_P \end{bmatrix}, \quad W_i \in \mathbb{R}^{(m/P) \times n} $$同时输入 $ X $ 也被划分为 $ [X_1, X_2, \dots, X_P] $。然而,为了保证每台设备能完成局部计算,通常采取一种巧妙的方法:将输入 $ X $ 广播至所有设备,使得每个 GPU 都拥有完整的输入副本。
接着,各设备分别计算局部输出:
$$ Y_i = X \cdot W_i $$最后,通过 All-Reduce 对所有 $ Y_i $ 进行求和:
$$ Y = \sum_{i=1}^P Y_i $$从数学上看,结果等价于 $ Y = X \cdot (W_1 + \cdots + W_P) $,但由于每个设备只存储 $ 1/P $ 的权重,显存占用大幅下降,实现了高效的资源利用。
反向传播的过程也遵循类似的机制:梯度 $ \frac{\partial L}{\partial Y} $ 是完整的,会被广播到所有设备;每个设备基于本地数据计算对应的权重梯度 $ \frac{\partial L}{\partial W_i} = X^T \cdot \frac{\partial L}{\partial Y} $,随后通过 All-Reduce 操作将各设备的梯度结果进行聚合。 代码实现如下:def tensor_parallel_linear_backward(x, grad_output, weight_chunk, rank, world_size):
x_trans = x.transpose(0, 1)
grad_weight_local = torch.matmul(x_trans, grad_output)
# 跨设备求和
dist.all_reduce(grad_weight_local, op=dist.ReduceOp.SUM)
grad_input = torch.matmul(grad_output, weight_chunk.transpose(0, 1))
return grad_input, grad_weight_local
关键在于理解这样一个核心思想:
All-Reduce 不仅能够同步各设备上的梯度信息,还能隐式地完成“输出合并”的功能。这一点在涉及残差连接(Residual Connection)后接 LayerNorm 的结构中尤为重要——因为这类操作要求输入张量必须是完整且未分割的。
因此,可以总结出两种切分方式的应用场景与特点:
| 切分方式 | 典型用途 | 通信操作 | 输入是否复制 |
|---|---|---|---|
| Column-wise | QKV 计算、FFN Up-projection | All-Gather | 否 |
| Row-wise | Output Projection、FFN Down | All-Reduce | 是(广播) |
Input →
[Multi-Head Attention]
├── Q = X @ W_Q → Column-split
├── K = X @ W_K → Column-split
├── V = X @ W_V → Column-split
├── Attention & Softmax
└── Out = attn(Q,K,V) @ W_O → Row-split + All-Reduce
[Feed-Forward Network]
├── Up: X @ W_up → Column-split + All-Gather
├── GELU
└── Down: X @ W_down → Row-split + All-Reduce
这种模式被称为“**Split-Then-Combine**”——先对输入进行切分并在多个设备上并行处理,扩大计算规模;再通过通信操作将结果合并,恢复为完整表示。这种方法既保留了高维中间特征的表达能力,又有效控制了单卡显存的增长。
举个具体例子:假设模型维度 $ d = 4096 $,FFN 层将维度扩展 4 倍至 16384,并使用 4 卡进行张量并行:
- 权重矩阵 $ W_1 \in \mathbb{R}^{4096 \times 16384} $ 被按列切分为 4 块,每卡存储 $ 4096 \times 4096 $
- 各卡独立计算局部输出 $ H_i = X \cdot W_1^{(i)} $
- 然后执行 All-Gather,拼接成完整的激活值 $ H \in \mathbb{R}^{b \times 16384} $
- 接着在本地应用激活函数
- 进入下一级权重 $ W_2 \in \mathbb{R}^{16384 \times 4096} $,该矩阵按行切分,每卡持有 $ 4096 \times 4096 $ 的子块
- 每卡计算 $ O_i = A \cdot W_2^{(i)} $,最后通过 All-Reduce 得到最终输出
整个流程形成一个完美的闭环。
当然,真实场景往往更加复杂,部署过程中常会遇到以下挑战:
显存不足?
→ 引入梯度检查点(Gradient Checkpointing),用计算时间换取内存空间。
→ 或采用 Sequence Parallelism,将序列长度维度也进行切分。
通信成为瓶颈?
→ 使用 Ring-AllReduce 替代传统 All-Reduce,降低峰值带宽压力。
→ 开启计算与通信重叠(overlap communication with computation),实现边计算边传输。
负载不均衡?
→ 实施动态分片(Dynamic Sharding),根据序列长度自适应调整任务分配。
→ 采用异构调度策略,让性能更强的设备承担更多计算任务。
如何与其他并行策略协同工作?
这才是构建大规模训练系统的关键所在!尽管 Tensor Parallelism 具备细粒度拆分的能力,但它并非万能。现代分布式训练普遍采用 **3D 并行** 架构:
- **Tensor Parallelism**:按算子内部结构切分(细粒度)
- **Pipeline Parallelism**:按网络层数切分(粗粒度)
- **Data Parallelism**:按批次数据切分(横向扩展)
三者协同运作,才能真正支撑起千亿参数模型的高效训练。
[此处为图片4的位置已随内容调整]
最后说点通俗的话:你为什么应该重视 Tensor Parallelism?
因为它不仅仅是一个技术术语,更是一种解决问题的思维方式——
**当资源受限时,如何优雅地把一个大问题分解为多个小任务,并通过协作高效完成整体目标**。
这不仅是分布式深度学习的核心理念,也是现代 AI 工程体系的底层逻辑之一。
随着模型规模持续增长、硬件环境日益异构(CPU/GPU/TPU混合)、对推理延迟的要求不断提高,Tensor Parallelism 正在不断演进:支持稀疏激活、动态路由、自动化的切分策略选择等高级特性。
或许在不远的将来,我们调用一次跨数十甚至上百设备的并行矩阵运算,会像现在调用一个普通 matmul 函数一样自然和透明。而这一切的背后,正是这些看似“枯燥”的并行策略在默默支撑。
所以,请不要轻视那一次行切、那一列分。它们虽小,却是通往 AGI 之路的重要基石。

雷达卡


京公网安备 11010802022788号







