浏览代码

rather painful migration of sampling params

AlpinDale 4 月之前
父节点
当前提交
22427b1d4c
共有 1 个文件被更改,包括 128 次插入151 次删除
  1. 128 151
      aphrodite/common/sampling_params.py

+ 128 - 151
aphrodite/common/sampling_params.py

@@ -3,11 +3,11 @@ import copy
 import os
 from enum import IntEnum
 from functools import cached_property
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Set, Union
 
+import msgspec
 import torch
 from loguru import logger
-from pydantic import Field
 from typing_extensions import Annotated
 
 _SAMPLING_EPS = 1e-5
@@ -34,7 +34,10 @@ first argument, and returns a modified tensor of logits
 to sample from."""
 
 
-class SamplingParams:
+class SamplingParams(
+    msgspec.Struct,
+    omit_defaults=True,
+    dict=True):
     """Sampling parameters for text generation.
 
     Overall, we follow the sampling parameters from the OpenAI text completion
@@ -147,161 +150,127 @@ class SamplingParams:
             0 disables the sampler, 1 makes it always happen.
     """
 
-    def __init__(
-        self,
-        n: int = 1,
-        best_of: Optional[int] = None,
-        presence_penalty: float = 0.0,
-        frequency_penalty: float = 0.0,
-        repetition_penalty: float = 1.0,
-        temperature: float = 1.0,
-        dynatemp_min: float = 0.0,
-        dynatemp_max: float = 0.0,
-        dynatemp_exponent: float = 1.0,
-        temperature_last: bool = False,
-        top_p: float = 1.0,
-        top_k: int = -1,
-        top_a: float = 0.0,
-        min_p: float = 0.0,
-        tfs: float = 1.0,
-        eta_cutoff: float = 0.0,
-        epsilon_cutoff: float = 0.0,
-        typical_p: float = 1.0,
-        smoothing_factor: float = 0.0,
-        smoothing_curve: float = 1.0,
-        seed: Optional[int] = None,
-        use_beam_search: bool = False,
-        length_penalty: float = 1.0,
-        early_stopping: Union[bool, str] = False,
-        stop: Union[None, str, List[str]] = None,
-        stop_token_ids: Optional[List[int]] = None,
-        include_stop_str_in_output: bool = False,
-        ignore_eos: bool = False,
-        max_tokens: Optional[int] = 16,
-        min_tokens: int = 0,
-        logprobs: Optional[int] = None,
-        prompt_logprobs: Optional[int] = None,
-        detokenize: bool = True,
-        custom_token_bans: Optional[List[int]] = None,
-        skip_special_tokens: bool = True,
-        spaces_between_special_tokens: bool = True,
-        logits_processors: Optional[List[LogitsProcessorFunc]] = None,
-        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
-        xtc_threshold: float = 0.1,
-        xtc_probability: float = 0,
-    ) -> None:
-        self.n = n
-        self.best_of = best_of if best_of is not None else n
-        self.presence_penalty = presence_penalty
-        self.frequency_penalty = frequency_penalty
-        self.repetition_penalty = repetition_penalty
-        if 0 < temperature < _MAX_TEMP:
+    n: int = 1
+    best_of: Optional[int] = None
+    presence_penalty: float = 0.0
+    frequency_penalty: float = 0.0
+    repetition_penalty: float = 1.0
+    temperature: float = 1.0
+    dynatemp_min: float = 0.0
+    dynatemp_max: float = 0.0
+    dynatemp_exponent: float = 1.0
+    temperature_last: bool = False
+    top_p: float = 1.0
+    top_k: int = -1
+    top_a: float = 0.0
+    min_p: float = 0.0
+    tfs: float = 1.0
+    eta_cutoff: float = 0.0
+    epsilon_cutoff: float = 0.0
+    typical_p: float = 1.0
+    smoothing_factor: float = 0.0
+    smoothing_curve: float = 1.0
+    seed: Optional[int] = None
+    use_beam_search: bool = False
+    length_penalty: float = 1.0
+    early_stopping: Union[bool, str] = False
+    stop: Union[None, str, List[str]] = None
+    stop_token_ids: Optional[List[int]] = None
+    include_stop_str_in_output: bool = False
+    ignore_eos: bool = False
+    max_tokens: Optional[int] = 16
+    min_tokens: int = 0
+    logprobs: Optional[int] = None
+    prompt_logprobs: Optional[int] = None
+    detokenize: bool = True
+    custom_token_bans: Optional[List[int]] = None
+    skip_special_tokens: bool = True
+    spaces_between_special_tokens: bool = True
+    # Optional[List[LogitsProcessorFunc]] type.
+    # We use Any here because the type above
+    # is not supported by msgspec.
+    logits_processors: Optional[Any] = None
+    truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
+    xtc_threshold: float = 0.1
+    xtc_probability: float = 0
+
+    # The below fields are not supposed to be used as an input.
+    # They are set in post_init.
+    output_text_buffer_length: int = 0
+    _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
+
+    default_values = {
+        "n": 1,
+        "best_of": 1,
+        "presence_penalty": 0.0,
+        "frequency_penalty": 0.0,
+        "repetition_penalty": 1.0,
+        "temperature": 1.0,
+        "dynatemp_min": 0.0,
+        "dynatemp_max": 0.0,
+        "dynatemp_exponent": 1.0,
+        "temperature_last": False,
+        "top_p": 1.0,
+        "top_k": -1,
+        "top_a": 0.0,
+        "min_p": 0.0,
+        "tfs": 1.0,
+        "eta_cutoff": 0.0,
+        "epsilon_cutoff": 0.0,
+        "typical_p": 1.0,
+        "smoothing_factor": 0.0,
+        "smoothing_curve": 1.0,
+        "seed": None,
+        "use_beam_search": False,
+        "length_penalty": 1.0,
+        "early_stopping": False,
+        "stop": [],
+        "stop_token_ids": [],
+        "ignore_eos": False,
+        "max_tokens": 16,
+        "min_tokens": 0,
+        "logprobs": None,
+        "prompt_logprobs": None,
+        "detokenize": True,
+        "custom_token_bans": [],
+        "skip_special_tokens": True,
+        "spaces_between_special_tokens": True,
+        "include_stop_str_in_output": False,
+        "truncate_prompt_tokens": None,
+        "xtc_threshold": 0.1,
+        "xtc_probability": 0,
+    }
+
+    def __post_init__(self) -> None:
+        self.best_of = self.best_of or self.n
+        if 0 < self.temperature < _MAX_TEMP:
             logger.warning(
-                f"temperature {temperature} is less than {_MAX_TEMP}, "
-                f"which may cause numerical errors (NaN or Inf) in tensors. "
-                f"We have capped the temperature to {_MAX_TEMP}.")
-            temperature = min(temperature, _MAX_TEMP)
-        self.temperature = temperature
-        self.dynatemp_min = dynatemp_min
-        self.dynatemp_max = dynatemp_max
-        self.dynatemp_exponent = dynatemp_exponent
-        self.temperature_last = temperature_last
-        self.top_p = top_p
-        self.top_k = top_k
-        self.top_a = top_a
-        self.min_p = min_p
-        self.tfs = tfs
-        self.eta_cutoff = eta_cutoff
-        self.epsilon_cutoff = epsilon_cutoff
-        self.typical_p = typical_p
-        self.smoothing_factor = smoothing_factor
-        self.smoothing_curve = smoothing_curve
-        if seed == -1:
+                f"temperature {self.temperature} is less than {_MAX_TEMP}, "
+                "which may cause numerical errors NaN or inf in tensors. We "
+                f"have maxed it out to {_MAX_TEMP}.")
+            self.temperature = max(self.temperature, _MAX_TEMP)
+        if self.seed == -1:
             self.seed = None
         else:
-            self.seed = seed
-        self.use_beam_search = use_beam_search
-        self.length_penalty = length_penalty
-        self.early_stopping = early_stopping
-        if stop is None:
+            self.seed = self.seed
+        if self.stop is None:
             self.stop = []
-        elif isinstance(stop, str):
-            self.stop = [stop]
+        elif isinstance(self.stop, str):
+            self.stop = [self.stop]
         else:
-            self.stop = list(stop)
-        self.stop_token_ids = stop_token_ids or []
-        self.ignore_eos = ignore_eos
-        self.max_tokens = max_tokens
-        self.min_tokens = min_tokens
-        self.logprobs = 1 if logprobs is True else logprobs
-        self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
-        # NOTE: This parameter is only exposed at the engine level for now.
-        # It is not exposed in the OpenAI API server, as the OpenAI API does
-        # not support returning only a list of token IDs.
-        self.detokenize = detokenize
-        self.custom_token_bans = custom_token_bans or []
-        self.skip_special_tokens = skip_special_tokens
-        self.spaces_between_special_tokens = spaces_between_special_tokens
-        self.logits_processors = logits_processors or []
-        self.include_stop_str_in_output = include_stop_str_in_output
-        self.truncate_prompt_tokens = truncate_prompt_tokens
-        # Number of characters to hold back for stop string evaluation
-        # until sequence is finished.
-        if self.stop and not include_stop_str_in_output:
-            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
+            self.stop = list(self.stop)
+        if self.stop_token_ids is None:
+            self.stop_token_ids = []
         else:
-            self.output_text_buffer_length = 0
-        self.xtc_threshold = xtc_threshold
-        self.xtc_probability = xtc_probability
-
-        self.default_values = {
-            "n": 1,
-            "best_of": 1,
-            "presence_penalty": 0.0,
-            "frequency_penalty": 0.0,
-            "repetition_penalty": 1.0,
-            "temperature": 1.0,
-            "dynatemp_min": 0.0,
-            "dynatemp_max": 0.0,
-            "dynatemp_exponent": 1.0,
-            "temperature_last": False,
-            "top_p": 1.0,
-            "top_k": -1,
-            "top_a": 0.0,
-            "min_p": 0.0,
-            "tfs": 1.0,
-            "eta_cutoff": 0.0,
-            "epsilon_cutoff": 0.0,
-            "typical_p": 1.0,
-            "smoothing_factor": 0.0,
-            "smoothing_curve": 1.0,
-            "seed": None,
-            "use_beam_search": False,
-            "length_penalty": 1.0,
-            "early_stopping": False,
-            "stop": [],
-            "stop_token_ids": [],
-            "ignore_eos": False,
-            "max_tokens": 16,
-            "min_tokens": 0,
-            "logprobs": None,
-            "prompt_logprobs": None,
-            "detokenize": True,
-            "custom_token_bans": [],
-            "skip_special_tokens": True,
-            "spaces_between_special_tokens": True,
-            "include_stop_str_in_output": False,
-            "truncate_prompt_tokens": None,
-            "xtc_threshold": 0.1,
-            "xtc_probability": 0,
-        }
+            self.stop_token_ids = list(self.stop_token_ids)
+        self.logprobs = 1 if self.logprobs is True else self.logprobs
+        self.prompt_logprobs = (1 if self.prompt_logprobs is True else
+                                self.prompt_logprobs)
 
         # Number of characters to hold back for stop string evaluation
         # until sequence is finished.
-        if self.stop and not include_stop_str_in_output:
+        if self.stop and not self.include_stop_str_in_output:
             self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
-        else:
-            self.output_text_buffer_length = 0
 
         self._verify_args()
         if self.use_beam_search:
@@ -322,11 +291,12 @@ class SamplingParams:
                 self.top_a = 0.0
                 self._verify_greedy_sampling()
         # eos_token_id is added to this by the engine
-        self.all_stop_token_ids = set(self.stop_token_ids)
+        self._all_stop_token_ids = set(self.stop_token_ids)
 
     def _verify_args(self) -> None:
         if self.n < 1:
             raise ValueError(f"n must be at least 1, got {self.n}.")
+        assert isinstance(self.best_of, int)
         if self.best_of < self.n:
             raise ValueError(f"best_of must be greater than or equal to n, "
                              f"got n={self.n} and best_of={self.best_of}.")
@@ -383,6 +353,7 @@ class SamplingParams:
                 and self.truncate_prompt_tokens < 1):
             raise ValueError(f"truncate_prompt_tokens must be >= 1, "
                              f"got {self.truncate_prompt_tokens}")
+        assert isinstance(self.stop, list)
         if any(not stop_str for stop_str in self.stop):
             raise ValueError("stop cannot contain an empty string.")
         if self.stop and not self.detokenize:
@@ -424,6 +395,7 @@ class SamplingParams:
                 "default value of 1.0 when not using beam search.")
 
     def _verify_greedy_sampling(self) -> None:
+        assert isinstance(self.best_of, int)
         if self.best_of > 1:
             raise ValueError("best_of must be 1 when using greedy sampling."
                              f"Got {self.best_of}.")
@@ -441,7 +413,7 @@ class SamplingParams:
         if model_eos_token_id is not None:
             # Add the eos token id into the sampling_params to support
             # min_tokens processing.
-            self.all_stop_token_ids.add(model_eos_token_id)
+            self._all_stop_token_ids.add(model_eos_token_id)
 
         # Update eos_token_id for generation
         if (eos_ids := generation_config.get("eos_token_id")) is not None:
@@ -453,8 +425,9 @@ class SamplingParams:
                 # purposes.
                 eos_ids.discard(model_eos_token_id)
             if eos_ids:
-                self.all_stop_token_ids.update(eos_ids)
+                self._all_stop_token_ids.update(eos_ids)
                 if not self.ignore_eos:
+                    assert isinstance(self.stop_token_ids, list)
                     eos_ids.update(self.stop_token_ids)
                     self.stop_token_ids = list(eos_ids)
 
@@ -468,6 +441,10 @@ class SamplingParams:
             return SamplingType.RANDOM_SEED
         return SamplingType.RANDOM
 
+    @property
+    def all_stop_token_ids(self) -> Set[int]:
+        return self._all_stop_token_ids
+
     def clone(self) -> "SamplingParams":
         """Deep copy excluding LogitsProcessor objects.
         LogitsProcessor objects are excluded because they may contain an