prepare_prefill和prepare_block_tables
prepare_prefill 函数的作用,本质上是为大模型推理中的 prefill 阶段做一次完整的数据整理与运行时环境构建,它并不是简单地把多个序列拼接起来,而是在一个支持 block 级 KV cache、prefix cache 复用以及 FlashAttention 的高性能推理框架中,将多个变长、可能部分已缓存的序列,转换成一次可以直接送入 GPU kernel 执行的结构化输入。
def prepare_prefill(self, seqs: list[Sequence]):
input_ids = []
positions = []
cu_seqlens_q = [0]
cu_seqlens_k = [0]
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
block_tables = None
for seq in seqs:
seqlen = len(seq)
input_ids.extend(seq[seq.num_cached_tokens:])
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
seqlen_q = seqlen - seq.num_cached_tokens
seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(max_seqlen_q, seqlen_q)
max_seqlen_k = max(max_seqlen_k, seqlen_k)
if not seq.block_table:
continue
for i in range(seq.num_cached_blocks, seq.num_blocks):
start = seq.block_table[i] * self.block_size
if i != seq.num_blocks - 1:
end = start + self.block_size
else :
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
if cu_seqlens_k[-1] > cu_seqlens_q[-1]:
self.prepare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
return input_ids, positions
所以,我们prefill 阶段的目标是将当前批次的seq整理成input_ids和positions,同时更新 Context 中相关的元数据,以便于底层算子直接读取调用,比如cu_seqlens_q、cu_seqlens_k、max_seqlen_q、max_seqlen_k、slot_mapping、block_tables等
这些元数据我们只需要根据相关定义直接扫一遍seqs就可以得出来。需要注意的是,prefill 阶段的目标是将 prompt 中尚未计算的 token 一次性前向计算,并把对应的 K/V 写入 KV cache,为后续 decode 做准备;而在一个支持 prefix cache 的系统里,每个序列前面一部分 token 可能已经算过并缓存好了,因此这里只需要处理“未缓存”的部分。
input_ids.extend(seq[seq.num_cached_tokens:])把未缓存的输入数据收集起来positions.extend(list(range(seq.num_cached_tokens, seqlen)))构造对应input_ids的positions,我们是从未缓存的地方开始计算-
seqlen_q = seqlen - seq.num_cached_tokens seqlen_k = seqlen cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)前缀和数组,维护起来很简单
-
max_seqlen_q = max(max_seqlen_q, seqlen_q) max_seqlen_k = max(max_seqlen_k, seqlen_k)维护一下两个最大值,FlashAttention 需要知道最大长度用于分块
-
if not seq.block_table: continue这是用于warmup的时候,跳过
slot_mapping的构造预热过程旨在通过让 GPU 空跑伪造的填充数据,强制触发 CUDA 上下文初始化、显存分配器就绪以及底层 C++ 算子的编译与加载,从而彻底消除真实线上请求的冷启动延迟;然而,由于这些伪造数据计算出的 KV Cache 是毫无意义的逻辑垃圾,如果为其执行
slot_mapping,将导致这些垃圾数据直接写入并严重污染极其宝贵的 PagedAttention 全局物理显存池,并带来无谓的内存分配与释放开销。
-
for i in range(seq.num_cached_blocks, seq.num_blocks): start = seq.block_table[i] * self.block_size if i != seq.num_blocks - 1: end = start + self.block_size else: end = start + seq.last_block_num_tokens slot_mapping.extend(list(range(start, end)))根据输入的seqs构造一下
slot_mapping -
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: self.prepare_block_tables(seqs) def prepare_block_tables(self, seqs: list[Sequence]): max_len = max(len(seq.block_table) for seq in seqs) block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs] block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) return block_tables如果
cu_seqlens_k[-1] > cu_seqlens_q[-1]说明本次的输入数据是经历过了prefix caching,即有一些kv cache是存在我们的block 里面,我们是需要取出来,因为后面prefill计算注意力的时候是需要用到之前的KV cache。block table是一个二维矩阵,我们需要先获取最大长度max_len,然后创建对应的矩阵,对于没有的位置用-1来代替 -
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)这段代码是预填充(Prefill)数据准备阶段跨越 CPU 与 GPU 物理边界的“终极交接仪式”:它首先利用锁页内存(
pin_memory)配合异步 DMA(non_blocking),将 CPU 端精心编排好的输入序列、绝对位置索引、用于 FlashAttention 的变长序列边界(cu_seqlens)以及用于 PagedAttention 的物理内存路由表(slot_mapping)以零阻塞的极致性能高速空投至 GPU 显存;紧接着,通过set_context将这批复杂的底层寻址指针与维度元数据整体挂载到系统的全局上下文中,从而完美屏蔽了底层的调度复杂性,确保底层的 C++ 算子能够直接从该上下文中精准抓取所需参数。
prepare_decode
def prepare_decode(self, seqs: list[Sequence]):
input_ids = []
positions = []
slot_mapping = []
context_lens = []
for seq in seqs:
input_ids.append(seq.last_token)
positions.append(len(seq) - 1)
slot_mapping.append((seq.block_table[-1]) * self.block_size + seq.last_block_num_tokens - 1)
context_lens.append(len(seq))
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.prepare_block_tables(seqs)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
return input_ids, positions
prepare_decode 函数的作用是为大模型推理中的 decode 阶段做一次数据整理与运行时环境构建,和prepare_prefill函数比较类似,由于decode阶段,每个序列的输入都是上一个token,所以我们省去了维护cu_seqlen_q、cu_seqlen_k、max_seqlen_q、max_seqlen_k这四个变量,具体代码比较简单,我们不过多赘述
prepare_sample
def prepare_sample(self, seqs: list[Sequence]):
temperatures = []
for seq in seqs:
temperatures.append(seq.temperature)
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
return temperatures
获取每个序列的采样参数
run()和run_model()
@torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
return self.model.compute_logits(self.model(input_ids, positions))
else:
bs = input_ids.size(0)
context = get_context()
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
graph_vars = self.graph_vars
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"].fill_(-1)
graph_vars["slot_mapping"][:bs] = context.slot_mapping
graph_vars["context_lens"].zero_()
graph_vars["context_lens"][:bs] = context.context_lens
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
graph.replay()
return self.model.compute_logits(graph_vars["outputs"][:bs])
这段代码实现了一个推理执行器中的核心函数 run_model,其主要目的是在大型语言模型的推理过程中,根据当前是预填充(Prefill)还是解码(Decode)阶段,灵活地选择是直接运行模型还是利用 CUDA Graphs 来加速执行。
函数通过 @torch.inference_mode() 装饰器确保在推理时不记录梯度,从而优化性能。
在执行逻辑的开始,它首先进行条件判断。如果当前处于 is_prefill 阶段(即处理用户输入的 Prompt)、或者设置了强制使用 Eager 模式(enforce_eager)、再或者当前处理的 Token 数量(input_ids.size(0))超过了 512 的阈值,代码将走传统的执行路径。此时,它会直接调用 self.model(...) 并计算逻辑输出(logits)。这种做法是因为预填充阶段的输入长度是动态且多变的,CUDA Graphs 这种静态捕获技术在处理变长输入时不仅维护成本高,而且由于 GPU 计算核心通常已经饱和,图捕获带来的开销缩减边际效应递减。
相反,如果上述条件都不满足,说明目前正处于解码阶段的 Auto-regressive 步骤。由于解码通常是逐个 Token 生成,Batch Size 通常较小且固定,此时 CPU 发射算子的开销会成为主要瓶颈。为了消除这种开销,代码利用了预先捕获好的 CUDA Graph。它首先根据当前的 Batch Size 从 self.graphs 中选取一个预分配好的计算图。为了实现这一点,代码直接将外部输入的 input_ids、positions 以及 KV Cache 相关的元数据(如 slot_mapping、context_lens 和 block_tables)手动拷贝到固定的 graph_vars 张量中。这些张量是 CUDA Graph 在捕获阶段就已经绑定的内存地址,因此这种“搬运”操作是更新图输入的关键。完成数据填充后,通过调用 graph.replay(),GPU 会直接执行之前记录的一连串算子序列,完全跳过了 CPU 的介入和算子调度。最后,函数从预定义的输出张量 graph_vars["outputs"] 中截取对应 Batch 的部分,并计算并返回最终的 logits 结果,从而实现了极低延迟的 Token 生成。
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context()
return token_ids
首先根据布尔变量 is_prefill 的状态来决定输入数据的组织方式。如果处于预填充阶段,它调用 prepare_prefill 将原始序列转化为模型可理解的 input_ids 和位置编码 positions,这通常涉及处理较长的 Prompt 输入;如果处于解码阶段,则通过 prepare_decode 处理增量生成的单个 Token。与此同时,针对采样逻辑,代码通过 prepare_sample 准备采样所需的温度参数(temperatures),但为了避免冗余计算,这部分操作仅在主进程(self.rank == 0)中执行。
紧接着,核心计算任务被移交给 run_model 函数。正如之前分析的,该步骤会根据输入规模自动切换 Eager 模式或 CUDA Graphs 模式来高效计算出 logits(未归一化的概率分布)。在获取到 logits 后,主进程利用 self.sampler 根据设定的温度进行概率采样,选出下一个 Token 的 ID,并将其转换为列表形式。在返回结果之前,函数调用 reset_context() 清理当前的推理上下文(如 slot_mapping 或临时张量状态),确保下一轮推理的干净启动。