|
@@ -1,15 +1,20 @@
|
|
|
"""Sampling parameters for text generation."""
|
|
|
import copy
|
|
|
+import os
|
|
|
from enum import IntEnum
|
|
|
from functools import cached_property
|
|
|
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
|
|
|
import torch
|
|
|
+from loguru import logger
|
|
|
from pydantic import Field
|
|
|
from typing_extensions import Annotated
|
|
|
|
|
|
_SAMPLING_EPS = 1e-5
|
|
|
|
|
|
+APHRODITE_NO_DEPRECATION_WARNING = bool(
|
|
|
+ int(os.environ.get("APHRODITE_NO_DEPRECATION_WARNING", "0")))
|
|
|
+
|
|
|
|
|
|
class SamplingType(IntEnum):
|
|
|
GREEDY = 0
|
|
@@ -152,12 +157,6 @@ class SamplingParams:
|
|
|
eta_cutoff: float = 0.0,
|
|
|
epsilon_cutoff: float = 0.0,
|
|
|
typical_p: float = 1.0,
|
|
|
- mirostat_mode: int = 0,
|
|
|
- mirostat_tau: float = 0,
|
|
|
- mirostat_eta: float = 0,
|
|
|
- dynatemp_min: float = 0,
|
|
|
- dynatemp_max: float = 0,
|
|
|
- dynatemp_exponent: float = 1,
|
|
|
smoothing_factor: float = 0.0,
|
|
|
smoothing_curve: float = 1.0,
|
|
|
seed: Optional[int] = None,
|
|
@@ -193,12 +192,6 @@ class SamplingParams:
|
|
|
self.eta_cutoff = eta_cutoff
|
|
|
self.epsilon_cutoff = epsilon_cutoff
|
|
|
self.typical_p = typical_p
|
|
|
- self.mirostat_mode = mirostat_mode
|
|
|
- self.mirostat_tau = mirostat_tau
|
|
|
- self.mirostat_eta = mirostat_eta
|
|
|
- self.dynatemp_min = dynatemp_min
|
|
|
- self.dynatemp_max = dynatemp_max
|
|
|
- self.dynatemp_exponent = dynatemp_exponent
|
|
|
self.smoothing_factor = smoothing_factor
|
|
|
self.smoothing_curve = smoothing_curve
|
|
|
if seed == -1:
|
|
@@ -252,12 +245,6 @@ class SamplingParams:
|
|
|
"eta_cutoff": 0.0,
|
|
|
"epsilon_cutoff": 0.0,
|
|
|
"typical_p": 1.0,
|
|
|
- "mirostat_mode": 0,
|
|
|
- "mirostat_tau": 0,
|
|
|
- "mirostat_eta": 0,
|
|
|
- "dynatemp_min": 0,
|
|
|
- "dynatemp_max": 0,
|
|
|
- "dynatemp_exponent": 1,
|
|
|
"smoothing_factor": 0.0,
|
|
|
"smoothing_curve": 1.0,
|
|
|
"seed": None,
|
|
@@ -288,6 +275,12 @@ class SamplingParams:
|
|
|
|
|
|
self._verify_args()
|
|
|
if self.use_beam_search:
|
|
|
+ if not APHRODITE_NO_DEPRECATION_WARNING:
|
|
|
+ logger.warning(
|
|
|
+ "[IMPORTANT] We plan to discontinue the support for beam "
|
|
|
+ "search in the next major release. Set "
|
|
|
+ "APHRODITE_NO_DEPRECATION_WARNING=1 to "
|
|
|
+ "suppress this warning.")
|
|
|
self._verify_beam_search()
|
|
|
else:
|
|
|
self._verify_non_beam_search()
|
|
@@ -340,32 +333,6 @@ class SamplingParams:
|
|
|
if not 0.0 <= self.typical_p <= 1.0:
|
|
|
raise ValueError(
|
|
|
f"typical_p must be in (0, 1], got {self.typical_p}.")
|
|
|
- if not self.dynatemp_min >= 0:
|
|
|
- raise ValueError(
|
|
|
- f"dynatemp_min must be non negative, got {self.dynatemp_min}.")
|
|
|
- if not self.dynatemp_max >= 0:
|
|
|
- raise ValueError(
|
|
|
- f"dynatemp_max must be non negative, got {self.dynatemp_max}.")
|
|
|
- if not self.dynatemp_exponent >= 0:
|
|
|
- raise ValueError(f"dynatemp_exponent must be non negative, got "
|
|
|
- f"{self.dynatemp_exponent}.")
|
|
|
- if not self.smoothing_factor >= 0:
|
|
|
- raise ValueError(f"smoothing_factor must be non negative, got "
|
|
|
- f"{self.smoothing_factor}.")
|
|
|
- if not self.smoothing_curve >= 1.0:
|
|
|
- raise ValueError(f"smoothing_curve must larger than 1, got "
|
|
|
- f"{self.smoothing_curve}.")
|
|
|
- if self.mirostat_mode:
|
|
|
- if not self.mirostat_mode == 2:
|
|
|
- raise ValueError(
|
|
|
- "Only Mirostat v2 (2) and disabled (0) supported, "
|
|
|
- f"got {self.mirostat_mode}")
|
|
|
- if not self.mirostat_eta >= 0:
|
|
|
- raise ValueError(
|
|
|
- f"mirostat_eta must be positive, got {self.mirostat_eta}")
|
|
|
- if not self.mirostat_tau >= 0:
|
|
|
- raise ValueError(
|
|
|
- f"mirostat_tau must be positive, got {self.mirostat_tau}")
|
|
|
if self.max_tokens is not None and self.max_tokens < 1:
|
|
|
raise ValueError(
|
|
|
f"max_tokens must be at least 1, got {self.max_tokens}.")
|