楼主: 一踹出
84 0

Tensor Parallelism拆分矩阵运算 [推广有奖]

  • 0关注
  • 0粉丝

等待验证会员

学前班

80%

还不是VIP/贵宾

-

威望
0
论坛币
0 个
通用积分
0.7987
学术水平
0 点
热心指数
0 点
信用等级
0 点
经验
30 点
帖子
2
精华
0
在线时间
0 小时
注册时间
2018-9-11
最后登录
2018-9-11

楼主
一踹出 发表于 2025-11-24 13:11:05 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币

在训练如 Llama、GPT-4 这类千亿参数级别的大模型时,是否经常遇到这样的问题:前向传播刚刚开始,显存就已经耗尽?

CUDA out of memory

PyTorch 随即抛出内存溢出错误——这几乎是每一位大模型开发者都经历过的常态。

根本原因在于:

单张 GPU 的显存无法容纳整个模型的参数权重。以一个拥有 700 亿参数的模型为例,仅 FP16 精度下的权重就接近 140GB,而目前最高配置的 H100 显卡也只有 80GB 显存。面对这种困境,该如何解决?

虽然可以通过数据并行的方式提升吞吐量,但这种方式只是“表面缓解”——每张卡仍需保存完整的模型副本,显存压力并未减轻。真正有效的解决方案是将模型本身进行拆分,使其分布于多张 GPU 上协同运行。这就是模型并行的核心思想,而其中最为精细且高效的技术之一便是本文要深入探讨的主题:张量并行(Tensor Parallelism)

可以这样类比:设想你要搬运一尊巨大的石雕,一个人难以承担全部重量。此时,并非寻找更强壮的人,而是将雕像切割成若干部分,每人负责一段,到达目的地后再重新拼合。张量并行正是遵循这一逻辑。它不满足于按层或按批次划分任务,而是深入到最基本的运算单元——矩阵乘法,对 $ Y = XW $ 这类核心操作进行分解。

其基本策略非常直观:

  • 当输出维度过大时,将权重矩阵 $ W $ 按列切分
  • 当输入维度较高,或后续需要聚合梯度信息时,则采用行切分方式。

随后通过高效的集合通信机制(如 All-Gather 或 All-Reduce),确保各设备上的计算结果能够正确合并,恢复完整语义。

看似原理简单,但实际上涉及复杂的数学推导、实现细节和系统级优化。下面我们逐步解析两种典型场景。

列切分 + All-Gather:适用于输出扩展场景

考虑全连接层中的标准运算 $ Y = XW $,设输入 $ X \in \mathbb{R}^{b \times m} $,权重 $ W \in \mathbb{R}^{m \times n} $,输出 $ Y \in \mathbb{R}^{b \times n} $。当 $ n $ 很大时(例如 FFN 层中将隐藏维度放大四倍),我们可以考虑让每块 GPU 只计算输出的一部分。

具体做法是:使用 $ P $ 张 GPU,将权重 $ W $ 水平切分为 $ [W_1, W_2, ..., W_P] $,每个子块大小为 $ m \times (n/P) $。每个设备保留完整的输入 $ X $,并独立执行局部计算:

$$ Y_i = X \cdot W_i $$

由于所有设备都持有完整的 $ X $,无需切分输入。计算完成后,各设备通过 All-Gather 操作交换各自的输出分片 $ Y_i $,最终拼接成完整输出:

$$ Y = [Y_1, Y_2, \dots, Y_P] $$

这个过程类似于拼图游戏:每个人完成一部分画面,最终组合才能呈现整体图像。

import torch
import torch.distributed as dist

def tensor_parallel_linear_forward(x, weight_chunk, rank, world_size):
    y_local = torch.matmul(x, weight_chunk)  # 局部计算

    # 所有设备交换结果
    y_list = [torch.empty_like(y_local) for _ in range(world_size)]
    dist.all_gather(y_list, y_local)

    return torch.cat(y_list, dim=-1)  # 拼接

该方法的优点在于计算完全独立,通信阶段统一处理;但缺点也明显:All-Gather 需要传输完整的输出张量,通信开销较大,尤其在 batch size 较大时,$ b \times n $ 的数据量会带来显著的带宽压力。

行切分 + All-Reduce:降低通信负载的替代方案

如果不需要保留各个分片,而只需要得到最终的求和结果,那么可以采用更节省通信资源的方式:行切分配合 All-Reduce

这一策略常见于 FFN 结构的第二层,或注意力机制后的输出投影层。同样是 $ Y = XW $,此时我们将 $ W $ 垂直切分为:

$$ W = \begin{bmatrix} W_1 \\ W_2 \\ \vdots \\ W_P \end{bmatrix}, \quad W_i \in \mathbb{R}^{(m/P) \times n} $$

同时输入 $ X $ 也被划分为 $ [X_1, X_2, \dots, X_P] $。然而,为了保证每台设备能完成局部计算,通常采取一种巧妙的方法:将输入 $ X $ 广播至所有设备,使得每个 GPU 都拥有完整的输入副本。

接着,各设备分别计算局部输出:

$$ Y_i = X \cdot W_i $$

最后,通过 All-Reduce 对所有 $ Y_i $ 进行求和:

$$ Y = \sum_{i=1}^P Y_i $$

从数学上看,结果等价于 $ Y = X \cdot (W_1 + \cdots + W_P) $,但由于每个设备只存储 $ 1/P $ 的权重,显存占用大幅下降,实现了高效的资源利用。

反向传播的过程也遵循类似的机制:梯度 $ \frac{\partial L}{\partial Y} $ 是完整的,会被广播到所有设备;每个设备基于本地数据计算对应的权重梯度 $ \frac{\partial L}{\partial W_i} = X^T \cdot \frac{\partial L}{\partial Y} $,随后通过 All-Reduce 操作将各设备的梯度结果进行聚合。 代码实现如下:
def tensor_parallel_linear_backward(x, grad_output, weight_chunk, rank, world_size):
    x_trans = x.transpose(0, 1)
    grad_weight_local = torch.matmul(x_trans, grad_output)

    # 跨设备求和
    dist.all_reduce(grad_weight_local, op=dist.ReduceOp.SUM)

    grad_input = torch.matmul(grad_output, weight_chunk.transpose(0, 1))
    return grad_input, grad_weight_local
关键在于理解这样一个核心思想: All-Reduce 不仅能够同步各设备上的梯度信息,还能隐式地完成“输出合并”的功能。这一点在涉及残差连接(Residual Connection)后接 LayerNorm 的结构中尤为重要——因为这类操作要求输入张量必须是完整且未分割的。 因此,可以总结出两种切分方式的应用场景与特点:
切分方式 典型用途 通信操作 输入是否复制
Column-wise QKV 计算、FFN Up-projection All-Gather
Row-wise Output Projection、FFN Down All-Reduce 是(广播)
小贴士:如果某一层之后紧跟着残差连接或 LayerNorm 操作,则必须保证输出为完整张量。此时应优先选择 Row-wise 切分配合 All-Reduce,避免引入额外的 All-Gather 通信开销。 在实际的 Transformer 架构中(如 Megatron-LM),张量并行(Tensor Parallelism)正是以这种方式集成进去的:
Input → 
       [Multi-Head Attention]
         ├── Q = X @ W_Q     → Column-split
         ├── K = X @ W_K     → Column-split
         ├── V = X @ W_V     → Column-split
         ├── Attention & Softmax
         └── Out = attn(Q,K,V) @ W_O → Row-split + All-Reduce

       [Feed-Forward Network]
         ├── Up: X @ W_up    → Column-split + All-Gather
         ├── GELU
         └── Down: X @ W_down → Row-split + All-Reduce
这种模式被称为“**Split-Then-Combine**”——先对输入进行切分并在多个设备上并行处理,扩大计算规模;再通过通信操作将结果合并,恢复为完整表示。这种方法既保留了高维中间特征的表达能力,又有效控制了单卡显存的增长。 举个具体例子:假设模型维度 $ d = 4096 $,FFN 层将维度扩展 4 倍至 16384,并使用 4 卡进行张量并行: - 权重矩阵 $ W_1 \in \mathbb{R}^{4096 \times 16384} $ 被按列切分为 4 块,每卡存储 $ 4096 \times 4096 $ - 各卡独立计算局部输出 $ H_i = X \cdot W_1^{(i)} $ - 然后执行 All-Gather,拼接成完整的激活值 $ H \in \mathbb{R}^{b \times 16384} $ - 接着在本地应用激活函数 - 进入下一级权重 $ W_2 \in \mathbb{R}^{16384 \times 4096} $,该矩阵按行切分,每卡持有 $ 4096 \times 4096 $ 的子块 - 每卡计算 $ O_i = A \cdot W_2^{(i)} $,最后通过 All-Reduce 得到最终输出 整个流程形成一个完美的闭环。 当然,真实场景往往更加复杂,部署过程中常会遇到以下挑战: 显存不足? → 引入梯度检查点(Gradient Checkpointing),用计算时间换取内存空间。 → 或采用 Sequence Parallelism,将序列长度维度也进行切分。 通信成为瓶颈? → 使用 Ring-AllReduce 替代传统 All-Reduce,降低峰值带宽压力。 → 开启计算与通信重叠(overlap communication with computation),实现边计算边传输。 负载不均衡? → 实施动态分片(Dynamic Sharding),根据序列长度自适应调整任务分配。 → 采用异构调度策略,让性能更强的设备承担更多计算任务。 如何与其他并行策略协同工作? 这才是构建大规模训练系统的关键所在!尽管 Tensor Parallelism 具备细粒度拆分的能力,但它并非万能。现代分布式训练普遍采用 **3D 并行** 架构: - **Tensor Parallelism**:按算子内部结构切分(细粒度) - **Pipeline Parallelism**:按网络层数切分(粗粒度) - **Data Parallelism**:按批次数据切分(横向扩展) 三者协同运作,才能真正支撑起千亿参数模型的高效训练。 [此处为图片4的位置已随内容调整] 最后说点通俗的话:你为什么应该重视 Tensor Parallelism? 因为它不仅仅是一个技术术语,更是一种解决问题的思维方式—— **当资源受限时,如何优雅地把一个大问题分解为多个小任务,并通过协作高效完成整体目标**。 这不仅是分布式深度学习的核心理念,也是现代 AI 工程体系的底层逻辑之一。 随着模型规模持续增长、硬件环境日益异构(CPU/GPU/TPU混合)、对推理延迟的要求不断提高,Tensor Parallelism 正在不断演进:支持稀疏激活、动态路由、自动化的切分策略选择等高级特性。 或许在不远的将来,我们调用一次跨数十甚至上百设备的并行矩阵运算,会像现在调用一个普通 matmul 函数一样自然和透明。而这一切的背后,正是这些看似“枯燥”的并行策略在默默支撑。 所以,请不要轻视那一次行切、那一列分。它们虽小,却是通往 AGI 之路的重要基石。
二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:Parallel Paralle Tensor 矩阵运算 SOR

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
jg-xs1
拉您进交流群
GMT+8, 2025-12-5 13:19