block_manager.py
class Block:
def __init__(self, block_id):
self.block_id = block_id # 初始化块的唯一id
self.ref_count = 0 # 当前有多少序列在共享该块
self.hash = -1 # 初始化哈希值为-1
self.token_ids = [] #当前物理块存的token id的列表
def update(self, hash: int, token_ids: list[int]): # update 用来把一个已经算好的 block 标记为“可复用”
self.hash = hash # 更新哈希值
self.token_ids = token_ids # 更新token id
def reset(self): # reset 则表示这是一个刚被分配出来的全新 block,此时 ref_count 设为 1,表示立刻被一个序列占用,在 reset 的那一刻,KV cache 还没算完,甚至可能还没算
self.ref_count = 1 # 表示被一个序列占用
self.hash = -1 # 初始化哈希,-1表示现在还不能被别人复用
self.token_ids = [] # 初始化token ids为空,表示现在里面的内容还不完整 / 不稳定
Block 表示的是一段固定大小的 KV cache 存储单元,但它并不是“某个序列的 block”,而是一个可以在多个序列之间流转、被共享、被回收的缓存对象。
一开始,__init__ 给这个 block 一个永远不变的身份 block_id,然后把它放在一个“完全未使用”的初始状态。此时 ref_count = 0 的含义非常明确:现在没有任何序列在使用它;hash = -1 并不是说它“算错了 hash”,而是明确标记“这个 block 当前不具备任何可复用语义”;token_ids = [] 只是一个占位,用来表示这里还没有被写入稳定内容。
reset 和 update,对应了 block 生命周期中两个完全不同、而且不能颠倒顺序的阶段。
reset 发生在一个 block 从“空闲态”进入“被某个序列占用”的那一刻。它的语义不是“清空”,而是重新初始化为一个正在使用中的 block。所以它第一件事就是把 ref_count 设为 1,这表示:从现在开始,已经有且仅有一个序列持有这个 block 的引用了;与此同时,它把 hash 重新置为 -1,明确宣告“这个 block 当前不可被共享”,哪怕它之后会逐步被写入 token;而 token_ids 被清空,则是因为在这一刻,block 内部的数据还处在“生成中”,不具备稳定含义。这一步非常重要,因为它从结构上防止了“未完成 block 被误加入共享池”的可能。
随后,在生成或初始化过程中,这个 block 会逐 token 填充 KV cache,但只要它还没被填满,它在逻辑上就始终是“私有的”。这一阶段,Block 本身并不关心 token 是如何写入的,它只通过 hash = -1 这一状态,告诉外界:我还没准备好被别人用。
update 则发生在完全不同的时间点:当一个 block 的 token 数量恰好等于 block_size,并且它前面的所有前缀 block 都已经有了稳定 hash 之后,系统才会调用 update。这一步的含义可以精确地说成:“这个 block 的内容已经固定下来,不会再变化,可以作为一个可复用的缓存单元被登记和共享。” 因此,update 会写入一个确定的 hash,这个 hash 把“前缀上下文 + 当前 block token”压缩成一个唯一指纹;同时保存 token_ids,并不是为了参与计算,而是作为一个安全校验,用来防止哈希冲突带来的错误复用。注意这里一个非常重要但容易被忽略的细节:update 并不修改 ref_count,因为 block 在被 reset 的那一刻,就已经被某个序列合法持有了,update 只是改变了它的“共享资格”,而不是它的“占用关系”。
如果把这三个方法连起来看,可以发现 Block 并不负责“管理谁在用我”,它只是维护三个状态量,让 BlockManager 能够在外部做出正确决策。ref_count 表示当前有多少序列在逻辑上引用这个 block;hash 表示这个 block 是否已经稳定、是否可以被纳入前缀共享体系;而 token_ids 则是一个辅助信息,用来保证复用的正确性。
class BlockManager:
def __init__(self, num_blocks: int, block_size: int):
self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()
self.free_block_ids: deque[int] = deque(range(num_blocks))
self.used_block_ids: set[int] = set()
@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
def _allocate_block(self, block_id: int) -> Block:
block = self.blocks[block_id]
assert block.ref_count == 0
block.reset()
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id)
return self.blocks[block_id]
def _deallocate_block(self, block_id: int) -> Block:
assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
def can_allocate(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= seq.num_blocks
def allocate(self, seq: Sequence):
assert not seq.block_table
h = -1
cache_miss = False
for i in range(seq.num_blocks):
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
if cache_miss:
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
else:
seq.num_cached_tokens += self.block_size
if block_id in self.used_block_ids:
block = self.blocks[block_id]
block.ref_count += 1
else:
block = self._allocate_block(block_id)
if h != -1:
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
seq.block_table.append(block_id)
def deallocate(self, seq: Sequence):
for block_id in reversed(seq.block_table):
block = self.blocks[block_id]
block.ref_count -= 1
if block.ref_count == 0:
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.block_table.clear()
def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def may_append(self, seq: Sequence):
block_table = seq.block_table
last_block = self.blocks[block_table[-1]]
if len(seq) % self.block_size == 1:
assert last_block.hash != -1
block_id = self.free_block_ids[0]
self._allocate_block(block_id)
block_table.append(block_id)
elif len(seq) % self.block_size == 0:
assert last_block.hash == -1
token_ids = seq.block(seq.num_blocks-1)
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
h = self.compute_hash(token_ids, prefix)
last_block.update(h, token_ids)
self.hash_to_block_id[h] = last_block.block_id
else:
assert last_block.hash == -1
BlockManager 是一个“运行中的 KV cache 调度器” ,它的核心职责不是算 hash、也不是存 token,而是在“序列生命周期变化”的过程中,持续维护 block 的占用、共享与回收这三件事之间的一致性。
def __init__(self, num_blocks: int, block_size: int):
self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()
self.free_block_ids: deque[int] = deque(range(num_blocks))
self.used_block_ids: set[int] = set()
BlockManager 在被初始化的时候,首先确定了块的大小,这是我们 kv cache 共享的最小粒度。接着一次性创建了固定数量的 Block 对象,数量是num_blocks个,block 的数量是有限的且无法动态增长的资源
初始化阶段构建了三套“账本”:blocks 保存所有 block 的实体;free_block_ids 记录当前完全空闲、可以被 reset 并写入新内容的 block;used_block_ids 记录当前至少被一个序列引用的 block;而 hash_to_block_id 则是连接“内容语义”和“物理 block”的桥梁,它只服务于一件事:给定一个确定的前缀上下文和 token 内容,能否直接定位到一个已经存在的 block。
@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
compute_hash函数是一个计算哈希值的函数,给定前缀上下文的哈希值,以及当前输入的token,计算前缀上下文和当前token组合起来的哈希值。
其逻辑是,如果prefix非-1,则代表是存在上下文的,需要先把prefix扔到初始化的一个xhash64 的哈希对象里。然后把传进来的token ids转成一个连续的整数数组,然后用tobytes()得到稳定紧凑的二进制表示,然后喂给当前的哈希对象。
相当于当前block的hash = f(前一个block的hash + 当前block的token的内容),而前一个block 的哈希可能还是拼接的前前block,这是一种链式哈希,这和我们的目的是一样的:KV cache 的共享不仅要求这个 block 的 token 一样,还要求“它之前的上下文也一样。同样的 token 序列,一定会产生完全一致的字节表示,从而保证 hash 的确定性和可复现性。
返回的是h.intdigest(),是一个64位整数,作为我们最终的哈希值
我们有三个序列
A:[1, 2, 3, 4 | 5, 6, 7, 8]
B:[1, 2, 3, 4 | 5, 6, 7, 8 | 9, 10]
C:[0, 2, 3, 4 | 5, 6, 7, 8 | 9, 10]
对于序列A的第一块 ,第二块
对于序列B的第一块 ,第二块,第三块,可以发现序列B的前两块内容和A完全相同,这个时候就可以复用A的KV cache,只需要计算和存储后面的内容
对于序列C,从第一个token开始就和A和B对不上,就完全无法复用他们的KV cache,只能重新计算并存储
def _allocate_block(self, block_id: int) -> Block:
block = self.blocks[block_id]
assert block.ref_count == 0
block.reset()
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id)
return self.blocks[block_id]
这个 _allocate_block 函数是一个典型的底层资源管理器中的私有方法,其核心任务是将一个处于“空闲”状态的数据块正式标记为“已占用”并准备好供外部使用。代码首先通过传入的 block_id 在预先分配好的 self.blocks 数组或字典中定位到具体的 Block 对象,紧接着执行一个关键的安全检查:利用 assert block.ref_count == 0 确保该块当前的引用计数确实为零,这是一种防错机制,防止程序意外覆盖掉仍在被引用的活跃数据。
def _deallocate_block(self, block_id: int) -> Block:
assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
这个函数 _deallocate_block 是内存管理或缓存系统中一个典型的资源回收机制,主要负责将一个不再被使用的物理块从“已占用”状态归还到“可分配”状态。在代码执行之初,assert self.blocks[block_id].ref_count == 0 起到了核心的防御性编程作用,它强制要求该块的引用计数必须为零;如果此时还有其他程序指向这个块,断言就会触发报错,防止因误删正在使用的数据而导致系统崩溃或数据损坏。
def allocate(self, seq: Sequence):
assert not seq.block_table
h = -1
cache_miss = Talse
for i in range(seq.num_blocks):
token_ids = seq.block[i]
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or token_ids != self.blocks[block_id].token_ids:
cache_miss = False
if cache_miss :
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
else :
seq.num_cached_tokens += self.block_size
if block_id in self.used_block_ids :
block = self.blocks[block_id]
block.ref_count += 1
else:
block = self._allocate_block(block_id)
if h != -1:
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
seq.block_table.append(block_id)
这段代码是 BlockManager 的核心,负责为一个新的序列分配物理内存块,并在这个过程中尽可能地复用已有的缓存。
首先,assert not seq.block_table 确保该序列是一个全新的、尚未分配任何物理块的请求。
接着,初始化哈希值 h = -1 和一个关键的布尔标记 cache_miss = False。
进入循环后,它按顺序遍历序列所需的所有块。对于每一个块,首先取出其对应的 token_ids。如果这个块是满的(长度等于 block_size),就调用 compute_hash 结合之前的哈希值 h 计算出一个新的链式哈希;如果是不满的残缺块,则将 h 置为 -1。因为,h=-1表示的是,可以复用这个块,如果是残缺的,我们认为它不可以被复用,所以要让h为-1。
随后,程序尝试从 hash_to_block_id 字典中寻找这个哈希值对应的 block_id。如果找不到(即等于 -1),或者虽然找到了 ID 但物理块内存储的 token 与当前不一致(这里是双重保险,防止哈希冲突),则将 cache_miss 设为真,意味着从这一块开始,后面的内容都无法命中缓存了。
在处理具体分配时,如果 cache_miss 为真,说明没有可复用的缓存,系统直接从 free_block_ids 队列的头部取出一个最老的空闲块,并调用内部方法 _allocate_block 进行物理分配。反之,如果命中缓存,系统会增加序列记录的已缓存 token 数量 seq.num_cached_tokens,并检查该物理块的状态:如果该块当前正在被其他序列使用(存在于 used_block_ids 中),则简单地将该块的引用计数 ref_count 加 1,实现内存共享;如果该块虽然在哈希表里但目前处于空闲状态,则调用 _allocate_block 重新激活它。处理完分配后,如果当前块是一个有效的满块(h != -1),系统会调用 block.update 更新该物理块的内部指纹,并在全局哈希映射表中登记或更新这个哈希值与物理 ID 的对应关系。最后,将确定好的物理块 ID 添加到当前序列的 block_table 中,完成逻辑块到物理块的映射。
def deallocate(self, seq: Sequence):
for block_id in reversed(seq.block_table):
block = self.blocks[block_id]
block.ref_count -= 1
if block.ref_count == 0:
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.block_table.clear()
这段代码负责执行物理内存的回收逻辑,其核心机制是基于引用计数(Reference Counting)来确保安全释放。程序首先通过 reversed(seq.block_table) 倒序遍历序列所占用的物理块 ID 列表,采用倒序通常是因为序列末尾的块更有可能是独占的,从而能更快地腾出连续空间。对于每一个 block_id,系统会找到对应的物理块对象并将其 ref_count 引用计数减 1。这一步至关重要,因为在支持前缀缓存共享的机制下,一个物理块可能同时被多个序列引用,直接删除会导致其他序列的计算出错。只有当引用计数减到 0 时,说明没有任何序列再使用该块,系统才会调用 _deallocate_block 真正将其移出已使用集合并放回空闲块队列中,使其能够被后续的新请求重新分配。在所有块处理完毕后,代码会将该序列记录的缓存 Token 计数清零并清空其逻辑块表,彻底断开序列与物理内存的关联。
def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
该函数是用来判断,是否有足够的空闲物理块来支持序列继续增长。判断当前序列长度是否刚刚溢出到了一个新的块,如果是,则需要消耗一个新的块,需要检查空闲块列表的长度是否大于等于1。
def may_append(self, seq: Sequence):
block_table = seq.block_table
last_block = self.blocks[block_table[-1]]
if len(seq) % self.block_size == 1:
assert last_block.hash != -1
block_id = self.free_block_ids[0]
self._allocate_block(block_id)
block_table.append(block_id)
elif len(seq) % self.block_size == 0:
assert last_block.hash == -1
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
token_ids = seq.block(seq.num_blocks - 1)
h = self.compute_hash(token_ids, prefix)
last_block.update(h, token_ids)
self.hash_to_block_id[h] = last_block.block_id
else:
assert last_block.hash == -1
这段代码是根据序列当前生成的长度来自动管理其底层的物理内存块的状态,主要负责处理分配新块和封存满块这两个关键时间节点。当序列长度刚刚超过块的容量,也就是余数等于1的时候,此时需要开一个新块来存,首先要判断一下上一个块的哈希是否非-1,非-1代表他已经满了,可以被复用,然后申请一个新的物理块来存当前生成的新值,并更新block_table;当序列长度刚好是块容量的整数倍,即余数是0的时候,此时我们需要封存当前块,判断一下当前的块的哈希值是否是-1,代表他还没被封存,我们现在给他封存,要计算其哈希,需要上下文的哈希,通过上一个块的哈希即可获得,token_ids可以通过seq来获得,计算对应的哈希值,更新last_block的哈希和token_ids,由于更新过哈希值,可以复用了,我们就要更新一下哈希表;对于其他模数,我们跳过就行