Преглед на файлове

chore: deprecation warning for beam search

AlpinDale преди 6 месеца
родител
ревизия
bf15e1b4e8
променени са 1 файла, в които са добавени 11 реда и са изтрити 44 реда
  1. 11 44
      aphrodite/common/sampling_params.py

+ 11 - 44
aphrodite/common/sampling_params.py

@@ -1,15 +1,20 @@
 """Sampling parameters for text generation."""
 import copy
+import os
 from enum import IntEnum
 from functools import cached_property
 from typing import Any, Callable, Dict, List, Optional, Union
 
 import torch
+from loguru import logger
 from pydantic import Field
 from typing_extensions import Annotated
 
 _SAMPLING_EPS = 1e-5
 
+APHRODITE_NO_DEPRECATION_WARNING = bool(
+    int(os.environ.get("APHRODITE_NO_DEPRECATION_WARNING", "0")))
+
 
 class SamplingType(IntEnum):
     GREEDY = 0
@@ -152,12 +157,6 @@ class SamplingParams:
         eta_cutoff: float = 0.0,
         epsilon_cutoff: float = 0.0,
         typical_p: float = 1.0,
-        mirostat_mode: int = 0,
-        mirostat_tau: float = 0,
-        mirostat_eta: float = 0,
-        dynatemp_min: float = 0,
-        dynatemp_max: float = 0,
-        dynatemp_exponent: float = 1,
         smoothing_factor: float = 0.0,
         smoothing_curve: float = 1.0,
         seed: Optional[int] = None,
@@ -193,12 +192,6 @@ class SamplingParams:
         self.eta_cutoff = eta_cutoff
         self.epsilon_cutoff = epsilon_cutoff
         self.typical_p = typical_p
-        self.mirostat_mode = mirostat_mode
-        self.mirostat_tau = mirostat_tau
-        self.mirostat_eta = mirostat_eta
-        self.dynatemp_min = dynatemp_min
-        self.dynatemp_max = dynatemp_max
-        self.dynatemp_exponent = dynatemp_exponent
         self.smoothing_factor = smoothing_factor
         self.smoothing_curve = smoothing_curve
         if seed == -1:
@@ -252,12 +245,6 @@ class SamplingParams:
             "eta_cutoff": 0.0,
             "epsilon_cutoff": 0.0,
             "typical_p": 1.0,
-            "mirostat_mode": 0,
-            "mirostat_tau": 0,
-            "mirostat_eta": 0,
-            "dynatemp_min": 0,
-            "dynatemp_max": 0,
-            "dynatemp_exponent": 1,
             "smoothing_factor": 0.0,
             "smoothing_curve": 1.0,
             "seed": None,
@@ -288,6 +275,12 @@ class SamplingParams:
 
         self._verify_args()
         if self.use_beam_search:
+            if not APHRODITE_NO_DEPRECATION_WARNING:
+                logger.warning(
+                    "[IMPORTANT] We plan to discontinue the support for beam "
+                    "search in the next major release. Set "
+                    "APHRODITE_NO_DEPRECATION_WARNING=1 to "
+                    "suppress this warning.")
             self._verify_beam_search()
         else:
             self._verify_non_beam_search()
@@ -340,32 +333,6 @@ class SamplingParams:
         if not 0.0 <= self.typical_p <= 1.0:
             raise ValueError(
                 f"typical_p must be in (0, 1], got {self.typical_p}.")
-        if not self.dynatemp_min >= 0:
-            raise ValueError(
-                f"dynatemp_min must be non negative, got {self.dynatemp_min}.")
-        if not self.dynatemp_max >= 0:
-            raise ValueError(
-                f"dynatemp_max must be non negative, got {self.dynatemp_max}.")
-        if not self.dynatemp_exponent >= 0:
-            raise ValueError(f"dynatemp_exponent must be non negative, got "
-                             f"{self.dynatemp_exponent}.")
-        if not self.smoothing_factor >= 0:
-            raise ValueError(f"smoothing_factor must be non negative, got "
-                             f"{self.smoothing_factor}.")
-        if not self.smoothing_curve >= 1.0:
-            raise ValueError(f"smoothing_curve must larger than 1, got "
-                             f"{self.smoothing_curve}.")
-        if self.mirostat_mode:
-            if not self.mirostat_mode == 2:
-                raise ValueError(
-                    "Only Mirostat v2 (2) and disabled (0) supported, "
-                    f"got {self.mirostat_mode}")
-            if not self.mirostat_eta >= 0:
-                raise ValueError(
-                    f"mirostat_eta must be positive, got {self.mirostat_eta}")
-            if not self.mirostat_tau >= 0:
-                raise ValueError(
-                    f"mirostat_tau must be positive, got {self.mirostat_tau}")
         if self.max_tokens is not None and self.max_tokens < 1:
             raise ValueError(
                 f"max_tokens must be at least 1, got {self.max_tokens}.")