sampling_params.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. """Sampling parameters for text generation."""
  2. import copy
  3. from enum import IntEnum
  4. from functools import cached_property
  5. from typing import Callable, List, Optional, Union
  6. import torch
  7. _SAMPLING_EPS = 1e-5
  8. class SamplingType(IntEnum):
  9. GREEDY = 0
  10. RANDOM = 1
  11. BEAM = 2
  12. # We also accept KoboldAI's sampler IDs and convert to strings
  13. _sampler_map = {
  14. 0: "topk",
  15. 1: "topa",
  16. 2: "topp",
  17. 3: "tfs",
  18. 4: "typ",
  19. 5: "temp",
  20. 6: "pens",
  21. }
  22. LogitsProcessorFunc = Callable[[torch.Tensor, List[List[int]]], None]
  23. """LogitsProcessorFunc takes a logits tensor and corresponding lists of
  24. previously generated output tokens, and modifies the logits tensor."""
  25. class SamplingParams:
  26. """Sampling parameters for text generation.
  27. Overall, we follow the sampling parameters from the OpenAI text completion
  28. API (https://platform.openai.com/docs/api-reference/completions/create).
  29. In addition, we support multiple additional samplers which are not supported
  30. by OpenAI.
  31. Args:
  32. n: Number of output sequences to return for the given prompt.
  33. best_of: Number of output sequences that are generated from the prompt.
  34. From these `best_of` sequences, the top `n` sequences are returned.
  35. `best_of` must be greater than or equal to `n`. This is treated as
  36. the beam width when `use_beam_search` is True. By default, `best_of`
  37. is set to `n`.
  38. presence_penalty: Float that penalizes new tokens based on whether they
  39. appear in the generated text so far. Values > 0 encourage the model
  40. to use new tokens, while values < 0 encourage the model to repeat
  41. tokens.
  42. frequency_penalty: Float that penalizes new tokens based on their
  43. frequency in the generated text so far. Values > 0 encourage the
  44. model to use new tokens, while values < 0 encourage the model to
  45. repeat tokens.
  46. repetition_penalty: Float that penalizes new tokens based on their
  47. frequency in the generated text so far.
  48. freq_pen is applied additively while
  49. rep_pen is applied multiplicatively.
  50. Must be in [1, inf). Set to 1 to disable the effect.
  51. temperature: Float that controls the randomness of the sampling. Lower
  52. values make the model more deterministic, while higher values make
  53. the model more random. Zero means greedy sampling.
  54. top_p: Float that controls the cumulative probability of the top tokens
  55. to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
  56. top_k: Integer that controls the number of top tokens to consider. Set
  57. to -1 to consider all tokens.
  58. top_a: Float that controls the cutoff for Top-A sampling.
  59. Exact cutoff is top_a*max_prob**2. Must be in [0,inf], 0 to disable.
  60. min_p: Float that controls the cutoff for min-p sampling.
  61. Exact cutoff is min_p*max_prob. Must be in [0,1], 0 to disable.
  62. tfs: Float that controls the cummulative approximate curvature of the
  63. distribution to retain for Tail Free Sampling.
  64. Must be in (0, 1]. Set to 1 to disable
  65. eta_cutoff: Float that controls the cutoff treshold for Eta sampling
  66. (a form of entropy adaptive truncation sampling)
  67. treshold is computed as min(eta, sqrt(eta)*entropy(probs)).
  68. Specified in units of 1e-4. Set to 0 to disable
  69. epsilon_cutoff: Float that controls the cutoff treshold for
  70. Epsilon sampling (simple probability treshold truncation).
  71. Specified in units of 1e-4. Set to 0 to disable.
  72. typical_p: Float that controls the cumulative probability of tokens
  73. closest in surprise to the expected surprise to consider.
  74. Must be in (0, 1]. Set to 1 to disable.
  75. mirostat_mode: Can either be 0 (disabled) or 2 (Mirostat v2).
  76. mirostat_tau: Target "surprisal" that mirostat works towards.
  77. Range [0, inf).
  78. mirostat_eta: Rate at which mirostat updates its internal surprisal
  79. value. Range [0, inf).
  80. dynatemp_range: The range to use for dynamic temperature. When used,
  81. the actual temperature is allowed to be automatically adjusted
  82. dynamically between DynaTemp ± DynaTempRange. For example,
  83. setting `temperature=0.4` and `dynatemp_range=0.1` will result
  84. in a minimum temp of 0.3 and max of 0.5.
  85. dynatemp_exponent: Exponent for dynatemp sampling. Range [0, inf).
  86. sampler_order: List of lists specifying the order in which samplers are applied.
  87. All samplers in a sublist are applied in parallel, and the results are combined.
  88. Combinator is hardcoded to be "and" for now.
  89. Samplers are specified as strings, see "sampler.py" for sampler codes.
  90. smoothing_factor: Smoothing factor for Quadratic Sampling.
  91. use_beam_search: Whether to use beam search instead of sampling.
  92. length_penalty: Float that penalizes sequences based on their length.
  93. Used in beam search.
  94. early_stopping: Controls the stopping condition for beam search. It
  95. accepts the following values: `True`, where the generation stops as
  96. soon as there are `best_of` complete candidates; `False`, where an
  97. heuristic is applied and the generation stops when is it very
  98. unlikely to find better candidates; `"never"`, where the beam search
  99. procedure only stops when there cannot be better candidates
  100. (canonical beam search algorithm).
  101. stop: List of strings that stop the generation when they are generated.
  102. The returned output will not contain the stop strings.
  103. stop_token_ids: List of tokens that stop the generation when they are
  104. generated. The returned output will contain the stop tokens unless
  105. the stop tokens are sepcial tokens.
  106. include_stop_str_in_output: Whether to include the stop strings in
  107. output text. Defaults to False.
  108. ignore_eos: Whether to ignore the EOS token and continue generating
  109. tokens after the EOS token is generated.
  110. max_tokens: Maximum number of tokens to generate per output sequence.
  111. logprobs: Number of log probabilities to return per output token.
  112. Note that the implementation follows the OpenAI API: The return
  113. result includes the log probabilities on the `logprobs` most likely
  114. tokens, as well the chosen tokens. The API will always return the
  115. log probability of the sampled token, so there may be up to
  116. `logprobs+1` elements in the response.
  117. prompt_logprobs: Number of log probabilities to return per prompt token.
  118. custom_token_bans: List of token IDs to ban from generating
  119. skip_special_tokens: Whether to skip special tokens in the output.
  120. defaults to true.
  121. spaces_between_special_tokens: Whether to add spaces between special
  122. tokens in the output. Defaults to True.
  123. logits_processors: List of LogitsProcessors to change the probability
  124. of token prediction at runtime.
  125. """
  126. def __init__(
  127. self,
  128. n: int = 1,
  129. best_of: Optional[int] = None,
  130. presence_penalty: float = 0.0,
  131. frequency_penalty: float = 0.0,
  132. repetition_penalty: float = 1.0,
  133. temperature: float = 1.0,
  134. top_p: float = 1.0,
  135. top_k: int = -1,
  136. top_a: float = 0.0,
  137. min_p: float = 0.0,
  138. tfs: float = 1.0,
  139. eta_cutoff: float = 0.0,
  140. epsilon_cutoff: float = 0.0,
  141. typical_p: float = 1.0,
  142. mirostat_mode: int = 0,
  143. mirostat_tau: float = 0,
  144. mirostat_eta: float = 0,
  145. dynatemp_range: float = 0,
  146. dynatemp_exponent: float = 1,
  147. sampler_order: List[List[str]] = None,
  148. smoothing_factor: float = 0.0,
  149. use_beam_search: bool = False,
  150. length_penalty: float = 1.0,
  151. early_stopping: Union[bool, str] = False,
  152. stop: Union[None, str, List[str]] = None,
  153. stop_token_ids: Optional[List[int]] = None,
  154. include_stop_str_in_output: bool = False,
  155. ignore_eos: bool = False,
  156. max_tokens: Optional[int] = 16,
  157. logprobs: Optional[int] = None,
  158. prompt_logprobs: Optional[int] = None,
  159. custom_token_bans: Optional[List[int]] = None,
  160. skip_special_tokens: bool = True,
  161. spaces_between_special_tokens: bool = True,
  162. logits_processors: Optional[List[LogitsProcessorFunc]] = None,
  163. ) -> None:
  164. self.n = n
  165. self.best_of = best_of if best_of is not None else n
  166. self.presence_penalty = presence_penalty
  167. self.frequency_penalty = frequency_penalty
  168. self.repetition_penalty = repetition_penalty
  169. self.temperature = temperature
  170. self.top_p = top_p
  171. self.top_k = top_k
  172. self.top_a = top_a
  173. self.min_p = min_p
  174. self.tfs = tfs
  175. self.eta_cutoff = eta_cutoff
  176. self.epsilon_cutoff = epsilon_cutoff
  177. self.typical_p = typical_p
  178. self.mirostat_mode = mirostat_mode
  179. self.mirostat_tau = mirostat_tau
  180. self.mirostat_eta = mirostat_eta
  181. self.dynatemp_range = dynatemp_range
  182. self.dynatemp_exponent = dynatemp_exponent
  183. self.sampler_order = sampler_order
  184. self.smoothing_factor = smoothing_factor
  185. self.use_beam_search = use_beam_search
  186. self.length_penalty = length_penalty
  187. self.early_stopping = early_stopping
  188. if stop is None:
  189. self.stop = []
  190. elif isinstance(stop, str):
  191. self.stop = [stop]
  192. else:
  193. self.stop = list(stop)
  194. self.stop_token_ids = stop_token_ids or []
  195. self.ignore_eos = ignore_eos
  196. self.max_tokens = max_tokens
  197. self.logprobs = logprobs
  198. self.prompt_logprobs = prompt_logprobs
  199. self.custom_token_bans = custom_token_bans or []
  200. self.skip_special_tokens = skip_special_tokens
  201. self.spaces_between_special_tokens = spaces_between_special_tokens
  202. self.logits_processors = logits_processors or []
  203. self.include_stop_str_in_output = include_stop_str_in_output
  204. if not self.sampler_order:
  205. self.sampler_order = [
  206. "pens", "temp", "miro", "typ", "quad", "tfs", "minp", "eta",
  207. "topa", "topp", "eps", "topk"
  208. ]
  209. self.sampler_order = [[s] if
  210. (isinstance(s, str) or isinstance(s, int)) else s
  211. for s in self.sampler_order]
  212. self.sampler_order = [[
  213. _sampler_map[s] if isinstance(s, int) else s for s in sub
  214. ] for sub in self.sampler_order]
  215. self.verify()
  216. def verify(self) -> None:
  217. self._verify_args()
  218. if self.use_beam_search:
  219. self._verify_beam_search()
  220. else:
  221. self._verify_non_beam_search()
  222. if self.temperature < _SAMPLING_EPS:
  223. # Zero temperature means greedy sampling.
  224. self.top_p = 1.0
  225. self.top_k = -1
  226. self.min_p = 0.0
  227. self.top_a = 0.0
  228. self._verify_greedy_sampling()
  229. def _verify_args(self) -> None:
  230. if self.n < 1:
  231. raise ValueError(f"n must be at least 1, got {self.n}.")
  232. if self.best_of < self.n:
  233. raise ValueError(f"best_of must be greater than or equal to n, "
  234. f"got n={self.n} and best_of={self.best_of}.")
  235. if not -2.0 <= self.presence_penalty <= 2.0:
  236. raise ValueError("presence_penalty must be in [-2, 2], got "
  237. f"{self.presence_penalty}.")
  238. if not -2.0 <= self.frequency_penalty <= 2.0:
  239. raise ValueError("frequency_penalty must be in [-2, 2], got "
  240. f"{self.frequency_penalty}.")
  241. if self.repetition_penalty < 1.0:
  242. raise ValueError("repetition_penalty must be in [1, inf), got "
  243. f"{self.repetition_penalty}.")
  244. if self.temperature < 0.0:
  245. raise ValueError(
  246. f"temperature must be non-negative, got {self.temperature}.")
  247. if not 0.0 < self.top_p <= 1.0:
  248. raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
  249. if self.top_k < -1 or self.top_k == 0:
  250. raise ValueError(f"top_k must be -1 (disable), or at least 1, "
  251. f"got {self.top_k}.")
  252. if self.top_a < 0:
  253. raise ValueError(f"top_a must be non negative, got {self.top_a}.")
  254. if not 0.0 <= self.min_p <= 1.0:
  255. raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
  256. if not 0.0 < self.tfs <= 1.0:
  257. raise ValueError(f"tfs must be in (0, 1], got {self.tfs}.")
  258. if self.epsilon_cutoff < 0.0 or self.epsilon_cutoff > 1000.0:
  259. raise ValueError("epsilon_cutoff must be in [0, 1000], got "
  260. f"{self.epsilon_cutoff}.")
  261. # pylint: disable=unneeded-not
  262. if not self.eta_cutoff >= 0:
  263. raise ValueError(
  264. f"eta_cutoff must be non negative, got {self.eta_cutoff}.")
  265. if not 0.0 <= self.typical_p <= 1.0:
  266. raise ValueError(
  267. f"typical_p must be in (0, 1], got {self.typical_p}.")
  268. if not self.dynatemp_range >= 0:
  269. raise ValueError("dynatemp_range must be non negative, got "
  270. f"{self.dynatemp_range}.")
  271. if not self.dynatemp_exponent >= 0:
  272. raise ValueError(f"dynatemp_exponent must be non negative, got "
  273. f"{self.dynatemp_exponent}.")
  274. if not self.smoothing_factor >= 0:
  275. raise ValueError(f"smoothing_factor must be non negative, got "
  276. f"{self.smoothing_factor}.")
  277. if self.mirostat_mode:
  278. if not self.mirostat_mode == 2:
  279. raise ValueError(
  280. "Only Mirostat v2 (2) and disabled (0) supported, "
  281. f"got {self.mirostat_mode}")
  282. if not self.mirostat_eta >= 0:
  283. raise ValueError(
  284. f"mirostat_eta must be positive, got {self.mirostat_eta}")
  285. if not self.mirostat_tau >= 0:
  286. raise ValueError(
  287. f"mirostat_tau must be positive, got {self.mirostat_tau}")
  288. if self.max_tokens is not None and self.max_tokens < 1:
  289. raise ValueError(
  290. f"max_tokens must be at least 1, got {self.max_tokens}.")
  291. if self.logprobs is not None and self.logprobs < 0:
  292. raise ValueError(
  293. f"logprobs must be non-negative, got {self.logprobs}.")
  294. if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
  295. raise ValueError("prompt_logprobs must be non-negative, got "
  296. f"{self.prompt_logprobs}.")
  297. for subgroup in self.sampler_order:
  298. if len(subgroup) > 1:
  299. if any([s in ["temp", "pens", "miro"] for s in subgroup]):
  300. raise ValueError("temp, pens and miro must be alone"
  301. f"in their own subgroup, got {subgroup}")
  302. def _verify_beam_search(self) -> None:
  303. if self.best_of == 1:
  304. raise ValueError("best_of must be greater than 1 when using beam "
  305. f"search. Got {self.best_of}.")
  306. if self.temperature > _SAMPLING_EPS:
  307. raise ValueError("temperature must be 0 when using beam search.")
  308. if self.top_p < 1.0 - _SAMPLING_EPS:
  309. raise ValueError("top_p must be 1 when using beam search.")
  310. if self.top_k != -1:
  311. raise ValueError("top_k must be -1 when using beam search.")
  312. if self.early_stopping not in [True, False, "never"]:
  313. raise ValueError(
  314. f"early_stopping must be True, False, or 'never', "
  315. f"got {self.early_stopping}.")
  316. def _verify_non_beam_search(self) -> None:
  317. if self.early_stopping is not False:
  318. raise ValueError("early_stopping is not effective and must be "
  319. "False when not using beam search.")
  320. if (self.length_penalty < 1.0 - _SAMPLING_EPS
  321. or self.length_penalty > 1.0 + _SAMPLING_EPS):
  322. raise ValueError(
  323. "length_penalty is not effective and must be the "
  324. "default value of 1.0 when not using beam search.")
  325. def _verify_greedy_sampling(self) -> None:
  326. if self.best_of > 1:
  327. raise ValueError("best_of must be 1 when using greedy sampling."
  328. f"Got {self.best_of}.")
  329. if self.top_p < 1.0 - _SAMPLING_EPS:
  330. raise ValueError("top_p must be 1 when using greedy sampling.")
  331. if self.top_k != -1:
  332. raise ValueError("top_k must be -1 when using greedy sampling.")
  333. @cached_property
  334. def sampling_type(self) -> SamplingType:
  335. if self.use_beam_search:
  336. return SamplingType.BEAM
  337. if self.temperature < _SAMPLING_EPS:
  338. return SamplingType.GREEDY
  339. return SamplingType.RANDOM
  340. def clone(self) -> "SamplingParams":
  341. """Deep copy excluding LogitsProcessors objects.
  342. LogitsProcessor objects are excluded because they
  343. may contain an arbitrary, nontrivial amount of
  344. data.
  345. """
  346. logit_processor_refs = None if self.logits_processors is None else {
  347. id(lp): lp
  348. for lp in self.logits_processors
  349. }
  350. return copy.deepcopy(self, memo=logit_processor_refs)
  351. def __repr__(self) -> str:
  352. return (f"SamplingParams(n={self.n}, "
  353. f"best_of={self.best_of}, "
  354. f"presence_penalty={self.presence_penalty}, "
  355. f"frequency_penalty={self.frequency_penalty}, "
  356. f"repetition_penalty={self.repetition_penalty}, "
  357. f"temperature={self.temperature}, "
  358. f"top_p={self.top_p}, "
  359. f"top_k={self.top_k}, "
  360. f"top_a={self.top_a}, "
  361. f"min_p={self.min_p}, "
  362. f"tfs={self.tfs}, "
  363. f"eta_cutoff={self.eta_cutoff}, "
  364. f"epsilon_cutoff={self.epsilon_cutoff}, "
  365. f"typical_p={self.typical_p}, "
  366. f"mirostat_mode={self.mirostat_mode}, "
  367. f"mirostat_tau={self.mirostat_tau}, "
  368. f"mirostat_eta={self.mirostat_eta}, "
  369. f"dynatemp_range={self.dynatemp_range}, "
  370. f"dynatemp_exponent={self.dynatemp_exponent}, "
  371. f"sampler_order={self.sampler_order}, "
  372. f"smoothing_factor={self.smoothing_factor}, "
  373. f"use_beam_search={self.use_beam_search}, "
  374. f"length_penalty={self.length_penalty}, "
  375. f"early_stopping={self.early_stopping}, "
  376. f"stop={self.stop}, "
  377. f"stop_token_ids={self.stop_token_ids}, "
  378. "include_stop_str_in_output="
  379. f"{self.include_stop_str_in_output}, "
  380. f"ignore_eos={self.ignore_eos}, "
  381. f"max_tokens={self.max_tokens}, "
  382. f"custom_token_bans={self.custom_token_bans}, "
  383. f"logprobs={self.logprobs}, "
  384. f"prompt_logprobs={self.prompt_logprobs}, "
  385. f"skip_special_tokens={self.skip_special_tokens}, "
  386. "spaces_between_special_tokens="
  387. f"{self.spaces_between_special_tokens})")