从0开始复现nano-vllm「llm_engine.py」

llm_engine

import atexit
from dataclasses import fields
from time import perf_counter
from tqdm.auto import tqdm
from transformers import AutoTokenizer
import torch.multiprocessing as mp

from nanovllm.config import Config
from nanovllm.sampling_params import SamplingParams
from nanovllm.engine.sequence import Sequence
from nanovllm.engine.scheduler import Scheduler
from nanovllm.engine.model_runner import ModelRunner


class LLMEngine:

    def __init__(self, model, **kwargs):
        config_fields = {field.name for field in fields(Config)}
        config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
        config = Config(model, **config_kwargs)
        self.ps = []
        self.events = []
        ctx = mp.get_context("spawn")
        for i in range(1, config.tensor_parallel_size):
            event = ctx.Event()
            process = ctx.Process(target=ModelRunner, args=(config, i, event))
            process.start()
            self.ps.append(process)
            self.events.append(event)
        self.model_runner = ModelRunner(config, 0, self.events)
        self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
        config.eos = self.tokenizer.eos_token_id
        self.scheduler = Scheduler(config)
        atexit.register(self.exit)

    def exit(self):
        self.model_runner.call("exit")
        del self.model_runner
        for p in self.ps:
            p.join()

    def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
        if isinstance(prompt, str):
            prompt = self.tokenizer.encode(prompt)
        seq = Sequence(prompt, sampling_params)
        self.scheduler.add(seq)

    def step(self):
        seqs, is_prefill = self.scheduler.schedule()
        token_ids = self.model_runner.call("run", seqs, is_prefill)
        self.scheduler.postprocess(seqs, token_ids)
        outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
        num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
        return outputs, num_tokens

    def is_finished(self):
        return self.scheduler.is_finished()

    def generate(
        self,
        prompts: list[str] | list[list[int]],
        sampling_params: SamplingParams | list[SamplingParams],
        use_tqdm: bool = True,
    ) -> list[str]:
        if use_tqdm:
            pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
        if not isinstance(sampling_params, list):
            sampling_params = [sampling_params] * len(prompts)
        for prompt, sp in zip(prompts, sampling_params):
            self.add_request(prompt, sp)
        outputs = {}
        prefill_throughput = decode_throughput = 0.
        while not self.is_finished():
            t = perf_counter()
            output, num_tokens = self.step()
            if use_tqdm:
                if num_tokens > 0:
                    prefill_throughput = num_tokens / (perf_counter() - t)
                else:
                    decode_throughput = -num_tokens / (perf_counter() - t)
                pbar.set_postfix({
                    "Prefill": f"{int(prefill_throughput)}tok/s",
                    "Decode": f"{int(decode_throughput)}tok/s",
                })
            for seq_id, token_ids in output:
                outputs[seq_id] = token_ids
                if use_tqdm:
                    pbar.update(1)
        outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
        outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
        if use_tqdm:
            pbar.close()
        return outputs

这段代码实现了一个轻量级且高效的大语言模型推理引擎的核心控制器。它的主要作用是充当整个文本生成任务的“总指挥”,对外提供了一个简单易用的批量文本生成接口,对内则完美封装并统筹了所有复杂的底层运行机制——包括文本数据的分词转换、多 GPU 协同的张量并行分布式计算、以及优化系统吞吐量与显存的连续批处理调度,从而驱动庞大的 AI 模型稳定、高效地完成从接收用户请求到最终输出生成文本的完整推理生命周期。

def __init__(self, model, **kwargs):
    config_fields = {field.name for field in fields(Config)}
    config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
    config = Config(model, **config_kwargs)
    self.ps = []
    self.events = []
    ctx = mp.get_context("spawn")
    for i in range(1, config.tensor_parallel_size):
        event = ctx.Event()
        process = ctx.Process(target=ModelRunner, args=(config, i, event))
        process.start()
        self.ps.append(process)
        self.events.append(event)
    self.model_runner = ModelRunner(config, 0, self.events)
    self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
    config.eos = self.tokenizer.eos_token_id
    self.scheduler = Scheduler(config)
    atexit.register(self.exit)

在初始化阶段,首先根据传入参数构造 Config 对象

  • fields(Config)dataclasses 提供的一个内置函数,它会返回 Config 类中定义的所有字段对象。借此获得Config类内所有的变量名。
  • config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}是通过遍历传入的kwargs字典,把在Config内的键值对保留下来
  • config = Config(model, **config_kwargs),然后就可以初始化对应的Config对象了

然后搭建张量并行的多进程分布式计算环境

  • self.ps用于记录所有创建的子进程,用self.events记录每个子进程的事件锁,这个事件我们前面介绍过了,这里不过多介绍

  • ctx = mp.get_context("spawn")设定进程启动模式,mp 是 Python 的多进程模块 multiprocessing。在 Linux 系统下,创建子进程默认的方式是 fork(直接复制当前进程状态)。但是,fork 和 PyTorch 的 CUDA 不兼容,如果带着初始化的 CUDA 上下文直接 fork,极易导致程序死锁、显存泄漏或直接崩溃。 spawn 模式会启动一个全新的、干净的 Python 解释器环境,再重新导入必要的模块。这是 PyTorch 官方强行规定在多 GPU 编程下必须使用的安全模式。

  • # 创建子进程和对应的事件锁,由于当前已经有一个进程了,所以只需要再创建`tensor_parallel_size-1`个即可
    for i in range(1, config.tensor_parallel_size):
        event = ctx.Event() # 新建一个事件锁
        process = ctx.Process(target=ModelRunner, args=(config, i, event)) # 创建子进程,告诉他去运行ModelRunner这个代码,参数是(config, i, event)
        process.start() # 正式在后台启动这个进程(此时子进程对应的 GPU 开始被分配显存)
    	# 将启动好的进程对象和它对应的事件对象,存放到建立好的列表里
        self.ps.append(process)
        self.events.append(event)
    
  • 然后再初始化一下主进程的ModelRunner,以及tokenizerscheduler

    self.model_runner = ModelRunner(config, 0, self.events)
    self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
    config.eos = self.tokenizer.eos_token_id
    self.scheduler = Scheduler(config)
    
    • 给当前的主进程创建一个model_runner对象,因为是主进程,所以要把前面的events列表传进去统一管理
    • AutoTokenizertransformers 库中极其强大的一个类,会根据你选择的模型,自动识别并实例化最匹配的分词器
    • .from_pretrained(config.model)是从预训练库中加载专属字典
    • use_fast=True会调用由 Rust 语言编写的 tokenizers 核心库,加快速度
    • 更新一下config.eos,然后创建调度器scheduler
  • atexit.register(self.exit)
    

    atexit 是 Python 的一个内置标准库(at exit 的缩写,意思是“在退出时”)。atexit.register(self.exit) 的核心作用是:向 Python 解释器立下一份“遗嘱”,无论这个程序是正常运行结束,还是被用户按 Ctrl+C 强行中断,在程序彻底死亡的前一秒,都必须无条件执行 self.exit 这个函数。

def exit(self):
    self.model_runner.call("exit")
    del self.model_runner
    for p in self.ps:
        p.join()

model_runner.call(method_name)函数是一个在ModelRunner类内调用函数名为method_name方法的方法,在ModelRunner我们已经定义好了exit()函数,所以这里是让主进程使用这个函数,优雅地关闭整个推理引擎及其相关的多进程资源

接下来执行 del self.model_runner,这一句的作用是显式删除当前进程中持有的 ModelRunner 对象引用,从而触发其析构逻辑或释放其占用的 Python 资源和底层 CUDA 资源,避免内存泄漏或残留的显存占用。

最后通过遍历 self.ps 中保存的子进程对象并调用 p.join(),主进程会阻塞等待每一个子进程完全结束运行。join() 的语义是“等待该进程终止”,这可以保证所有并行进程都已经正常退出后主程序才继续向下执行或结束,从而避免出现僵尸进程或资源未释放的问题。整体来看,这段代码实现的是一个标准的多进程优雅退出流程:先发送退出信号,再释放本地资源,最后等待所有子进程安全结束。

def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
    if isinstance(prompt, str):
        prompt = self.tokenizer.encode(prompt)
    seq = Sequence(prompt, sampling_params)
    self.scheduler.add(seq)

这个 add_request 方法的作用是把一个新的生成请求注册到引擎内部的调度系统中,为后续的批处理推理做准备。函数首先接收 prompt 和对应的 SamplingParams,其中 prompt 既可以是字符串,也可以是已经分词好的 token id 列表;如果传入的是字符串,就通过当前加载的 tokenizer 调用 encode 方法把自然语言文本转换成整数 token 序列,这一步实际上完成了从文本空间到模型输入空间的映射。接着使用处理后的 token 序列和采样参数构造一个 Sequence 对象,这个对象通常封装了该请求的全部状态信息,例如当前已有的 token、生成进度、是否完成、采样策略(如 temperature、top-p 等)以及后续解码所需的缓存信息。最后调用 self.scheduler.add(seq) 将这个 Sequence 加入调度器内部的等待队列中,由 Scheduler 统一管理多个请求的批次组织与执行顺序。

def step(self):
    seqs, is_prefill = self.scheduler.schedule()
    token_ids = self.model_runner.call("run", seqs, is_prefill)
    self.scheduler.postprocess(seqs, token_ids)
    outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
    num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
    return outputs, num_tokens

这个 step 方法是整个推理引擎的核心执行单元,每调用一次就推进所有正在生成的请求一步,相当于驱动一次批量前向计算。

首先第一行 seqs, is_prefill = self.scheduler.schedule() 是调度阶段,它从调度器中取出当前应该执行的一批 Sequence,并返回一个布尔值 is_prefill 表示当前阶段是 prefill 阶段 还是 decode 阶段。prefill 指的是把每个请求的完整 prompt 一次性送入模型建立 KV cache;decode 则是已经完成预填充后,每个序列只生成一个新 token 的阶段。调度器在这里负责动态批处理,把不同请求组织成一个可并行执行的小批次。

接着 token_ids = self.model_runner.call("run", seqs, is_prefill) 才是真正的模型执行阶段。这里通过 ModelRunner 发起一次 "run" 调用,把当前批次的序列和阶段信息传入底层推理模块。ModelRunner 会负责准备张量、构建 attention 输入、执行前向传播,并返回新生成的 token id。

然后 self.scheduler.postprocess(seqs, token_ids) 是后处理阶段。调度器根据返回的 token 更新每个 Sequence 的内部状态,例如追加新 token、判断是否遇到 EOS、检查是否达到最大长度、更新缓存映射等。如果某个序列生成完成,它会被标记为 finished,并在后续调度中移除。

接下来这一行:

outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]

会从当前批次中筛选出已经完成的序列,并返回它们的 seq_id 以及完整生成结果 completion_token_ids。注意这里只返回“刚刚完成”的请求,而不是所有请求。

最后一行:

num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)

这是为了统计吞吐量。如果当前是 prefill 阶段,就返回本批次处理的 token 总数(正值);如果是 decode 阶段,则返回生成的 token 数量(每个序列生成一个 token,所以是 len(seqs)),但用负号区分阶段。外层的 generate() 会根据正负号分别统计 prefill 吞吐率和 decode 吞吐率。

def generate(
        self,
        prompts: list[str] | list[list[int]],
        sampling_params: SamplingParams | list[SamplingParams],
        use_tqdm: bool = True,
    ) -> list[str]:
        if use_tqdm:
            pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
        if not isinstance(sampling_params, list):
            sampling_params = [sampling_params] * len(prompts)
        for prompt, sp in zip(prompts, sampling_params):
            self.add_request(prompt, sp)
        outputs = {}
        prefill_throughput = decode_throughput = 0.
        while not self.is_finished():
            t = perf_counter()
            output, num_tokens = self.step()
            if use_tqdm:
                if num_tokens > 0:
                    prefill_throughput = num_tokens / (perf_counter() - t)
                else:
                    decode_throughput = -num_tokens / (perf_counter() - t)
                pbar.set_postfix({
                    "Prefill": f"{int(prefill_throughput)}tok/s",
                    "Decode": f"{int(decode_throughput)}tok/s",
                })
            for seq_id, token_ids in output:
                outputs[seq_id] = token_ids
                if use_tqdm:
                    pbar.update(1)
        outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
        outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
        if use_tqdm:
            pbar.close()
        return outputs

这个 generate 函数是整个推理引擎对外暴露的核心接口,它封装了从“接收请求”到“返回生成文本”的完整生命周期,相当于一个批量推理调度控制器。下面我们按执行流程分段详细解释。

首先是初始化阶段。如果 use_tqdm=True,就创建一个进度条对象 pbar,总长度为 len(prompts),表示一共有多少条生成请求。desc="Generating" 是前缀文字,dynamic_ncols=True 表示自动适配终端宽度。这个进度条的设计单位是“完成的序列数”,而不是 token 数。接下来对 sampling_params 做一次类型归一化处理,如果用户只传入一个 SamplingParams,就复制成与 prompts 等长的列表,这样每个 prompt 都有对应的采样策略,便于统一调度。

然后进入“请求入队阶段”。通过

for prompt, sp in zip(prompts, sampling_params):
    self.add_request(prompt, sp)

逐个把请求加入调度器。add_request 内部会把字符串转成 token id,然后封装成 Sequence 对象,并加入 Scheduler 的等待队列。此时只是注册请求,还没有执行模型前向计算。

接下来初始化两个吞吐率统计变量 prefill_throughputdecode_throughput,并进入主循环:

while not self.is_finished():

只要调度器中还有未完成的序列,就不断调用 self.step() 推动生成流程。每一轮都会记录开始时间 t = perf_counter(),然后执行一次 step()step() 内部会完成一次调度 → 模型前向 → 状态更新 → 返回已完成序列。

step() 返回两个值:

  • output:当前这一轮刚刚完成的序列结果(可能为 0 条或多条)
  • num_tokens:本轮处理的 token 数量(正数表示 prefill 阶段,负数表示 decode 阶段)

接下来是吞吐率统计逻辑。如果开启了进度条,就根据 num_tokens 的正负来区分阶段:

  • 如果 num_tokens > 0,说明当前是 prefill 阶段,表示一次性处理了若干 prompt token,此时吞吐率 = 处理 token 数 / 耗时。
  • 如果 num_tokens < 0,说明是 decode 阶段,因为 decode 是每个序列生成一个 token,所以返回的是 -len(seqs),这里取负号计算实际生成 token 数。

然后用 pbar.set_postfix() 动态更新进度条右侧的统计信息,例如:

Generating:  50%|██████     | 5/10 [00:03<00:03, Prefill=12000tok/s, Decode=45tok/s]

这部分的设计其实是在实时区分“构建 KV cache 的吞吐”和“逐 token 解码的吞吐”,因为两者性能特征完全不同。

然后是结果收集阶段:

for seq_id, token_ids in output:
    outputs[seq_id] = token_ids

这里只处理“刚刚完成”的序列。因为有些序列可能早完成,有些晚完成,所以用字典 outputsseq_id 作为键保存结果,避免顺序错乱。如果使用进度条,每完成一个序列就 pbar.update(1),表示整体任务进度前进一格。

while 循环结束时,说明所有序列都生成完成。接下来做结果整理:

outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]

这一步是按 seq_id 排序,恢复原始输入顺序,因为调度器内部可能是乱序执行的。

然后进行解码:

outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]

把 token id 列表转换成可读文本,同时保留原始 token id,最终返回的是一个字典列表,每个元素包含:

{
    "text": "...生成的文本...",
    "token_ids": [...]
}

 

博客内容均系原创,未经允许严禁转载!
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇