loader.py
import os
from glob import glob
import torch
from torch import nn
from safetensors import safe_open
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
def load_model(model: nn.Module, path: str):
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
for file in glob(os.path.join(path, "*.safetensors")):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
for k in packed_modules_mapping:
if k in weight_name:
v, shard_id = packed_modules_mapping[k]
param_name = weight_name.replace(k, v)
param = model.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
weight_loader(param, f.get_tensor(weight_name), shard_id)
break
else:
param = model.get_parameter(weight_name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, f.get_tensor(weight_name))
这是加载模型参数的函数,看似很简单,其实大有玄机
在理想世界里,我们可以直接model.load_state_dict(torch.load(...)),但我们这里做不了,原因有两大点
-
模型结构和 checkpoint 的“参数组织方式”不一致
-
比如我们Qwen3模型里Attention用的
qkv_proj的完整矩阵,MLP用的gate_up_proj的完整矩阵,简单介绍一下原因:- 可以减少内核启动开销,每当 CPU 想要 GPU 算一个矩阵乘法时,都要发一个指令,有固定的小延迟,减少启动次数,既可以减少时间开销
- 提高计算吞吐量,GPU适合处理大矩阵乘法,小矩阵乘法可能无法填满CUDA Cores,导致硬件半闲置
- 内存访问优化,对Attention的计算,输入向量
在计算 时是完全相同的,不合并的话GPU需要读写三次 ,合并的话仅需一次,可以减少显存宽带压力。
-
而 HF 风格 checkpoint 通常是拆开的,比如QKV拆成
q_proj.weight / k_proj.weight / v_proj.weight,MLP拆成gate_proj.weight / up_proj.weight -
这样的不对齐,导致无法直接加载参数
-
-
如果开启了TP,每张卡上的参数其实是“整块参数的一部分”,不同的卡需要加载的参数是不同的,我们不需要完全加载所有参数,所以需要特别处理,比如TP会有按行切和按列切,都需要不同的加载逻辑
为了更好的理解这个代码,我们需要掌握一些前置知识:
PyTorch 的参数系统
在 PyTorch 里,模型是一个层级系统:
nn.Module (模块 / 容器 / 计算逻辑)
│
├── nn.Module (子模块)
│ ├── nn.Parameter (真正的权重)
│ └── nn.Parameter
│
└── nn.Module
└── nn.Parameter
nn.Module 是 PyTorch 里所有模型、层、子层的统一抽象。任何你称之为“模型”“层”“子层”的东西,本质上都是一个 nn.Module 对象。只要你在 __init__ 里把一个 nn.Module 或 nn.Parameter 赋值给 self.xxx,它就会被自动注册进当前模块,形成一个树状结构,树的叶子节点是nn.Parameter ,其他节点都是nn.Moudle。
nn.Parameter在数据层面就是一个 Tensor,但在语义层面,它向 PyTorch 明确声明:“这是一个模型参数”。 PyTorch 并不会把所有 Tensor 都当作权重,只有 nn.Parameter 才会被自动加入参数列表、被优化器更新、被写入 checkpoint。这种区分让模型可以同时拥有“参与训练的状态”和“纯粹用于计算的中间张量”,而不会混在一起。
当 PyTorch 需要把模型参数导出成一个可序列化、可传输、可重建的形式时,它使用的不是对象引用,而是一个纯数据字典,也就是 state_dict。state_dict 是一个 {字符串名字 → Tensor} 的字典,表示模型当前的全部参数状态。这些字符串的名字不是随便取的,而是从树根出发,按树的路径(模型的属性名)一步步走,直到到达最后的nn.Parameter,最后以参数名结尾,比如.weight或者.bias,这中间用.来连接,比如model.layers.0.self_attn.qkv_proj.weight。重要的是:state_dict 里的 value 是 Tensor,而不是 nn.Parameter,这让它和具体的模型对象解耦——只要名字对得上,任何兼容结构的模型都可以加载这些数值。
model.get_parameter(),正因为是树状结构,有了路径,就可以找到模型底层的nn.Parameter对象,获得或者赋予相关的权重。
param.data.copy_()写入权重,.data 是参数底层 Tensor 的一个直接视图,它绕过了 autograd 的计算图机制;.copy_() 是一个原地拷贝操作,表示“用右边 tensor 的数值覆盖左边这块内存”。
checkpoint和safetensors
checkpoint是某一时刻模型的具体状态,保存在磁盘里,包含模型参数,也可能还有别的优化器状态、学习率调度器状态等等别的
最早 PyTorch 的 checkpoint 通常用 torch.save,底层是 Python 的 pickle。这种方式有两个严重问题:
pickle格式允许在加载时执行任意 Python 代码,这有安全风险。pickle是“整体反序列化”,你没法在不加载全部内容的情况下只读其中一个参数。如果只想加载一部分权重或者只在 CPU 上检查一下 key,pickle 会直接把你拖垮。
于是 safetensors 出现了,它的设计目标非常明确:只存张量,不存代码;只做零拷贝读,不做对象反序列化;允许你按 key 精确索引 tensor。
safetensors 的设计非常纯粹:一个简单的 JSON 请求头 + 一块连续的原始二进制数据。
| 字节偏移 | 长度 | 内容 | 说明 |
|---|---|---|---|
| 0 ~ 7 | 8 字节 | Header Size | 一个无符号 64 位整数(uint64),表示后面 Header 的字节长度。 |
| 8 ~ (8 + Header Size) | 变长 | Header (JSON) | 一个 UTF-8 编码的 JSON 字符串,包含所有权重的元数据。 |
| (8 + Header Size) ~ EOF | 变长 | Data Buffer | 原始二进制数据流。所有张量的具体数值都紧密排列在这里。 |
一个典型的 Header 看起来像这样:
{
"__metadata__": { "format": "pt" },
"model.layers.0.self_attn.q_proj.weight": {
"dtype": "F16",
"shape": [4096, 4096],
"data_offsets": [0, 33554432]
},
"model.layers.0.self_attn.k_proj.weight": {
"dtype": "F16",
"shape": [4096, 4096],
"data_offsets": [33554432, 67108864]
}
}
__metadata__: 可选部分,存储模型版本、框架来源等信息。dtype: 数据类型(如F16,BF16,F32,I32等)。shape: 张量的维度(如[4096, 4096])。data_offsets: 这是一个包含两个数字的列表[start, end]。它们代表该张量在 Data Buffer 部分的相对起始和结束字节偏移量。
通过使用with safe_open(file, "pt", "cpu") as f:的方式加载safetensors 文件,此时他就像一个只读的字典一样,可以通过f.keys()获得存储的所有参数名,如
model.layers.0.self_attn.q_proj.weight
model.layers.0.self_attn.k_proj.weight
model.layers.0.self_attn.v_proj.weight
也可以通过f.get_tensor(weight_name)来获取参数名weight_name对应的tensor值,也就是具体参数值
packed_modules_mapping
模型在设计的时候使用的具体架构和模型保存成safetensors文件时的存储结构是不一样的,存储的时候safetensors文件把所有参数分开存储了,而推理设计模型的时候可能会存在一些矩阵结构合并的操作,那我们在加载safetensors时就需要特殊处理这些地方,所以我们需要一个映射结构,告诉我们这二者之间的对应关系
在真正的大模型工程里,权重是被两种完全不同的“世界观”使用的。一边是训练与模型定义的世界,另一边是推理与系统优化的世界。这两个世界对“参数应该长什么样”有根本不同的诉求,而 packed_modules_mapping 就诞生在这条断层线上。
在训练和模型表达的世界里,参数的组织方式首先服务的是“语义清晰”。注意力机制里有 Query、Key、Value,这是论文里的概念,也是人类理解模型的方式;MLP 里有 gate 和 up,这是为了描述激活函数的结构。这种拆分方式让模型定义直观、让微调工具(比如 LoRA)可以只作用在某一个投影上,也让不同模型之间的权重更容易复用。因此 HuggingFace 生态中,checkpoint 几乎统一采用“拆开的权重命名”:q_proj.weight、k_proj.weight、v_proj.weight、gate_proj.weight、up_proj.weight。这是一种表达友好、生态友好的组织方式。
而在推理系统和高性能实现的世界里,关注点完全变了。这里最重要的是算子数量、内存访问模式和并行效率。三次线性变换不如一次线性变换快,三个 GEMM kernel 不如一个 GEMM kernel 高效;如果再考虑 Tensor Parallel,把权重按输出维度切分并融合成一个大矩阵,反而更容易实现负载均衡和高吞吐。因此你现在的模型实现里,注意力层不再有三个独立的 q_proj/k_proj/v_proj,而是只有一个 qkv_proj;MLP 不再有 gate_proj 和 up_proj,而是一个 gate_up_proj。这种结构不是为了“好看”,而是为了计算友好和系统友好。
同一份模型语义,被两种世界用两种完全不同的参数结构来表达。此时,加载权重时,必须存在一个“翻译层”,能把safetensors中拆分的参数映射到我们推理所设计的模型内部的融合参数上去。
例如我们的Qwen3:
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
checkpoint 中的 q_proj 不是一个独立存在的参数,而是模型中 qkv_proj 这个大参数里,专门用于 Query 的那一段。k_proj和v_proj也是类似。同样地,"gate_proj": ("gate_up_proj", 0) 和 "up_proj": ("gate_up_proj", 1) 并不是随意的编号,而是在告诉加载器:gate 和 up 各自对应融合矩阵中的哪一段。
loader.py
好了,现在了解了上述知识,我们来重新重新定义一下loader.py的作用
我们要把checkpoint,也就是.safetensors文件里面的所有参数都加载到我们推理的具体模型里面去
我们可以先看一下具体的.safetensors文件的内部有什么,以qwen3为例
🔹 Embedding (1 个)
model.embed_tokens.weight (151936, 1024)
🟢 RMSNorm (113 个)
model.layers.0.input_layernorm.weight (1024,)
model.layers.0.post_attention_layernorm.weight (1024,)
model.layers.0.self_attn.k_norm.weight (128,)
... 还有 110 个
🔵 Attention (QKV) (168 个)
model.layers.0.self_attn.k_norm.weight (128,)
model.layers.0.self_attn.k_proj.weight (1024, 1024)
model.layers.0.self_attn.o_proj.weight (1024, 2048)
... 还有 165 个
🟡 MLP (84 个)
model.layers.0.mlp.down_proj.weight (1024, 3072)
model.layers.0.mlp.gate_proj.weight (3072, 1024)
model.layers.0.mlp.up_proj.weight (3072, 1024)
... 还有 81 个
🔸 LM Head (1 个)
lm_head.weight (151936, 1024)
我们可以把模型中的层按照这是否是融合算子和是否TP两个维度重新组合:
| 分类 | 是否语义融合 | 是否 TP 切分 | 典型代表 |
|---|---|---|---|
| 全能型 (Fused + TP) | 是 | 是 | qkv_proj (ColumnParallel), gate_up_proj (ColumnParallel) |
| 骨干型 (Unfused + TP) | 否 | 是 | o_proj (RowParallel), down_proj (RowParallel) |
| 独立型 (Unfused + No TP) | 否 | 否 | input_layernorm, post_attention_layernorm |
| 特殊型 (Fused + No TP) | 是 | 否 | 较少见。除非模型极小或显存极大,不需要切分。 |
但是我们在写loader.py的时候不用考虑的那么细,我们只需要考虑怎么把已有参数加载到我们推理的具体模型里面去即可,具体的weight_loader下放到各自的Linear去实现,可能不好理解,需要多看看源码
我们介绍一下具体的loader.py逻辑
# 默认的权重加载,针对的是最普通的linner,即独立型 (Unfused + No TP),比如LayerNorm / RMSNorm,我们直接复制checkpoint里的Tensor给对应的param
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
def load_model(model: nn.Module, path: str):
# 获取一下packed_modules_mapping字典,没有的话就是空字典
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
# 遍历所有.safetensors文件
for file in glob(os.path.join(path, "*.safetensors")):
# 以 safe_open 方式打开,"pt" 指返回 PyTorch 张量,"cpu" 指加载到内存
with safe_open(file, "pt", "cpu") as f:
# 通过f.keys()获取.safetensors文件里面所有的参数名
for weight_name in f.keys():
# 检查当前的权重名是否命中映射表(即:它是否是某个大层的一部分)
for k in packed_modules_mapping:
# 如果是,说明这个是融合算子,我们需要把参数整合到融合后的具体位置
if k in weight_name:
# v: 目标大层的名字,shard_id: 这部分权重在大层里的“身份”
v, shard_id = packed_modules_mapping[k]
# 换名:把小层名换成大层名(如把 q_proj 换成 qkv_proj)
param_name = weight_name.replace(k, v)
# 找到内存中真正的参数容器(Parameter 对象)
param = model.get_parameter(param_name)
# 获取该层专属的“加载器”(通常是分布式层自带的特殊函数)
weight_loader = getattr(param, "weight_loader")
# 执行加载:由于是融合层,这里会传入 shard_id 指导如何拼接
weight_loader(param, f.get_tensor(weight_name), shard_id)
# 处理完这个权重,跳出映射表循环,继续下一个权重名
break
else:
# 如果上面的 for 循环没有触发 break(即没命中融合映射)
# 直接按原始名称寻找模型中的参数
param = model.get_parameter(weight_name)
# 尝试获取自定义加载器,如果没有,就用最简单的“全量拷贝”函数
weight_loader = getattr(param, "weight_loader", default_weight_loader)
# 执行加载:这里没有 shard_id,是 1 对 1 的搬运
weight_loader(param, f.get_tensor(weight_name))
外层遍历的逻辑其实很好理解,就是在枚举checkpoint里面的参数,然后看看我们的字典是否有命中,有命中的话就换个名字,换完名字就要想办法把具体参数给他塞进去,这里的用法看着就没那么简单明了,我们可以具体学习一下,主要是下面着三行代码
param = model.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
weight_loader(param, f.get_tensor(weight_name), shard_id)
-
model.get_parameter(param_name)这个也很好理解,就是根据get_parameter方法,按param_name的字符串名字去模型里按树形结构遍历得到对应的底层的nn.Paramer, -
weight_loader = getattr(param, "weight_loader"),这个代码可能会让你有点疑惑,什么是getattr?为什么可以对param进行getattr?具体是怎么定义的?-
首先,
getattr是 Python 用来“通过字符串访问对象属性”的内建机制,他和.的语义是一样的,都是访问对象的属性,但是使用.的前提是,你明确知道一定会有这个属性,否则Python 会立刻抛出AttributeError,程序直接中断,是静态写死的。而getattr可以以更安全的形式去运行,是动态获取的,如果你没有找到这个属性,可以返回一个指定的默认属性,即getattr(obj, "x", default),如果obj有x这个属性,那就返回x,如果没有的话我们就返回默认的default,当然,这个属性既可以是变量,也可以是方法 -
而
param之所以能被getattr,是因为nn.Parameter本身就是一个普通的 Python 对象,可以被动态加属性。 -
我们可以看一下具体的定义方法
class LinearBase(nn.Module): def __init__( self, input_size: int, output_size: int, bias: bool = False, tp_dim: int | None = None, ): super().__init__() self.tp_dim = tp_dim self.tp_rank = dist.get_rank() self.tp_size = dist.get_world_size() self.weight = nn.Parameter(torch.empty(output_size, input_size)) self.weight.weight_loader = self.weight_loader if bias: self.bias = nn.Parameter(torch.empty(output_size)) self.bias.weight_loader = self.weight_loader else: self.register_parameter("bias", None) def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError -
在
LinearBase初始化的时候,我们定义了self.weight = nn.Parameter(torch.empty(output_size, input_size)),然后通过self.weight.weight_loader = self.weight_loader的语句给weight的nn.Parameter类挂载了weight_loader属性, -
由于子类会 override
weight_loader方法,而LinearBase.__init__里挂载的是self.weight_loader这个绑定方法,所以最终参数上绑定的是子类的weight_loader实现。 -
比如:
class ColumnParallelLinear(LinearBase): def __init__( self, input_size: int, output_size: int, bias: bool = False, ): tp_size = dist.get_world_size() super().__init__(input_size, divide(output_size, tp_size), bias, 0) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) param_data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.linear(x, self.weight, self.bias) class QKVParallelLinear(ColumnParallelLinear): def __init__( self, hidden_size: int, head_size: int, total_num_heads: int, total_num_kv_heads: int | None = None, bias: bool = False, ): tp_size = dist.get_world_size() total_num_kv_heads = total_num_kv_heads or total_num_heads self.head_size = head_size self.num_heads = divide(total_num_heads, tp_size) self.num_kv_heads = divide(total_num_kv_heads, tp_size) output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size super().__init__(hidden_size, output_size, bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): param_data = param.data assert loaded_shard_id in ["q", "k", "v"] if loaded_shard_id == "q": shard_size = self.num_heads * self.head_size shard_offset = 0 elif loaded_shard_id == "k": shard_size = self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size else: shard_size = self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank] param_data.copy_(loaded_weight)
-
-
weight_loader(param, f.get_tensor(weight_name), shard_id),就是调用我们刚刚拿出来的weight_loader函数,把权重塞到应该去的地方