|
@@ -3,11 +3,11 @@ import copy
|
|
|
import os
|
|
|
from enum import IntEnum
|
|
|
from functools import cached_property
|
|
|
-from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
+from typing import Any, Callable, Dict, List, Optional, Set, Union
|
|
|
|
|
|
+import msgspec
|
|
|
import torch
|
|
|
from loguru import logger
|
|
|
-from pydantic import Field
|
|
|
from typing_extensions import Annotated
|
|
|
|
|
|
_SAMPLING_EPS = 1e-5
|
|
@@ -34,7 +34,10 @@ first argument, and returns a modified tensor of logits
|
|
|
to sample from."""
|
|
|
|
|
|
|
|
|
-class SamplingParams:
|
|
|
+class SamplingParams(
|
|
|
+ msgspec.Struct,
|
|
|
+ omit_defaults=True,
|
|
|
+ dict=True):
|
|
|
"""Sampling parameters for text generation.
|
|
|
|
|
|
Overall, we follow the sampling parameters from the OpenAI text completion
|
|
@@ -147,161 +150,127 @@ class SamplingParams:
|
|
|
0 disables the sampler, 1 makes it always happen.
|
|
|
"""
|
|
|
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- n: int = 1,
|
|
|
- best_of: Optional[int] = None,
|
|
|
- presence_penalty: float = 0.0,
|
|
|
- frequency_penalty: float = 0.0,
|
|
|
- repetition_penalty: float = 1.0,
|
|
|
- temperature: float = 1.0,
|
|
|
- dynatemp_min: float = 0.0,
|
|
|
- dynatemp_max: float = 0.0,
|
|
|
- dynatemp_exponent: float = 1.0,
|
|
|
- temperature_last: bool = False,
|
|
|
- top_p: float = 1.0,
|
|
|
- top_k: int = -1,
|
|
|
- top_a: float = 0.0,
|
|
|
- min_p: float = 0.0,
|
|
|
- tfs: float = 1.0,
|
|
|
- eta_cutoff: float = 0.0,
|
|
|
- epsilon_cutoff: float = 0.0,
|
|
|
- typical_p: float = 1.0,
|
|
|
- smoothing_factor: float = 0.0,
|
|
|
- smoothing_curve: float = 1.0,
|
|
|
- seed: Optional[int] = None,
|
|
|
- use_beam_search: bool = False,
|
|
|
- length_penalty: float = 1.0,
|
|
|
- early_stopping: Union[bool, str] = False,
|
|
|
- stop: Union[None, str, List[str]] = None,
|
|
|
- stop_token_ids: Optional[List[int]] = None,
|
|
|
- include_stop_str_in_output: bool = False,
|
|
|
- ignore_eos: bool = False,
|
|
|
- max_tokens: Optional[int] = 16,
|
|
|
- min_tokens: int = 0,
|
|
|
- logprobs: Optional[int] = None,
|
|
|
- prompt_logprobs: Optional[int] = None,
|
|
|
- detokenize: bool = True,
|
|
|
- custom_token_bans: Optional[List[int]] = None,
|
|
|
- skip_special_tokens: bool = True,
|
|
|
- spaces_between_special_tokens: bool = True,
|
|
|
- logits_processors: Optional[List[LogitsProcessorFunc]] = None,
|
|
|
- truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
|
|
- xtc_threshold: float = 0.1,
|
|
|
- xtc_probability: float = 0,
|
|
|
- ) -> None:
|
|
|
- self.n = n
|
|
|
- self.best_of = best_of if best_of is not None else n
|
|
|
- self.presence_penalty = presence_penalty
|
|
|
- self.frequency_penalty = frequency_penalty
|
|
|
- self.repetition_penalty = repetition_penalty
|
|
|
- if 0 < temperature < _MAX_TEMP:
|
|
|
+ n: int = 1
|
|
|
+ best_of: Optional[int] = None
|
|
|
+ presence_penalty: float = 0.0
|
|
|
+ frequency_penalty: float = 0.0
|
|
|
+ repetition_penalty: float = 1.0
|
|
|
+ temperature: float = 1.0
|
|
|
+ dynatemp_min: float = 0.0
|
|
|
+ dynatemp_max: float = 0.0
|
|
|
+ dynatemp_exponent: float = 1.0
|
|
|
+ temperature_last: bool = False
|
|
|
+ top_p: float = 1.0
|
|
|
+ top_k: int = -1
|
|
|
+ top_a: float = 0.0
|
|
|
+ min_p: float = 0.0
|
|
|
+ tfs: float = 1.0
|
|
|
+ eta_cutoff: float = 0.0
|
|
|
+ epsilon_cutoff: float = 0.0
|
|
|
+ typical_p: float = 1.0
|
|
|
+ smoothing_factor: float = 0.0
|
|
|
+ smoothing_curve: float = 1.0
|
|
|
+ seed: Optional[int] = None
|
|
|
+ use_beam_search: bool = False
|
|
|
+ length_penalty: float = 1.0
|
|
|
+ early_stopping: Union[bool, str] = False
|
|
|
+ stop: Union[None, str, List[str]] = None
|
|
|
+ stop_token_ids: Optional[List[int]] = None
|
|
|
+ include_stop_str_in_output: bool = False
|
|
|
+ ignore_eos: bool = False
|
|
|
+ max_tokens: Optional[int] = 16
|
|
|
+ min_tokens: int = 0
|
|
|
+ logprobs: Optional[int] = None
|
|
|
+ prompt_logprobs: Optional[int] = None
|
|
|
+ detokenize: bool = True
|
|
|
+ custom_token_bans: Optional[List[int]] = None
|
|
|
+ skip_special_tokens: bool = True
|
|
|
+ spaces_between_special_tokens: bool = True
|
|
|
+ # Optional[List[LogitsProcessorFunc]] type.
|
|
|
+ # We use Any here because the type above
|
|
|
+ # is not supported by msgspec.
|
|
|
+ logits_processors: Optional[Any] = None
|
|
|
+ truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
|
|
|
+ xtc_threshold: float = 0.1
|
|
|
+ xtc_probability: float = 0
|
|
|
+
|
|
|
+ # The below fields are not supposed to be used as an input.
|
|
|
+ # They are set in post_init.
|
|
|
+ output_text_buffer_length: int = 0
|
|
|
+ _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
|
|
|
+
|
|
|
+ default_values = {
|
|
|
+ "n": 1,
|
|
|
+ "best_of": 1,
|
|
|
+ "presence_penalty": 0.0,
|
|
|
+ "frequency_penalty": 0.0,
|
|
|
+ "repetition_penalty": 1.0,
|
|
|
+ "temperature": 1.0,
|
|
|
+ "dynatemp_min": 0.0,
|
|
|
+ "dynatemp_max": 0.0,
|
|
|
+ "dynatemp_exponent": 1.0,
|
|
|
+ "temperature_last": False,
|
|
|
+ "top_p": 1.0,
|
|
|
+ "top_k": -1,
|
|
|
+ "top_a": 0.0,
|
|
|
+ "min_p": 0.0,
|
|
|
+ "tfs": 1.0,
|
|
|
+ "eta_cutoff": 0.0,
|
|
|
+ "epsilon_cutoff": 0.0,
|
|
|
+ "typical_p": 1.0,
|
|
|
+ "smoothing_factor": 0.0,
|
|
|
+ "smoothing_curve": 1.0,
|
|
|
+ "seed": None,
|
|
|
+ "use_beam_search": False,
|
|
|
+ "length_penalty": 1.0,
|
|
|
+ "early_stopping": False,
|
|
|
+ "stop": [],
|
|
|
+ "stop_token_ids": [],
|
|
|
+ "ignore_eos": False,
|
|
|
+ "max_tokens": 16,
|
|
|
+ "min_tokens": 0,
|
|
|
+ "logprobs": None,
|
|
|
+ "prompt_logprobs": None,
|
|
|
+ "detokenize": True,
|
|
|
+ "custom_token_bans": [],
|
|
|
+ "skip_special_tokens": True,
|
|
|
+ "spaces_between_special_tokens": True,
|
|
|
+ "include_stop_str_in_output": False,
|
|
|
+ "truncate_prompt_tokens": None,
|
|
|
+ "xtc_threshold": 0.1,
|
|
|
+ "xtc_probability": 0,
|
|
|
+ }
|
|
|
+
|
|
|
+ def __post_init__(self) -> None:
|
|
|
+ self.best_of = self.best_of or self.n
|
|
|
+ if 0 < self.temperature < _MAX_TEMP:
|
|
|
logger.warning(
|
|
|
- f"temperature {temperature} is less than {_MAX_TEMP}, "
|
|
|
- f"which may cause numerical errors (NaN or Inf) in tensors. "
|
|
|
- f"We have capped the temperature to {_MAX_TEMP}.")
|
|
|
- temperature = min(temperature, _MAX_TEMP)
|
|
|
- self.temperature = temperature
|
|
|
- self.dynatemp_min = dynatemp_min
|
|
|
- self.dynatemp_max = dynatemp_max
|
|
|
- self.dynatemp_exponent = dynatemp_exponent
|
|
|
- self.temperature_last = temperature_last
|
|
|
- self.top_p = top_p
|
|
|
- self.top_k = top_k
|
|
|
- self.top_a = top_a
|
|
|
- self.min_p = min_p
|
|
|
- self.tfs = tfs
|
|
|
- self.eta_cutoff = eta_cutoff
|
|
|
- self.epsilon_cutoff = epsilon_cutoff
|
|
|
- self.typical_p = typical_p
|
|
|
- self.smoothing_factor = smoothing_factor
|
|
|
- self.smoothing_curve = smoothing_curve
|
|
|
- if seed == -1:
|
|
|
+ f"temperature {self.temperature} is less than {_MAX_TEMP}, "
|
|
|
+ "which may cause numerical errors NaN or inf in tensors. We "
|
|
|
+ f"have maxed it out to {_MAX_TEMP}.")
|
|
|
+ self.temperature = max(self.temperature, _MAX_TEMP)
|
|
|
+ if self.seed == -1:
|
|
|
self.seed = None
|
|
|
else:
|
|
|
- self.seed = seed
|
|
|
- self.use_beam_search = use_beam_search
|
|
|
- self.length_penalty = length_penalty
|
|
|
- self.early_stopping = early_stopping
|
|
|
- if stop is None:
|
|
|
+ self.seed = self.seed
|
|
|
+ if self.stop is None:
|
|
|
self.stop = []
|
|
|
- elif isinstance(stop, str):
|
|
|
- self.stop = [stop]
|
|
|
+ elif isinstance(self.stop, str):
|
|
|
+ self.stop = [self.stop]
|
|
|
else:
|
|
|
- self.stop = list(stop)
|
|
|
- self.stop_token_ids = stop_token_ids or []
|
|
|
- self.ignore_eos = ignore_eos
|
|
|
- self.max_tokens = max_tokens
|
|
|
- self.min_tokens = min_tokens
|
|
|
- self.logprobs = 1 if logprobs is True else logprobs
|
|
|
- self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
|
|
|
- # NOTE: This parameter is only exposed at the engine level for now.
|
|
|
- # It is not exposed in the OpenAI API server, as the OpenAI API does
|
|
|
- # not support returning only a list of token IDs.
|
|
|
- self.detokenize = detokenize
|
|
|
- self.custom_token_bans = custom_token_bans or []
|
|
|
- self.skip_special_tokens = skip_special_tokens
|
|
|
- self.spaces_between_special_tokens = spaces_between_special_tokens
|
|
|
- self.logits_processors = logits_processors or []
|
|
|
- self.include_stop_str_in_output = include_stop_str_in_output
|
|
|
- self.truncate_prompt_tokens = truncate_prompt_tokens
|
|
|
- # Number of characters to hold back for stop string evaluation
|
|
|
- # until sequence is finished.
|
|
|
- if self.stop and not include_stop_str_in_output:
|
|
|
- self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
|
|
|
+ self.stop = list(self.stop)
|
|
|
+ if self.stop_token_ids is None:
|
|
|
+ self.stop_token_ids = []
|
|
|
else:
|
|
|
- self.output_text_buffer_length = 0
|
|
|
- self.xtc_threshold = xtc_threshold
|
|
|
- self.xtc_probability = xtc_probability
|
|
|
-
|
|
|
- self.default_values = {
|
|
|
- "n": 1,
|
|
|
- "best_of": 1,
|
|
|
- "presence_penalty": 0.0,
|
|
|
- "frequency_penalty": 0.0,
|
|
|
- "repetition_penalty": 1.0,
|
|
|
- "temperature": 1.0,
|
|
|
- "dynatemp_min": 0.0,
|
|
|
- "dynatemp_max": 0.0,
|
|
|
- "dynatemp_exponent": 1.0,
|
|
|
- "temperature_last": False,
|
|
|
- "top_p": 1.0,
|
|
|
- "top_k": -1,
|
|
|
- "top_a": 0.0,
|
|
|
- "min_p": 0.0,
|
|
|
- "tfs": 1.0,
|
|
|
- "eta_cutoff": 0.0,
|
|
|
- "epsilon_cutoff": 0.0,
|
|
|
- "typical_p": 1.0,
|
|
|
- "smoothing_factor": 0.0,
|
|
|
- "smoothing_curve": 1.0,
|
|
|
- "seed": None,
|
|
|
- "use_beam_search": False,
|
|
|
- "length_penalty": 1.0,
|
|
|
- "early_stopping": False,
|
|
|
- "stop": [],
|
|
|
- "stop_token_ids": [],
|
|
|
- "ignore_eos": False,
|
|
|
- "max_tokens": 16,
|
|
|
- "min_tokens": 0,
|
|
|
- "logprobs": None,
|
|
|
- "prompt_logprobs": None,
|
|
|
- "detokenize": True,
|
|
|
- "custom_token_bans": [],
|
|
|
- "skip_special_tokens": True,
|
|
|
- "spaces_between_special_tokens": True,
|
|
|
- "include_stop_str_in_output": False,
|
|
|
- "truncate_prompt_tokens": None,
|
|
|
- "xtc_threshold": 0.1,
|
|
|
- "xtc_probability": 0,
|
|
|
- }
|
|
|
+ self.stop_token_ids = list(self.stop_token_ids)
|
|
|
+ self.logprobs = 1 if self.logprobs is True else self.logprobs
|
|
|
+ self.prompt_logprobs = (1 if self.prompt_logprobs is True else
|
|
|
+ self.prompt_logprobs)
|
|
|
|
|
|
# Number of characters to hold back for stop string evaluation
|
|
|
# until sequence is finished.
|
|
|
- if self.stop and not include_stop_str_in_output:
|
|
|
+ if self.stop and not self.include_stop_str_in_output:
|
|
|
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
|
|
|
- else:
|
|
|
- self.output_text_buffer_length = 0
|
|
|
|
|
|
self._verify_args()
|
|
|
if self.use_beam_search:
|
|
@@ -322,11 +291,12 @@ class SamplingParams:
|
|
|
self.top_a = 0.0
|
|
|
self._verify_greedy_sampling()
|
|
|
# eos_token_id is added to this by the engine
|
|
|
- self.all_stop_token_ids = set(self.stop_token_ids)
|
|
|
+ self._all_stop_token_ids = set(self.stop_token_ids)
|
|
|
|
|
|
def _verify_args(self) -> None:
|
|
|
if self.n < 1:
|
|
|
raise ValueError(f"n must be at least 1, got {self.n}.")
|
|
|
+ assert isinstance(self.best_of, int)
|
|
|
if self.best_of < self.n:
|
|
|
raise ValueError(f"best_of must be greater than or equal to n, "
|
|
|
f"got n={self.n} and best_of={self.best_of}.")
|
|
@@ -383,6 +353,7 @@ class SamplingParams:
|
|
|
and self.truncate_prompt_tokens < 1):
|
|
|
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
|
|
|
f"got {self.truncate_prompt_tokens}")
|
|
|
+ assert isinstance(self.stop, list)
|
|
|
if any(not stop_str for stop_str in self.stop):
|
|
|
raise ValueError("stop cannot contain an empty string.")
|
|
|
if self.stop and not self.detokenize:
|
|
@@ -424,6 +395,7 @@ class SamplingParams:
|
|
|
"default value of 1.0 when not using beam search.")
|
|
|
|
|
|
def _verify_greedy_sampling(self) -> None:
|
|
|
+ assert isinstance(self.best_of, int)
|
|
|
if self.best_of > 1:
|
|
|
raise ValueError("best_of must be 1 when using greedy sampling."
|
|
|
f"Got {self.best_of}.")
|
|
@@ -441,7 +413,7 @@ class SamplingParams:
|
|
|
if model_eos_token_id is not None:
|
|
|
# Add the eos token id into the sampling_params to support
|
|
|
# min_tokens processing.
|
|
|
- self.all_stop_token_ids.add(model_eos_token_id)
|
|
|
+ self._all_stop_token_ids.add(model_eos_token_id)
|
|
|
|
|
|
# Update eos_token_id for generation
|
|
|
if (eos_ids := generation_config.get("eos_token_id")) is not None:
|
|
@@ -453,8 +425,9 @@ class SamplingParams:
|
|
|
# purposes.
|
|
|
eos_ids.discard(model_eos_token_id)
|
|
|
if eos_ids:
|
|
|
- self.all_stop_token_ids.update(eos_ids)
|
|
|
+ self._all_stop_token_ids.update(eos_ids)
|
|
|
if not self.ignore_eos:
|
|
|
+ assert isinstance(self.stop_token_ids, list)
|
|
|
eos_ids.update(self.stop_token_ids)
|
|
|
self.stop_token_ids = list(eos_ids)
|
|
|
|
|
@@ -468,6 +441,10 @@ class SamplingParams:
|
|
|
return SamplingType.RANDOM_SEED
|
|
|
return SamplingType.RANDOM
|
|
|
|
|
|
+ @property
|
|
|
+ def all_stop_token_ids(self) -> Set[int]:
|
|
|
+ return self._all_stop_token_ids
|
|
|
+
|
|
|
def clone(self) -> "SamplingParams":
|
|
|
"""Deep copy excluding LogitsProcessor objects.
|
|
|
LogitsProcessor objects are excluded because they may contain an
|