sampling_params.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. """Sampling parameters for text generation."""
  2. import ast
  3. import copy
  4. import os
  5. from enum import IntEnum
  6. from functools import cached_property
  7. from typing import Any, Callable, Dict, List, Optional, Union
  8. import torch
  9. from loguru import logger
  10. from pydantic import Field
  11. from typing_extensions import Annotated
  12. _SAMPLING_EPS = 1e-5
  13. _MAX_TEMP = 1e-2
  14. APHRODITE_NO_DEPRECATION_WARNING = bool(
  15. int(os.environ.get("APHRODITE_NO_DEPRECATION_WARNING", "0")))
  16. class SamplingType(IntEnum):
  17. GREEDY = 0
  18. RANDOM = 1
  19. RANDOM_SEED = 2
  20. BEAM = 3
  21. LogitsProcessorFunc = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
  22. Callable[[List[int], List[int], torch.Tensor],
  23. torch.Tensor]]
  24. """LogitsProcessor is a function that takes a list
  25. of previously generated tokens, the logits tensor
  26. for the next token and, optionally, prompt tokens as a
  27. first argument, and returns a modified tensor of logits
  28. to sample from."""
  29. class SamplingParams:
  30. """Sampling parameters for text generation.
  31. Overall, we follow the sampling parameters from the OpenAI text completion
  32. API (https://platform.openai.com/docs/api-reference/completions/create).
  33. In addition, we support multiple additional samplers which are not supported
  34. by OpenAI.
  35. Args:
  36. n: Number of output sequences to return for the given prompt.
  37. best_of: Number of output sequences that are generated from the prompt.
  38. From these `best_of` sequences, the top `n` sequences are returned.
  39. `best_of` must be greater than or equal to `n`. This is treated as
  40. the beam width when `use_beam_search` is True. By default, `best_of`
  41. is set to `n`.
  42. presence_penalty: Float that penalizes new tokens based on whether they
  43. appear in the generated text so far. Values > 0 encourage the model
  44. to use new tokens, while values < 0 encourage the model to repeat
  45. tokens.
  46. frequency_penalty: Float that penalizes new tokens based on their
  47. frequency in the generated text so far. Values > 0 encourage the
  48. model to use new tokens, while values < 0 encourage the model to
  49. repeat tokens.
  50. repetition_penalty: Float that penalizes new tokens based on their
  51. frequency in the generated text so far.
  52. freq_pen is applied additively while
  53. rep_pen is applied multiplicatively.
  54. Must be in [1, inf). Set to 1 to disable the effect.
  55. dry_multiplier: Float that controls the magnitude of the penalty for
  56. the shortest penalized sequences in the DRY sampler. Set to values
  57. greater than 0 to enable DRY sampling.
  58. dry_base: Float that controls how fast the penalty grows with
  59. increasing sequence length in the DRY sampler.
  60. dry_allowed_length: Integer that controls the maximum length of
  61. sequences that can be repeated without being penalized in the DRY
  62. sampler.
  63. dry_sequence_breakers: Tokens across which sequence matching is not
  64. continued. Specified as a comma-separated list of quoted strings.
  65. temperature: Float that controls the randomness of the sampling. Lower
  66. values make the model more deterministic, while higher values make
  67. the model more random. Zero means greedy sampling.
  68. top_p: Float that controls the cumulative probability of the top tokens
  69. to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
  70. top_k: Integer that controls the number of top tokens to consider. Set
  71. to -1 to consider all tokens.
  72. top_a: Float that controls the cutoff for Top-A sampling.
  73. Exact cutoff is top_a*max_prob**2. Must be in [0,inf], 0 to disable.
  74. min_p: Float that controls the cutoff for min-p sampling.
  75. Exact cutoff is min_p*max_prob. Must be in [0,1], 0 to disable.
  76. tfs: Float that controls the cumulative approximate curvature of the
  77. distribution to retain for Tail Free Sampling.
  78. Must be in (0, 1]. Set to 1 to disable
  79. eta_cutoff: Float that controls the cutoff threshold for Eta sampling
  80. (a form of entropy adaptive truncation sampling)
  81. threshold is computed as min(eta, sqrt(eta)*entropy(probs)).
  82. Specified in units of 1e-4. Set to 0 to disable
  83. epsilon_cutoff: Float that controls the cutoff threshold for
  84. Epsilon sampling (simple probability threshold truncation).
  85. Specified in units of 1e-4. Set to 0 to disable.
  86. typical_p: Float that controls the cumulative probability of tokens
  87. closest in surprise to the expected surprise to consider.
  88. Must be in (0, 1]. Set to 1 to disable.
  89. mirostat_mode: Can either be 0 (disabled) or 2 (Mirostat v2).
  90. mirostat_tau: Target "surprisal" that mirostat works towards.
  91. Range [0, inf).
  92. mirostat_eta: Rate at which mirostat updates its internal surprisal
  93. value. Range [0, inf).
  94. dynatemp_min: Minimum temperature for dynatemp sampling.
  95. Range [0, inf).
  96. dynatemp_max: Maximum temperature for dynatemp sampling.
  97. Range [0, inf).
  98. dynatemp_exponent: Exponent for dynatemp sampling. Range [0, inf).
  99. smoothing_factor: Smoothing factor for Quadratic Sampling.
  100. smoothing_curve: Smoothing curve for Quadratic (Cubic) Sampling.
  101. seed: Random seed to use for the generation.
  102. use_beam_search: Whether to use beam search instead of sampling.
  103. length_penalty: Float that penalizes sequences based on their length.
  104. Used in beam search.
  105. early_stopping: Controls the stopping condition for beam search. It
  106. accepts the following values: `True`, where the generation stops as
  107. soon as there are `best_of` complete candidates; `False`, where an
  108. heuristic is applied and the generation stops when is it very
  109. unlikely to find better candidates; `"never"`, where the beam search
  110. procedure only stops when there cannot be better candidates
  111. (canonical beam search algorithm).
  112. stop: List of strings that stop the generation when they are generated.
  113. The returned output will not contain the stop strings.
  114. stop_token_ids: List of tokens that stop the generation when they are
  115. generated. The returned output will contain the stop tokens unless
  116. the stop tokens are special tokens.
  117. include_stop_str_in_output: Whether to include the stop strings in
  118. output text. Defaults to False.
  119. ignore_eos: Whether to ignore the EOS token and continue generating
  120. tokens after the EOS token is generated.
  121. max_tokens: Maximum number of tokens to generate per output sequence.
  122. min_tokens: Minimum number of tokens to generate per output sequence
  123. before EOS or stop tokens are generated.
  124. logprobs: Number of log probabilities to return per output token.
  125. When set to None, no probability is returned. If set to a non-None
  126. value, the result includes the log probabilities of the specified
  127. number of most likely tokens, as well as the chosen tokens.
  128. Note that the implementation follows the OpenAI API: The API will
  129. always return the log probability of the sampled token, so there
  130. may be up to `logprobs+1` elements in the response.
  131. prompt_logprobs: Number of log probabilities to return per prompt token.
  132. detokenize: Whether to detokenize the output. Defaults to True.
  133. custom_token_bans: List of token IDs to ban from generating
  134. skip_special_tokens: Whether to skip special tokens in the output.
  135. defaults to true.
  136. spaces_between_special_tokens: Whether to add spaces between special
  137. tokens in the output. Defaults to True.
  138. logits_processors: List of functions that modify logits based on
  139. previously generated tokens, and optionally prompt tokens as
  140. a first argument.
  141. truncate_prompt_tokens: If set to an integer k, will use only the last
  142. k tokens from the prompt (i.e. left-truncation). Defaults to None
  143. (i.e. no truncation).
  144. """
  145. def __init__(
  146. self,
  147. n: int = 1,
  148. best_of: Optional[int] = None,
  149. presence_penalty: float = 0.0,
  150. frequency_penalty: float = 0.0,
  151. repetition_penalty: float = 1.0,
  152. dry_multiplier: float = 0.0,
  153. dry_base: float = 1.75,
  154. dry_allowed_length: int = 2,
  155. dry_sequence_breakers: Union[str, List[List[int]]] = '"\\n", ":", "\\"", "*"',
  156. temperature: float = 1.0,
  157. temperature_last: bool = False,
  158. top_p: float = 1.0,
  159. top_k: int = -1,
  160. top_a: float = 0.0,
  161. min_p: float = 0.0,
  162. tfs: float = 1.0,
  163. eta_cutoff: float = 0.0,
  164. epsilon_cutoff: float = 0.0,
  165. typical_p: float = 1.0,
  166. smoothing_factor: float = 0.0,
  167. smoothing_curve: float = 1.0,
  168. seed: Optional[int] = None,
  169. use_beam_search: bool = False,
  170. length_penalty: float = 1.0,
  171. early_stopping: Union[bool, str] = False,
  172. stop: Union[None, str, List[str]] = None,
  173. stop_token_ids: Optional[List[int]] = None,
  174. include_stop_str_in_output: bool = False,
  175. ignore_eos: bool = False,
  176. max_tokens: Optional[int] = 16,
  177. min_tokens: int = 0,
  178. logprobs: Optional[int] = None,
  179. prompt_logprobs: Optional[int] = None,
  180. detokenize: bool = True,
  181. custom_token_bans: Optional[List[int]] = None,
  182. skip_special_tokens: bool = True,
  183. spaces_between_special_tokens: bool = True,
  184. logits_processors: Optional[List[LogitsProcessorFunc]] = None,
  185. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
  186. ) -> None:
  187. self.n = n
  188. self.best_of = best_of if best_of is not None else n
  189. self.presence_penalty = presence_penalty
  190. self.frequency_penalty = frequency_penalty
  191. self.repetition_penalty = repetition_penalty
  192. self.dry_multiplier = dry_multiplier
  193. self.dry_base = dry_base
  194. self.dry_allowed_length = dry_allowed_length
  195. self.dry_sequence_breakers = self._parse_dry_sequence_breakers(dry_sequence_breakers)
  196. if 0 < temperature < _MAX_TEMP:
  197. logger.warning(
  198. f"temperature {temperature} is less than {_MAX_TEMP}, "
  199. f"which may cause numerical errors (NaN or Inf) in tensors. "
  200. f"We have capped the temperature to {_MAX_TEMP}.")
  201. temperature = min(temperature, _MAX_TEMP)
  202. self.temperature = temperature
  203. self.temperature_last = temperature_last
  204. self.top_p = top_p
  205. self.top_k = top_k
  206. self.top_a = top_a
  207. self.min_p = min_p
  208. self.tfs = tfs
  209. self.eta_cutoff = eta_cutoff
  210. self.epsilon_cutoff = epsilon_cutoff
  211. self.typical_p = typical_p
  212. self.smoothing_factor = smoothing_factor
  213. self.smoothing_curve = smoothing_curve
  214. if seed == -1:
  215. self.seed = None
  216. else:
  217. self.seed = seed
  218. self.use_beam_search = use_beam_search
  219. self.length_penalty = length_penalty
  220. self.early_stopping = early_stopping
  221. if stop is None:
  222. self.stop = []
  223. elif isinstance(stop, str):
  224. self.stop = [stop]
  225. else:
  226. self.stop = list(stop)
  227. self.stop_token_ids = stop_token_ids or []
  228. self.ignore_eos = ignore_eos
  229. self.max_tokens = max_tokens
  230. self.min_tokens = min_tokens
  231. self.logprobs = 1 if logprobs is True else logprobs
  232. self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
  233. # NOTE: This parameter is only exposed at the engine level for now.
  234. # It is not exposed in the OpenAI API server, as the OpenAI API does
  235. # not support returning only a list of token IDs.
  236. self.detokenize = detokenize
  237. self.custom_token_bans = custom_token_bans or []
  238. self.skip_special_tokens = skip_special_tokens
  239. self.spaces_between_special_tokens = spaces_between_special_tokens
  240. self.logits_processors = logits_processors or []
  241. self.include_stop_str_in_output = include_stop_str_in_output
  242. self.truncate_prompt_tokens = truncate_prompt_tokens
  243. # Number of characters to hold back for stop string evaluation
  244. # until sequence is finished.
  245. if self.stop and not include_stop_str_in_output:
  246. self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
  247. else:
  248. self.output_text_buffer_length = 0
  249. self.default_values = {
  250. "n": 1,
  251. "best_of": 1,
  252. "presence_penalty": 0.0,
  253. "frequency_penalty": 0.0,
  254. "repetition_penalty": 1.0,
  255. "dry_multiplier": 0.0,
  256. "dry_base": 1.75,
  257. "dry_allowed_length": 2,
  258. "dry_sequence_breakers": '"\\n", ":", "\\"", "*"',
  259. "temperature": 1.0,
  260. "temperature_last": False,
  261. "top_p": 1.0,
  262. "top_k": -1,
  263. "top_a": 0.0,
  264. "min_p": 0.0,
  265. "tfs": 1.0,
  266. "eta_cutoff": 0.0,
  267. "epsilon_cutoff": 0.0,
  268. "typical_p": 1.0,
  269. "smoothing_factor": 0.0,
  270. "smoothing_curve": 1.0,
  271. "seed": None,
  272. "use_beam_search": False,
  273. "length_penalty": 1.0,
  274. "early_stopping": False,
  275. "stop": [],
  276. "stop_token_ids": [],
  277. "ignore_eos": False,
  278. "max_tokens": 16,
  279. "min_tokens": 0,
  280. "logprobs": None,
  281. "prompt_logprobs": None,
  282. "detokenize": True,
  283. "custom_token_bans": [],
  284. "skip_special_tokens": True,
  285. "spaces_between_special_tokens": True,
  286. "include_stop_str_in_output": False,
  287. "truncate_prompt_tokens": None,
  288. }
  289. # Number of characters to hold back for stop string evaluation
  290. # until sequence is finished.
  291. if self.stop and not include_stop_str_in_output:
  292. self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
  293. else:
  294. self.output_text_buffer_length = 0
  295. self._verify_args()
  296. if self.use_beam_search:
  297. if not APHRODITE_NO_DEPRECATION_WARNING:
  298. logger.warning(
  299. "[IMPORTANT] We plan to discontinue the support for beam "
  300. "search in the next major release. Set "
  301. "APHRODITE_NO_DEPRECATION_WARNING=1 to "
  302. "suppress this warning.")
  303. self._verify_beam_search()
  304. else:
  305. self._verify_non_beam_search()
  306. if self.temperature < _SAMPLING_EPS:
  307. # Zero temperature means greedy sampling.
  308. self.top_p = 1.0
  309. self.top_k = -1
  310. self.min_p = 0.0
  311. self.top_a = 0.0
  312. self._verify_greedy_sampling()
  313. # eos_token_id is added to this by the engine
  314. self.all_stop_token_ids = set(self.stop_token_ids)
  315. def _parse_dry_sequence_breakers(self, dry_sequence_breakers: Union[str, List[List[int]]]) -> List[str]:
  316. if isinstance(dry_sequence_breakers, list):
  317. return dry_sequence_breakers
  318. try:
  319. # Use ast.literal_eval to safely evaluate the string as a Python expression
  320. parsed = ast.literal_eval(f'[{dry_sequence_breakers}]')
  321. return [str(item) for item in parsed]
  322. except (SyntaxError, ValueError):
  323. # If parsing fails, return the original string as a single-item list
  324. return [dry_sequence_breakers]
  325. def tokenize_dry_sequence_breakers(self, tokenizer):
  326. if not isinstance(self.dry_sequence_breakers[0], str):
  327. # Already tokenized
  328. return
  329. tokenized_breakers = []
  330. for breaker in self.dry_sequence_breakers:
  331. tokenized_breaker = tokenizer.encode(breaker, add_special_tokens=False)
  332. tokenized_breakers.append(tokenized_breaker)
  333. self.dry_sequence_breakers = tokenized_breakers
  334. def _verify_args(self) -> None:
  335. if self.n < 1:
  336. raise ValueError(f"n must be at least 1, got {self.n}.")
  337. if self.best_of < self.n:
  338. raise ValueError(f"best_of must be greater than or equal to n, "
  339. f"got n={self.n} and best_of={self.best_of}.")
  340. if not -2.0 <= self.presence_penalty <= 2.0:
  341. raise ValueError("presence_penalty must be in [-2, 2], got "
  342. f"{self.presence_penalty}.")
  343. if not -2.0 <= self.frequency_penalty <= 2.0:
  344. raise ValueError("frequency_penalty must be in [-2, 2], got "
  345. f"{self.frequency_penalty}.")
  346. if self.repetition_penalty < 1.0:
  347. raise ValueError("repetition_penalty must be in [1, inf), got "
  348. f"{self.repetition_penalty}.")
  349. if self.dry_multiplier < 0.0:
  350. raise ValueError("dry_multiplier must be non-negative, got "
  351. f"{self.dry_multiplier}.")
  352. if self.dry_base < 1.0:
  353. raise ValueError(
  354. f"dry_base must be at least 1, got {self.dry_base}.")
  355. if self.dry_allowed_length < 1:
  356. raise ValueError("dry_allowed_length must be at least 1, got "
  357. f"{self.dry_allowed_length}.")
  358. if not all(isinstance(s, str) for s in self.dry_sequence_breakers):
  359. raise ValueError(
  360. "dry_sequence_breakers must be a list of strings.")
  361. if self.temperature < 0.0:
  362. raise ValueError(
  363. f"temperature must be non-negative, got {self.temperature}.")
  364. if not 0.0 < self.top_p <= 1.0:
  365. raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
  366. if self.top_k < -1 or self.top_k == 0:
  367. raise ValueError(f"top_k must be -1 (disable), or at least 1, "
  368. f"got {self.top_k}.")
  369. if self.top_a < 0:
  370. raise ValueError(f"top_a must be non negative, got {self.top_a}.")
  371. if not 0.0 <= self.min_p <= 1.0:
  372. raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
  373. if not 0.0 < self.tfs <= 1.0:
  374. raise ValueError(f"tfs must be in (0, 1], got {self.tfs}.")
  375. if self.epsilon_cutoff < 0.0 or self.epsilon_cutoff > 1000.0:
  376. raise ValueError("epsilon_cutoff must be in [0, 1000], got "
  377. f"{self.epsilon_cutoff}.")
  378. # pylint: disable=unneeded-not
  379. if not self.eta_cutoff >= 0:
  380. raise ValueError(
  381. f"eta_cutoff must be non negative, got {self.eta_cutoff}.")
  382. if not 0.0 <= self.typical_p <= 1.0:
  383. raise ValueError(
  384. f"typical_p must be in (0, 1], got {self.typical_p}.")
  385. if self.max_tokens is not None and self.max_tokens < 1:
  386. raise ValueError(
  387. f"max_tokens must be at least 1, got {self.max_tokens}.")
  388. if self.min_tokens < 0:
  389. raise ValueError(f"min_tokens must be greater than or equal to 0, "
  390. f"got {self.min_tokens}.")
  391. if self.max_tokens is not None and self.min_tokens > self.max_tokens:
  392. raise ValueError(
  393. f"min_tokens must be less than or equal to "
  394. f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
  395. if self.logprobs is not None and self.logprobs < 0:
  396. raise ValueError(
  397. f"logprobs must be non-negative, got {self.logprobs}.")
  398. if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
  399. raise ValueError("prompt_logprobs must be non-negative, got "
  400. f"{self.prompt_logprobs}.")
  401. if (self.truncate_prompt_tokens is not None
  402. and self.truncate_prompt_tokens < 1):
  403. raise ValueError(f"truncate_prompt_tokens must be >= 1, "
  404. f"got {self.truncate_prompt_tokens}")
  405. if any(not stop_str for stop_str in self.stop):
  406. raise ValueError("stop cannot contain an empty string.")
  407. if self.stop and not self.detokenize:
  408. raise ValueError(
  409. "stop strings are only supported when detokenize is True. "
  410. "Set detokenize=True to use stop.")
  411. def _verify_beam_search(self) -> None:
  412. if self.best_of == 1:
  413. raise ValueError("best_of must be greater than 1 when using beam "
  414. f"search. Got {self.best_of}.")
  415. if self.temperature > _SAMPLING_EPS:
  416. raise ValueError("temperature must be 0 when using beam search.")
  417. if self.top_p < 1.0 - _SAMPLING_EPS:
  418. raise ValueError("top_p must be 1 when using beam search.")
  419. if self.top_k != -1:
  420. raise ValueError("top_k must be -1 when using beam search.")
  421. if self.early_stopping not in [True, False, "never"]:
  422. raise ValueError(
  423. f"early_stopping must be True, False, or 'never', "
  424. f"got {self.early_stopping}.")
  425. def _verify_non_beam_search(self) -> None:
  426. if self.early_stopping is not False:
  427. raise ValueError("early_stopping is not effective and must be "
  428. "False when not using beam search.")
  429. if (self.length_penalty < 1.0 - _SAMPLING_EPS
  430. or self.length_penalty > 1.0 + _SAMPLING_EPS):
  431. raise ValueError(
  432. "length_penalty is not effective and must be the "
  433. "default value of 1.0 when not using beam search.")
  434. def _verify_greedy_sampling(self) -> None:
  435. if self.best_of > 1:
  436. raise ValueError("best_of must be 1 when using greedy sampling."
  437. f"Got {self.best_of}.")
  438. if self.top_p < 1.0 - _SAMPLING_EPS:
  439. raise ValueError("top_p must be 1 when using greedy sampling.")
  440. if self.top_k != -1:
  441. raise ValueError("top_k must be -1 when using greedy sampling.")
  442. def update_from_generation_config(
  443. self,
  444. generation_config: Dict[str, Any],
  445. model_eos_token_id: Optional[int] = None) -> None:
  446. """Update if there are non-default values from generation_config"""
  447. if model_eos_token_id is not None:
  448. # Add the eos token id into the sampling_params to support
  449. # min_tokens processing.
  450. self.all_stop_token_ids.add(model_eos_token_id)
  451. # Update eos_token_id for generation
  452. if (eos_ids := generation_config.get("eos_token_id")) is not None:
  453. # it can be either int or list of int
  454. eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
  455. if model_eos_token_id is not None:
  456. # We don't need to include the primary eos_token_id in
  457. # stop_token_ids since it's handled separately for stopping
  458. # purposes.
  459. eos_ids.discard(model_eos_token_id)
  460. if eos_ids:
  461. self.all_stop_token_ids.update(eos_ids)
  462. if not self.ignore_eos:
  463. eos_ids.update(self.stop_token_ids)
  464. self.stop_token_ids = list(eos_ids)
  465. @cached_property
  466. def sampling_type(self) -> SamplingType:
  467. if self.use_beam_search:
  468. return SamplingType.BEAM
  469. if self.temperature < _SAMPLING_EPS:
  470. return SamplingType.GREEDY
  471. if self.seed is not None:
  472. return SamplingType.RANDOM_SEED
  473. return SamplingType.RANDOM
  474. def clone(self) -> "SamplingParams":
  475. """Deep copy excluding LogitsProcessor objects.
  476. LogitsProcessor objects are excluded because they may contain an
  477. arbitrary, nontrivial amount of data.
  478. """
  479. logit_processor_refs = None if self.logits_processors is None else {
  480. id(lp): lp
  481. for lp in self.logits_processors
  482. }
  483. return copy.deepcopy(self, memo=logit_processor_refs)
  484. def __repr__(self) -> str:
  485. repr_str = "SamplingParams("
  486. for param, default_value in self.default_values.items():
  487. current_value = getattr(self, param)
  488. if current_value != default_value:
  489. repr_str += f"{param}={current_value}, "
  490. repr_str = repr_str.rstrip(', ') + ")"
  491. return repr_str