sampling_params.py
from dataclasses import dataclass
@dataclass
class SamplingParams:
temperature: float = 1.0
max_tokens: int = 64
ignore_eos: bool = False
def __post_init__(self):
assert self.temperature > 1e-10, "greedy sampling is not premitted"
SampllingParams类
-
temperature-
定义:温度系数,控制生存随机性的核心参数
-
作用:在 softmax 归一化之前,将模型输出的 Logits 除以
:模型变得非常确定(“保守”),倾向于只选概率最高的词。 :模型保持原始的概率分布。 :概率分布变得“平滑”,模型会变得更“大胆”,增加输出的多样性,但也容易胡言乱语。
-
-
max_tokens- 定义:单次推理最多生成的 token 数量
- 作用:这是一个“硬计数”限制。一旦生成的 token 数量达到 64,无论句子是否写完,推理引擎都会强制停止。这有助于防止模型陷入无限循环或产生过长的响应。
-
ignore_eos-
定义:是否忽略结束符
-
作用:
- 如果是Flase,模型一旦预测出结束EOS标记,输出就停止了
- 如果是True,即使模型输出了EOS标记,也无法停止,会一直生成到max_tokens为止,通常用于压力测试或者强制生产特定长度的文本
-
assert self.temperature > 1e-10, "greedy sampling is not permitted",设计的时候强制要求温度系数不能小于1e-10,即不支持纯贪婪搜索