123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393 |
- """Sampling parameters for text generation."""
- from enum import IntEnum
- from functools import cached_property
- from typing import Callable, List, Optional, Union
- import torch
- _SAMPLING_EPS = 1e-5
- class SamplingType(IntEnum):
- GREEDY = 0
- RANDOM = 1
- RANDOM_SEED = 2
- BEAM = 3
- LogitsProcessorFunc = Callable[[torch.Tensor, List[List[int]]], None]
- """LogitsProcessorFunc takes a logits tensor and corresponding lists of
- previously generated output tokens, and modifies the logits tensor."""
- class SamplingParams:
- """Sampling parameters for text generation.
- Overall, we follow the sampling parameters from the OpenAI text completion
- API (https://platform.openai.com/docs/api-reference/completions/create).
- In addition, we support multiple additional samplers which are not supported
- by OpenAI.
- Args:
- n: Number of output sequences to return for the given prompt.
- best_of: Number of output sequences that are generated from the prompt.
- From these `best_of` sequences, the top `n` sequences are returned.
- `best_of` must be greater than or equal to `n`. This is treated as
- the beam width when `use_beam_search` is True. By default, `best_of`
- is set to `n`.
- presence_penalty: Float that penalizes new tokens based on whether they
- appear in the generated text so far. Values > 0 encourage the model
- to use new tokens, while values < 0 encourage the model to repeat
- tokens.
- frequency_penalty: Float that penalizes new tokens based on their
- frequency in the generated text so far. Values > 0 encourage the
- model to use new tokens, while values < 0 encourage the model to
- repeat tokens.
- repetition_penalty: Float that penalizes new tokens based on their
- frequency in the generated text so far.
- freq_pen is applied additively while
- rep_pen is applied multiplicatively.
- Must be in [1, inf). Set to 1 to disable the effect.
- temperature: Float that controls the randomness of the sampling. Lower
- values make the model more deterministic, while higher values make
- the model more random. Zero means greedy sampling.
- top_p: Float that controls the cumulative probability of the top tokens
- to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
- top_k: Integer that controls the number of top tokens to consider. Set
- to -1 to consider all tokens.
- top_a: Float that controls the cutoff for Top-A sampling.
- Exact cutoff is top_a*max_prob**2. Must be in [0,inf], 0 to disable.
- min_p: Float that controls the cutoff for min-p sampling.
- Exact cutoff is min_p*max_prob. Must be in [0,1], 0 to disable.
- tfs: Float that controls the cumulative approximate curvature of the
- distribution to retain for Tail Free Sampling.
- Must be in (0, 1]. Set to 1 to disable
- eta_cutoff: Float that controls the cutoff threshold for Eta sampling
- (a form of entropy adaptive truncation sampling)
- threshold is computed as min(eta, sqrt(eta)*entropy(probs)).
- Specified in units of 1e-4. Set to 0 to disable
- epsilon_cutoff: Float that controls the cutoff threshold for
- Epsilon sampling (simple probability threshold truncation).
- Specified in units of 1e-4. Set to 0 to disable.
- typical_p: Float that controls the cumulative probability of tokens
- closest in surprise to the expected surprise to consider.
- Must be in (0, 1]. Set to 1 to disable.
- typical_p_sigma: Used to scale the maximum threshold for positive
- deviations in typical_p. Range in [0, inf). Set to 0 to disable.
- mirostat_mode: Can either be 0 (disabled) or 2 (Mirostat v2).
- mirostat_tau: Target "surprisal" that mirostat works towards.
- Range [0, inf).
- mirostat_eta: Rate at which mirostat updates its internal surprisal
- value. Range [0, inf).
- dynatemp_min: Minimum temperature for dynatemp sampling.
- Range [0, inf).
- dynatemp_max: Maximum temperature for dynatemp sampling.
- Range [0, inf).
- dynatemp_exponent: Exponent for dynatemp sampling. Range [0, inf).
- smoothing_factor: Smoothing factor for Quadratic Sampling.
- smoothing_curve: Smoothing curve for Quadratic (Cubic) Sampling.
- seed: Random seed to use for the generation.
- use_beam_search: Whether to use beam search instead of sampling.
- length_penalty: Float that penalizes sequences based on their length.
- Used in beam search.
- early_stopping: Controls the stopping condition for beam search. It
- accepts the following values: `True`, where the generation stops as
- soon as there are `best_of` complete candidates; `False`, where an
- heuristic is applied and the generation stops when is it very
- unlikely to find better candidates; `"never"`, where the beam search
- procedure only stops when there cannot be better candidates
- (canonical beam search algorithm).
- stop: List of strings that stop the generation when they are generated.
- The returned output will not contain the stop strings.
- stop_token_ids: List of tokens that stop the generation when they are
- generated. The returned output will contain the stop tokens unless
- the stop tokens are special tokens.
- include_stop_str_in_output: Whether to include the stop strings in
- output text. Defaults to False.
- ignore_eos: Whether to ignore the EOS token and continue generating
- tokens after the EOS token is generated.
- max_tokens: Maximum number of tokens to generate per output sequence.
- logprobs: Number of log probabilities to return per output token.
- Note that the implementation follows the OpenAI API: The return
- result includes the log probabilities on the `logprobs` most likely
- tokens, as well the chosen tokens. The API will always return the
- log probability of the sampled token, so there may be up to
- `logprobs+1` elements in the response.
- prompt_logprobs: Number of log probabilities to return per prompt token.
- custom_token_bans: List of token IDs to ban from generating
- skip_special_tokens: Whether to skip special tokens in the output.
- defaults to true.
- spaces_between_special_tokens: Whether to add spaces between special
- tokens in the output. Defaults to True.
- logits_processors: List of LogitsProcessors to change the probability
- of token prediction at runtime.
- """
- def __init__(
- self,
- n: int = 1,
- best_of: Optional[int] = None,
- presence_penalty: float = 0.0,
- frequency_penalty: float = 0.0,
- repetition_penalty: float = 1.0,
- temperature: float = 1.0,
- top_p: float = 1.0,
- top_k: int = -1,
- top_a: float = 0.0,
- min_p: float = 0.0,
- tfs: float = 1.0,
- eta_cutoff: float = 0.0,
- epsilon_cutoff: float = 0.0,
- typical_p: float = 1.0,
- typical_p_sigma: float = 0.0,
- mirostat_mode: int = 0,
- mirostat_tau: float = 0,
- mirostat_eta: float = 0,
- dynatemp_min: float = 0,
- dynatemp_max: float = 0,
- dynatemp_exponent: float = 1,
- smoothing_factor: float = 0.0,
- smoothing_curve: float = 1.0,
- seed: Optional[int] = None,
- use_beam_search: bool = False,
- length_penalty: float = 1.0,
- early_stopping: Union[bool, str] = False,
- stop: Union[None, str, List[str]] = None,
- stop_token_ids: Optional[List[int]] = None,
- include_stop_str_in_output: bool = False,
- ignore_eos: bool = False,
- max_tokens: Optional[int] = 16,
- logprobs: Optional[int] = None,
- prompt_logprobs: Optional[int] = None,
- custom_token_bans: Optional[List[int]] = None,
- skip_special_tokens: bool = True,
- spaces_between_special_tokens: bool = True,
- logits_processors: Optional[List[LogitsProcessorFunc]] = None,
- ) -> None:
- self.n = n
- self.best_of = best_of if best_of is not None else n
- self.presence_penalty = presence_penalty
- self.frequency_penalty = frequency_penalty
- self.repetition_penalty = repetition_penalty
- self.temperature = temperature
- self.top_p = top_p
- self.top_k = top_k
- self.top_a = top_a
- self.min_p = min_p
- self.tfs = tfs
- self.eta_cutoff = eta_cutoff
- self.epsilon_cutoff = epsilon_cutoff
- self.typical_p = typical_p
- self.typical_p_sigma = typical_p_sigma
- self.mirostat_mode = mirostat_mode
- self.mirostat_tau = mirostat_tau
- self.mirostat_eta = mirostat_eta
- self.dynatemp_min = dynatemp_min
- self.dynatemp_max = dynatemp_max
- self.dynatemp_exponent = dynatemp_exponent
- self.smoothing_factor = smoothing_factor
- self.smoothing_curve = smoothing_curve
- self.seed = seed
- self.use_beam_search = use_beam_search
- self.length_penalty = length_penalty
- self.early_stopping = early_stopping
- if stop is None:
- self.stop = []
- elif isinstance(stop, str):
- self.stop = [stop]
- else:
- self.stop = list(stop)
- self.stop_token_ids = stop_token_ids or []
- self.ignore_eos = ignore_eos
- self.max_tokens = max_tokens
- self.logprobs = logprobs
- self.prompt_logprobs = prompt_logprobs
- self.custom_token_bans = custom_token_bans or []
- self.skip_special_tokens = skip_special_tokens
- self.spaces_between_special_tokens = spaces_between_special_tokens
- self.logits_processors = logits_processors or []
- self.include_stop_str_in_output = include_stop_str_in_output
- self.default_values = {
- "n": 1,
- "best_of": 1,
- "presence_penalty": 0.0,
- "frequency_penalty": 0.0,
- "repetition_penalty": 1.0,
- "temperature": 1.0,
- "top_p": 1.0,
- "top_k": -1,
- "top_a": 0.0,
- "min_p": 0.0,
- "tfs": 1.0,
- "eta_cutoff": 0.0,
- "epsilon_cutoff": 0.0,
- "typical_p": 1.0,
- "typical_p_sigma": 0.0,
- "mirostat_mode": 0,
- "mirostat_tau": 0,
- "mirostat_eta": 0,
- "dynatemp_min": 0,
- "dynatemp_max": 0,
- "dynatemp_exponent": 1,
- "smoothing_factor": 0.0,
- "smoothing_curve": 1.0,
- "seed": None,
- "use_beam_search": False,
- "length_penalty": 1.0,
- "early_stopping": False,
- "stop": [],
- "stop_token_ids": [],
- "ignore_eos": False,
- "max_tokens": 16,
- "logprobs": None,
- "prompt_logprobs": None,
- "custom_token_bans": [],
- "skip_special_tokens": True,
- "spaces_between_special_tokens": True,
- "include_stop_str_in_output": False
- }
- self._verify_args()
- if self.use_beam_search:
- self._verify_beam_search()
- else:
- self._verify_non_beam_search()
- if self.temperature < _SAMPLING_EPS:
- # Zero temperature means greedy sampling.
- self.top_p = 1.0
- self.top_k = -1
- self.min_p = 0.0
- self.top_a = 0.0
- self._verify_greedy_sampling()
- def _verify_args(self) -> None:
- if self.n < 1:
- raise ValueError(f"n must be at least 1, got {self.n}.")
- if self.best_of < self.n:
- raise ValueError(f"best_of must be greater than or equal to n, "
- f"got n={self.n} and best_of={self.best_of}.")
- if not -2.0 <= self.presence_penalty <= 2.0:
- raise ValueError("presence_penalty must be in [-2, 2], got "
- f"{self.presence_penalty}.")
- if not -2.0 <= self.frequency_penalty <= 2.0:
- raise ValueError("frequency_penalty must be in [-2, 2], got "
- f"{self.frequency_penalty}.")
- if self.repetition_penalty < 1.0:
- raise ValueError("repetition_penalty must be in [1, inf), got "
- f"{self.repetition_penalty}.")
- if self.temperature < 0.0:
- raise ValueError(
- f"temperature must be non-negative, got {self.temperature}.")
- if not 0.0 < self.top_p <= 1.0:
- raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
- if self.top_k < -1 or self.top_k == 0:
- raise ValueError(f"top_k must be -1 (disable), or at least 1, "
- f"got {self.top_k}.")
- if self.top_a < 0:
- raise ValueError(f"top_a must be non negative, got {self.top_a}.")
- if not 0.0 <= self.min_p <= 1.0:
- raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
- if not 0.0 < self.tfs <= 1.0:
- raise ValueError(f"tfs must be in (0, 1], got {self.tfs}.")
- if self.epsilon_cutoff < 0.0 or self.epsilon_cutoff > 1000.0:
- raise ValueError("epsilon_cutoff must be in [0, 1000], got "
- f"{self.epsilon_cutoff}.")
- # pylint: disable=unneeded-not
- if not self.eta_cutoff >= 0:
- raise ValueError(
- f"eta_cutoff must be non negative, got {self.eta_cutoff}.")
- if not 0.0 <= self.typical_p <= 1.0:
- raise ValueError(
- f"typical_p must be in (0, 1], got {self.typical_p}.")
- if not self.typical_p_sigma >= 0:
- raise ValueError(f"typical_p_sigma must be non negative, got "
- f"{self.typical_p_sigma}.")
- if not self.dynatemp_min >= 0:
- raise ValueError(
- f"dynatemp_min must be non negative, got {self.dynatemp_min}.")
- if not self.dynatemp_max >= 0:
- raise ValueError(
- f"dynatemp_max must be non negative, got {self.dynatemp_max}.")
- if not self.dynatemp_exponent >= 0:
- raise ValueError(f"dynatemp_exponent must be non negative, got "
- f"{self.dynatemp_exponent}.")
- if not self.smoothing_factor >= 0:
- raise ValueError(f"smoothing_factor must be non negative, got "
- f"{self.smoothing_factor}.")
- if not self.smoothing_curve >= 1.0:
- raise ValueError(f"smoothing_curve must larger than 1, got "
- f"{self.smoothing_curve}.")
- if self.mirostat_mode:
- if not self.mirostat_mode == 2:
- raise ValueError(
- "Only Mirostat v2 (2) and disabled (0) supported, "
- f"got {self.mirostat_mode}")
- if not self.mirostat_eta >= 0:
- raise ValueError(
- f"mirostat_eta must be positive, got {self.mirostat_eta}")
- if not self.mirostat_tau >= 0:
- raise ValueError(
- f"mirostat_tau must be positive, got {self.mirostat_tau}")
- if self.max_tokens is not None and self.max_tokens < 1:
- raise ValueError(
- f"max_tokens must be at least 1, got {self.max_tokens}.")
- if self.logprobs is not None and self.logprobs < 0:
- raise ValueError(
- f"logprobs must be non-negative, got {self.logprobs}.")
- if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
- raise ValueError("prompt_logprobs must be non-negative, got "
- f"{self.prompt_logprobs}.")
- def _verify_beam_search(self) -> None:
- if self.best_of == 1:
- raise ValueError("best_of must be greater than 1 when using beam "
- f"search. Got {self.best_of}.")
- if self.temperature > _SAMPLING_EPS:
- raise ValueError("temperature must be 0 when using beam search.")
- if self.top_p < 1.0 - _SAMPLING_EPS:
- raise ValueError("top_p must be 1 when using beam search.")
- if self.top_k != -1:
- raise ValueError("top_k must be -1 when using beam search.")
- if self.early_stopping not in [True, False, "never"]:
- raise ValueError(
- f"early_stopping must be True, False, or 'never', "
- f"got {self.early_stopping}.")
- def _verify_non_beam_search(self) -> None:
- if self.early_stopping is not False:
- raise ValueError("early_stopping is not effective and must be "
- "False when not using beam search.")
- if (self.length_penalty < 1.0 - _SAMPLING_EPS
- or self.length_penalty > 1.0 + _SAMPLING_EPS):
- raise ValueError(
- "length_penalty is not effective and must be the "
- "default value of 1.0 when not using beam search.")
- def _verify_greedy_sampling(self) -> None:
- if self.best_of > 1:
- raise ValueError("best_of must be 1 when using greedy sampling."
- f"Got {self.best_of}.")
- if self.top_p < 1.0 - _SAMPLING_EPS:
- raise ValueError("top_p must be 1 when using greedy sampling.")
- if self.top_k != -1:
- raise ValueError("top_k must be -1 when using greedy sampling.")
- @cached_property
- def sampling_type(self) -> SamplingType:
- if self.use_beam_search:
- return SamplingType.BEAM
- if self.temperature < _SAMPLING_EPS:
- return SamplingType.GREEDY
- if self.seed is not None:
- return SamplingType.RANDOM_SEED
- return SamplingType.RANDOM
- def __repr__(self) -> str:
- repr_str = "SamplingParams("
- for param, default_value in self.default_values.items():
- current_value = getattr(self, param)
- if current_value != default_value:
- repr_str += f"{param}={current_value}, "
- repr_str = repr_str.rstrip(', ') + ")"
- return repr_str
|