linear.py
LinearBase
class LinearBase(nn.Module):
def __init__(self,
input_size: int,
output_size: int,
bias : bool = False,
tp_dim : int | None = None
):
super.__init__()
self.tp_dim = tp_dim
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
self.weight = nn.Parameter(torch.empty(output_size, input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x : torch.Tensor) -> torch.Tensor:
raise NotImplementedError
这是所有并行Linear的抽象基类,不是“算线性层”,而是统一管理 Tensor Parallel Linear 的公共状态、参数结构和权重加载协议。
tp_dim是沿着哪个维度做切分,Column Parallel 是0,Row Parallel是1tp_rank是当前进程在Tensor Parallel通信组中的rank,tp_size是Tensor Parallel的并行度,即有多少个GPU参与并行计算self.weight = nn.Parameter(torch.empty(output_size, input_size)),这里创建了一个[output_size, input_size]的可训练参数矩阵- Q:为什么不是
[input_size, output_size]? - 因为Pytorch把Linear定义成,假设的形状是,那的形状是,那我们期望输入到
F.linear(x, w, bisa)中的形状就是
- Q:为什么不是
self.weight.weight_loader = self.weight_loader,给Parameter动态绑定了一个如何加载权重的方法-
if bias: self.bias = nn.Parameter(torch.empty(output_size)) self.bias.weight_loader = self.weight_loader else: self.register_parameter("bias", None)如果需要偏置项,就定义一下,也给他挂载同样的权重加载函数
否则就告诉
nn.Module这个模块没有bias参数,避免hasattr(self, 'bias')行为混乱
ReplicatedLinear
class ReplicatedLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bisa: bool = False
):
super().__init__(input_size, output_size, bisa)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
把权重在所有rank上完全复制
列并行与行并行
在分布式深度训练中,当模型大到单个 GPU 显存无法装下时(例如千亿参数的 LLM),我们就需要将模型拆分到多个 GPU 上。行并行(Row Parallelism)和列并行(Column Parallelism)是张量并行(Tensor Parallelism)中最核心的两种拆分方式。
这两者通常以 Megatron-LM 提出的方式协同工作,主要针对线性层 进行拆分。
列并行 (Column Parallelism)
核心思想: 将权重矩阵 按列拆分。
- 拆分方式: 将 拆分为 。
- 计算逻辑: 每个 GPU 持有相同的输入 ,分别计算出一部分输出特征。
- 通信需求: 前向传播: 每个 GPU 计算出 的一部分。如果下一层需要完整的 ,则需要进行一次 All-Gather 操作,或者保持拆分状态进入下一层(通常做法)。
- 反向传播: 需要进行一次 All-Reduce 来同步梯度。
优点: 适合作为多层感知机(MLP)的第一层,因为它不需要对输入 进行复杂的切分。
class ColumnParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bisa: bool = False,
):
tp_size = dist.get_world_size()
super().__init__(input_size, divide(output_size, tp_size), bisa, 0)
def weight_loader(self, para : nn.Parameter, loaded_weight: torch.Tensor):
para_data = para.data
shard_size = para_data.size(self.tp_dim)
stard_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, stard_idx, shard_size)
para_data.copy_(loaded_weight)
def forward(self, x):
return F.linear(x, self.weight, self.bias)
行并行 (Row Parallelism)
核心思想: 将权重矩阵 按行拆分。
- 拆分方式: 将 拆分为 。
- 计算逻辑: 为了能进行矩阵乘法,输入 也必须按列拆分为 。
$Y = [X_1, X_2] \begin{bmatrix} W_1 \\ W_2 \end{bmatrix} = X_1W_1 + X_2W_2$
- 通信需求:
- 前向传播: 每个 GPU 分别计算出部分和 ,最后必须进行一次 All-Reduce 操作将结果相加,得到最终的 。
- 反向传播: 梯度通过 All-Reduce 后的算子自然分发。
优点: 能够直接缩减中间特征图的维度,常作为 MLP 的第二层。
class RowParallelLinear(LinearBase):
def __init__(
self,
input_size : int,
output_size: int,
bias: bool = False,
):
tp_size = dist.get_world_size()
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
shard_id = shard_size * self.tp_rank
loaded_weight = loaded_weight.narrow(self.tp_dim, shard_id, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.tp_size > 1:
dist.all_reduce(y)
return y
列并行和行并行的关系
ColumnParallelLinear和RowParallelLinear在__init__()上的区别是tp_dim的不同,列并行的tp_dim是0,因为我们的LinearBase定义的时候是,我们要拆分的是output_size,所以dim是0,行并行也同理
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
# 获取param指向的Tensor的数据向量
param_data = param.data
# 获取被拆分后的那个维度,要时刻记住,我们的param是并行拆分后的,所以size是被除过的
shard_size = param_data.size(self.tp_dim)
# 获取当前这个gpu上对于param应该加载的下标id,前面有tp_rank个,每一个的大小是shard_size,所以乘起来就得到了id
shard_id = shard_size * self.tp_rank
# loadded_weight是完整的矩阵,所以我们需要聚焦到对应的维度,获得对应的大小
loaded_weight = loaded_weight.narrow(self.tp_dim, shard_id, shard_size)
# 拷贝过去
param_data.copy_(loaded_weight)
行并行和列并行还有一点不同的是,forward函数,如果有偏置项bias,
列并行的数学计算公式为,对于列并行,每个并行的linear都会算上bias,但是列并行的结果是二者拼接在一起,所以不会产生bias多加的问题
行并行的数学公式计算为:
对于行并行,每个并行的linear都会加上bias,统一加上的时候,就会多家tp_size-1份bias
解决方法是,只让其中一个linear计算bias,比如tp_rank=0的保留他原来的bias值,其他的全部置None,然后用一个all_reduce来同步计算出来的y
def forward(self, x):
return F.linear(x, self.weight, self.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.tp_size > 1:
dist.all_reduce(y)
return y
MergedColumnParallelLinear
MergedColumnParallelLinear 这一层的存在,本质上是为了解决一个在大模型实现中非常现实、也非常工程化的问题:多个逻辑上独立的线性层,如果它们具有相同的输入维度并且在计算上总是“同时出现”,那么完全可以把它们在参数层面合并成一个大的 Column Parallel Linear,以减少 kernel 启动次数、提升访存与计算效率,同时又不破坏原有的模型语义。这一类结构在 Transformer 的 MLP、Attention 投影中都极其常见。
我们举个例子,先从完全不并行、最直观的 MLP开始。假设你在写一个带 gate 的 MLP(Qwen / LLaMA 那一类),最自然的写法一定是这样的:
gate = Linear(hidden_size, intermediate_size)(x)
up = Linear(hidden_size, intermediate_size)(x)
y = silu(gate) * up
这里有三个非常直观的事实:第一,gate 和 up 输入一模一样,都是 x;第二,它们的 输出维度一模一样,都是 intermediate_size;第三,它们永远成对出现,你不可能只算 gate 不算 up,也不可能反过来。
接下来我们加一点“现实约束”:模型变大了,intermediate_size 很大,于是你开始做张量并行(TP)。这时候,你把这两个 Linear 都换成 Column Parallel 的版本,大概就变成:
gate = ColumnParallelLinear(hidden_size, intermediate_size)(x)
up = ColumnParallelLinear(hidden_size, intermediate_size)(x)
y = silu(gate) * up
这两行 Linear,除了名字不同,几乎一模一样,而且都是对同一个 x 做线性变换。
那我们可不可以把两个矩阵合并成一个来节省计算时间,当然是可以的
原来:
W_gate: [intermediate, hidden]
W_up: [intermediate, hidden]
合并成:
W_merged: [2 * intermediate, hidden]
进行一次线性变换
gate_up = x @ W_merged.T
得到的结果:
gate_up = [ gate , up ]
这里我们也可以解释一下,为什么不用 RowParallelLinear,因为用ColumnParalle是按列展开,根据公式,我们可以发现,两个矩阵放在一起可以直接拆开,乘法的时候毫不影响互相之间的计算,很方便我们合并小矩阵
我们可以解读一下源码
class MergedColumnParallelLinear(ColumnParallelLinear):# 继承的是ColumnParallelLinear类
def __init__(
self,
input_size : int,
output_sizes : list[int], # outpu_sizes是一个列表,因为我们合并的是多个ColumnParallelLinear,大小可能不一样
bisa : bool = False
):
self.output_sizes = output_sizes
super.__init__(input_size, sum(output_sizes), bisa) # 我们合并后的ColumnParallelLinear的输出维度应该是所有小ColumnParallelLinear的outpu_size之和
def weight_loader(self, param : nn.Parameter,loadded_weight : torch.Tensor,loaded_shard_id : int):
param_data = param.data # 先取数据
# 我们聚焦于output,从宏观的角度来说,假设我们合并的矩阵的形状是[input_size * a, input_size * b, input_size * c, ...],即大矩阵的output的大小是[a, b, c, ...]若干个子矩阵拼再一起,那划分到不同的gpu上,每个gpu持有的矩阵的output的大小应该是[a/tp.size, b/tp.size, c/tp.size, ...]
# 而 loaded_shard_id 就是我们指的 a/b/c的下标,我们要获取它在output里对应的下标,那计算方式就是 sum(self.output_sizes[:loaded_shard_id]) // self.tp_size,先求一下大矩阵前loaded_shard_id个的和,然后除以 tp_size
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
# shard_size 是我们要加载 loaded_shard_id 对应的小矩阵的output长度,我们可以利用output_sizes获取对应的大矩阵的output长度,除以 tp_size 就可以获得分到每个gpu上的小矩阵的output长度
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
# 聚焦到 tp_dim ,获取从 shard_offse 开始,长度为 shard_size 的矩阵
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
# 虽然我们在推理定义模型的时候用的是合并后的MergedColumnParallelLinear,但是在训练的时候用的还是分开的ColumnParallelLinear,也就是说 loadded_weight 是一个单独的矩阵的参数,也就是形如 input * a 的矩阵,我们把他切成 tp_size 份,取对应的 tp_rank 的那份即是需要加载进去的权重
loadded_weight = loadded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loadded_weight)
QKVParallelLinear
在 Transformer 的注意力机制中,输入向量 需要被投影到三个不同的向量空间:
- Q (Query/查询):代表“我正在找什么”。
- K (Key/键):代表“我能提供什么”。
- V (Value/值):代表“我包含的具体信息”。
数学上,这通常对应三个独立的线性层:
为了计算效率,工程上通常将这三个权重矩阵合并成一个巨大的矩阵,一次性完成乘法操作。
QKVParallelLinear的逻辑明明就是上面的MergedColumnParallel,为什么要单独写一个?
原因有三点:
- 处理 GQA/MQA 的非对称性
通用的
MergedColumnParallelLinear通常假设它合并的几个层是“对称”的,或者只是简单地平分。 但在现代模型中,GQA(Grouped Query Attention) 是标配。- Q 的维度:
num_heads * head_size - K/V 的维度:
num_kv_heads * head_size(通常num_kv_heads远小于num_heads)
我们在
QKVParallelLinear内部专门处理了这种非对称的计算逻辑:如果用通用的
Merged类,你必须手动计算这些复杂的 Offset 并传进去,而专用类可以根据head_size和头数自动推导,极大地降低了出错概率。 - Q 的维度:
- 语义清晰度与后处理逻辑
Q、K、V不是三个独立的线性层,它们是自注意力机制的有机整体。- Reshape 需求:在
forward之后,输出张量必须立即被view或reshape成[batch, seq, head, head_size]。 - 专用属性:
QKVParallelLinear会把num_heads和num_kv_heads存为类属性。后续的 Attention 算子(如 FlashAttention)可以直接读取这些属性,而不需要再次计算或查找配置。 - 代码可读性:看到
QKVParallelLinear任何人都会立刻明白这是 Attention 的入口,而MergedColumnParallelLinear可能被用于 MLP 的gate_up_proj(SwiGLU),两者的物理结构虽然相似,但数学意义完全不同。
- Reshape 需求:在
- 加载权重的复杂映射
特定的加载协议:官方权重(如 HuggingFace 格式)往往将 Q、K、V 存储为独立的 Tensor。
Offset 自动计算:
QKVParallelLinear知道K永远跟在Q后面,V永远跟在K后面。这种“固定的排列契约”如果写在通用的Merged类里会显得非常臃肿(你需要传一堆shard_id和复杂的映射表),而在专用的 QKV 类里则非常优雅。
具体源码如下,逻辑很简单,如上MergedCollumnParallel
class QKVParallelLinear(ColumnParallelLinear):
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_kv_heads: int | None = None, # 如果kv_heads是None,表示默认等于num_heads
bias: bool = False,
):
tp_size = dist.get_world_size()
total_kv_heads = total_kv_heads or total_num_heads
self.head_size = head_size
self.num_heads = divide(total_num_heads, tp_size)
self.num_kv_heads = divide(total_kv_heads, tp_size)
output_size = (self.num_heads + 2 * self.num_kv_heads) * self.head_size
super().__init__(hidden_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loadded_weight: torch.Tensor, loadded_shard_id: str):
param_data = param.data
assert loadded_shard_id in ["q", "k", "v"]
if loadded_shard_id == "q":
shard_size = self.head_size * self.num_heads
shard_offset = 0
elif loadded_shard_id == "k":
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.head_size * self.num_heads
else:
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loadded_weight = loadded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loadded_weight)