Procházet zdrojové kódy

feat: bring back dynatemp (#754)

* feat: bring back dynatemp

Co-authored-by: 50h100a <136940546+50h100a@users.noreply.github.com>

* remove unused global

* fix logging

* re-enable custom token bans in OAI API

---------

Co-authored-by: 50h100a <136940546+50h100a@users.noreply.github.com>
AlpinDale před 5 měsíci
rodič
revize
ad181e3fef

+ 9 - 0
aphrodite/common/sampling_params.py

@@ -155,6 +155,9 @@ class SamplingParams:
         frequency_penalty: float = 0.0,
         repetition_penalty: float = 1.0,
         temperature: float = 1.0,
+        dynatemp_min: float = 0.0,
+        dynatemp_max: float = 0.0,
+        dynatemp_exponent: float = 1.0,
         temperature_last: bool = False,
         top_p: float = 1.0,
         top_k: int = -1,
@@ -199,6 +202,9 @@ class SamplingParams:
                 f"We have capped the temperature to {_MAX_TEMP}.")
             temperature = min(temperature, _MAX_TEMP)
         self.temperature = temperature
+        self.dynatemp_min = dynatemp_min
+        self.dynatemp_max = dynatemp_max
+        self.dynatemp_exponent = dynatemp_exponent
         self.temperature_last = temperature_last
         self.top_p = top_p
         self.top_k = top_k
@@ -255,6 +261,9 @@ class SamplingParams:
             "frequency_penalty": 0.0,
             "repetition_penalty": 1.0,
             "temperature": 1.0,
+            "dynatemp_min": 0.0,
+            "dynatemp_max": 0.0,
+            "dynatemp_exponent": 1.0,
             "temperature_last": False,
             "top_p": 1.0,
             "top_k": -1,

+ 13 - 9
aphrodite/endpoints/openai/api_server.py

@@ -190,8 +190,8 @@ def mount_metrics(app: FastAPI):
                                    multiprocess)
     prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
     if prometheus_multiproc_dir_path is not None:
-        logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
-                    prometheus_multiproc_dir_path)
+        logger.info(f"Aphrodite to use {prometheus_multiproc_dir_path} "
+                    "as PROMETHEUS_MULTIPROC_DIR")
         registry = CollectorRegistry()
         multiprocess.MultiProcessCollector(registry)
         # Add prometheus asgi middleware to route /metrics requests
@@ -344,12 +344,6 @@ def prepare_engine_payload(
     if not kai_payload.genkey:
         kai_payload.genkey = f"kai-{random_uuid()}"
 
-    # if kai_payload.max_context_length > engine_args.max_model_len:
-    #     raise ValueError(
-    #         f"max_context_length ({kai_payload.max_context_length}) "
-    #         "must be less than or equal to "
-    #         f"max_model_len ({engine_args.max_model_len})")
-
     kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
     kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
     if kai_payload.temperature < _SAMPLING_EPS:
@@ -357,6 +351,11 @@ def prepare_engine_payload(
         kai_payload.top_p = 1.0
         kai_payload.top_k = -1
 
+    dynatemp_min = kai_payload.temperature - kai_payload.dynatemp_range / 2 \
+        if kai_payload.dynatemp_range else None
+    dynatemp_max = kai_payload.temperature + kai_payload.dynatemp_range / 2 \
+        if kai_payload.dynatemp_range else None
+
     sampling_params = SamplingParams(
         n=kai_payload.n,
         best_of=kai_payload.n,
@@ -378,6 +377,11 @@ def prepare_engine_payload(
         if kai_payload.use_default_badwordsids else [],
         max_tokens=kai_payload.max_length,
         seed=kai_payload.sampler_seed,
+        dynatemp_min=dynatemp_min,
+        dynatemp_max=dynatemp_max,
+        dynatemp_exponent=kai_payload.dynatemp_exponent,
+        xtc_probability=kai_payload.xtc_probability,
+        xtc_threshold=kai_payload.xtc_threshold,
     )
 
     max_input_tokens = max(
@@ -712,7 +716,7 @@ async def init_app(
 
     if args.launch_kobold_api:
         _set_badwords(tokenizer, model_config.hf_config)
-
+    
     return app
 
 

+ 20 - 61
aphrodite/endpoints/openai/protocol.py

@@ -4,8 +4,7 @@ import time
 from typing import Any, Dict, List, Literal, Optional, Union
 
 import torch
-from pydantic import (BaseModel, ConfigDict, Field, model_validator,
-                      root_validator)
+from pydantic import BaseModel, ConfigDict, Field, model_validator
 from transformers import PreTrainedTokenizer
 from typing_extensions import Annotated
 
@@ -148,6 +147,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
     prompt_logprobs: Optional[int] = None
     xtc_threshold: Optional[float] = 0.1
     xtc_probability: Optional[float] = 0.0
+    dynatemp_min: Optional[float] = 0.0
+    dynatemp_max: Optional[float] = 0.0
+    dynatemp_exponent: Optional[float] = 1.0
+    custom_token_bans: Optional[List[int]] = None
     # doc: end-chat-completion-sampling-params
 
     # doc: begin-chat-completion-extra-params
@@ -287,6 +290,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
             temperature_last=self.temperature_last,
             xtc_threshold=self.xtc_threshold,
             xtc_probability=self.xtc_probability,
+            dynatemp_min=self.dynatemp_min,
+            dynatemp_max=self.dynatemp_max,
+            dynatemp_exponent=self.dynatemp_exponent,
+            custom_token_bans=self.custom_token_bans,
         )
 
     @model_validator(mode='before')
@@ -394,6 +401,10 @@ class CompletionRequest(OpenAIBaseModel):
     prompt_logprobs: Optional[int] = None
     xtc_threshold: Optional[float] = 0.1
     xtc_probability: Optional[float] = 0.0
+    dynatemp_min: Optional[float] = 0.0
+    dynatemp_max: Optional[float] = 0.0
+    dynatemp_exponent: Optional[float] = 1.0
+    custom_token_bans: Optional[List[int]] = None
     # doc: end-completion-sampling-params
 
     # doc: begin-completion-extra-params
@@ -492,6 +503,10 @@ class CompletionRequest(OpenAIBaseModel):
             temperature_last=self.temperature_last,
             xtc_threshold=self.xtc_threshold,
             xtc_probability=self.xtc_probability,
+            dynatemp_min=self.dynatemp_min,
+            dynatemp_max=self.dynatemp_max,
+            dynatemp_exponent=self.dynatemp_exponent,
+            custom_token_bans=self.custom_token_bans,
         )
 
     @model_validator(mode="before")
@@ -771,45 +786,6 @@ class DetokenizeResponse(OpenAIBaseModel):
 # ========== KoboldAI ========== #
 
 
-class KoboldSamplingParams(BaseModel):
-    n: int = Field(1, alias="n")
-    best_of: Optional[int] = Field(None, alias="best_of")
-    presence_penalty: float = Field(0.0, alias="presence_penalty")
-    frequency_penalty: float = Field(0.0, alias="rep_pen")
-    temperature: float = Field(1.0, alias="temperature")
-    dynatemp_range: Optional[float] = 0.0
-    dynatemp_exponent: Optional[float] = 1.0
-    smoothing_factor: Optional[float] = 0.0
-    smoothing_curve: Optional[float] = 1.0
-    top_p: float = Field(1.0, alias="top_p")
-    top_k: float = Field(-1, alias="top_k")
-    min_p: float = Field(0.0, alias="min_p")
-    top_a: float = Field(0.0, alias="top_a")
-    tfs: float = Field(1.0, alias="tfs")
-    eta_cutoff: float = Field(0.0, alias="eta_cutoff")
-    epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
-    typical_p: float = Field(1.0, alias="typical_p")
-    use_beam_search: bool = Field(False, alias="use_beam_search")
-    length_penalty: float = Field(1.0, alias="length_penalty")
-    early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
-    stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
-    include_stop_str_in_output: Optional[bool] = False
-    ignore_eos: bool = Field(False, alias="ignore_eos")
-    max_tokens: int = Field(16, alias="max_length")
-    logprobs: Optional[int] = Field(None, alias="logprobs")
-    custom_token_bans: Optional[List[int]] = Field(None,
-                                                   alias="custom_token_bans")
-
-    @root_validator(pre=False, skip_on_failure=True)
-    def validate_best_of(cls, values):  # pylint: disable=no-self-argument
-        best_of = values.get("best_of")
-        n = values.get("n")
-        if best_of is not None and (best_of <= 0 or best_of > n):
-            raise ValueError(
-                "best_of must be a positive integer less than or equal to n")
-        return values
-
-
 class KAIGenerationInputSchema(BaseModel):
     genkey: Optional[str] = None
     prompt: str
@@ -817,8 +793,6 @@ class KAIGenerationInputSchema(BaseModel):
     max_context_length: int
     max_length: int
     rep_pen: Optional[float] = 1.0
-    rep_pen_range: Optional[int] = None
-    rep_pen_slope: Optional[float] = None
     top_k: Optional[int] = 0
     top_a: Optional[float] = 0.0
     top_p: Optional[float] = 1.0
@@ -832,31 +806,16 @@ class KAIGenerationInputSchema(BaseModel):
     dynatemp_exponent: Optional[float] = 1.0
     smoothing_factor: Optional[float] = 0.0
     smoothing_curve: Optional[float] = 1.0
-    use_memory: Optional[bool] = None
-    use_story: Optional[bool] = None
-    use_authors_note: Optional[bool] = None
-    use_world_info: Optional[bool] = None
-    use_userscripts: Optional[bool] = None
-    soft_prompt: Optional[str] = None
-    disable_output_formatting: Optional[bool] = None
-    frmtrmblln: Optional[bool] = None
-    frmtrmspch: Optional[bool] = None
-    singleline: Optional[bool] = None
+    xtc_threshold: Optional[float] = 0.1
+    xtc_probability: Optional[float] = 0.0
     use_default_badwordsids: Optional[bool] = None
-    mirostat: Optional[int] = 0
-    mirostat_tau: Optional[float] = 0.0
-    mirostat_eta: Optional[float] = 0.0
-    disable_input_formatting: Optional[bool] = None
-    frmtadsnsp: Optional[bool] = None
     quiet: Optional[bool] = None
     # pylint: disable=unexpected-keyword-arg
-    sampler_order: Optional[Union[List, str]] = Field(default_factory=list)
     sampler_seed: Optional[int] = None
-    sampler_full_determinism: Optional[bool] = None
     stop_sequence: Optional[List[str]] = None
     include_stop_str_in_output: Optional[bool] = False
 
-    @root_validator(pre=False, skip_on_failure=True)
+    @model_validator(mode='after')
     def check_context(cls, values):  # pylint: disable=no-self-argument
         assert values.get("max_length") <= values.get(
             "max_context_length"

+ 47 - 14
aphrodite/modeling/layers/sampler.py

@@ -70,14 +70,15 @@ class Sampler(nn.Module):
         self._sampling_tensors = None
 
         # Initialize new sampling tensors
-        (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as, do_min_p,
-         do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
-         do_quadratic, do_xtc,
-         do_temp_last) = SamplingTensors.from_sampling_metadata(
+        (sampling_tensors, do_penalties, do_temperatures, do_top_p_top_k,
+         do_top_as, do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
+         do_typical_ps, do_quadratic, do_xtc, do_temp_last
+         ) = SamplingTensors.from_sampling_metadata(
              sampling_metadata, vocab_size, logits.device, logits.dtype)
 
         self._sampling_tensors = sampling_tensors
         self._do_penalties = do_penalties
+        self._do_temperatures = do_temperatures
         self._do_top_p_top_k = do_top_p_top_k
         self._do_top_as = do_top_as
         self._do_min_p = do_min_p
@@ -115,6 +116,7 @@ class Sampler(nn.Module):
         assert self._sampling_tensors is not None
         sampling_tensors = self._sampling_tensors
         do_penalties = self._do_penalties
+        do_temperatures = self._do_temperatures
         do_top_p_top_k = self._do_top_p_top_k
         do_top_as = self._do_top_as
         do_min_p = self._do_min_p
@@ -137,11 +139,11 @@ class Sampler(nn.Module):
                                       sampling_tensors.repetition_penalties)
 
         # Apply temperature scaling if not doing temp_last.
-        if not do_temp_last:
-            # Use float32 to apply temp.
-            # Use in-place division to avoid creating a new tensor.
-            logits = logits.to(torch.float)
-            logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
+        if do_temperatures and not do_temp_last:
+            _apply_temperatures(logits, sampling_tensors.temperatures,
+                                sampling_tensors.dynatemp_mins,
+                                sampling_tensors.dynatemp_maxs,
+                                sampling_tensors.dynatemp_exps)
 
         if do_top_p_top_k:
             logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
@@ -177,11 +179,11 @@ class Sampler(nn.Module):
                 logits, sampling_tensors.xtc_thresholds,
                 sampling_tensors.xtc_probabilities)
 
-        if do_temp_last:
-            # Use float32 to apply temp.
-            # Use in-place division to avoid creating a new tensor.
-            logits = logits.to(torch.float)
-            logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
+        if do_temperatures and do_temp_last:
+            _apply_temperatures(logits, sampling_tensors.temperatures,
+                                sampling_tensors.dynatemp_mins,
+                                sampling_tensors.dynatemp_maxs,
+                                sampling_tensors.dynatemp_exps)
 
         banned_tokens = _get_custom_token_bans(sampling_metadata)
         logits = _apply_token_bans(logits, banned_tokens)
@@ -294,6 +296,37 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
     return logits
 
 
+def _apply_temperatures(
+    logits: torch.Tensor,
+    temperatures: torch.Tensor,
+    dynatemp_mins: torch.Tensor,
+    dynatemp_maxs: torch.Tensor,
+    dynatemp_exps: torch.Tensor,
+) -> None:
+    dynatemp_mask = dynatemp_exps != 0
+    dynatemp_mins = dynatemp_mins[dynatemp_mask]
+    dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
+    dynatemp_exps = dynatemp_exps[dynatemp_mask]
+
+    dynatemp_logits = logits[dynatemp_mask]
+    dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
+    dynatemp_probs = dynatemp_shifted_logits.exp()
+    dynatemp_entropies = -(dynatemp_probs *
+                           dynatemp_shifted_logits).nansum(dim=-1)
+    dynatemp_max_entropies = torch.log_(
+        (dynatemp_logits > float("-inf")).sum(dim=-1).float())
+    normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
+    dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
+                normalized_entropies.pow_(dynatemp_exps))
+
+    temperatures[dynatemp_mask] = dyn_temp
+    temperatures[temperatures <= 0.0] = 1.0
+    # Use float32 to apply temp.
+    # Use in-place division to avoid creating a new tensor.
+    logits = logits.to(torch.float)
+    logits.div_(temperatures.unsqueeze(dim=1))
+
+
 def _apply_token_bans(logits: torch.Tensor,
                       banned_tokens: List[List[int]]) -> torch.Tensor:
     for i, banned_token_ids in enumerate(banned_tokens):

+ 48 - 6
aphrodite/modeling/sampling_metadata.py

@@ -367,6 +367,9 @@ class SamplingTensors:
     """Tensors for sampling."""
 
     temperatures: torch.Tensor
+    dynatemp_mins: torch.Tensor
+    dynatemp_maxs: torch.Tensor
+    dynatemp_exps: torch.Tensor
     temperature_lasts: torch.Tensor
     top_ps: torch.Tensor
     top_ks: torch.Tensor
@@ -400,7 +403,7 @@ class SamplingTensors:
         extra_seeds_to_generate: int = 0,
         extra_entropy: Optional[Tuple[int, ...]] = None
     ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
-               bool, bool, bool, bool]:
+               bool, bool, bool, bool, bool]:
         """
         extra_seeds_to_generate: extra seeds to generate using the
             user-defined seed for each sequence.
@@ -410,6 +413,9 @@ class SamplingTensors:
         output_tokens: List[array] = []
         top_ks: List[int] = []
         temperatures: List[float] = []
+        dynatemp_mins: List[float] = []
+        dynatemp_maxs: List[float] = []
+        dynatemp_exps: List[float] = []
         temperature_lasts: List[bool] = []
         top_ps: List[float] = []
         top_as: List[float] = []
@@ -428,6 +434,7 @@ class SamplingTensors:
         sampling_seeds: List[int] = []
         sample_indices: List[int] = []
         do_penalties = False
+        do_temperatures = False
         do_top_p_top_k = False
         do_top_as = False
         do_min_p = False
@@ -451,6 +458,9 @@ class SamplingTensors:
             seq_ids = seq_group.seq_ids
             sampling_params = seq_group.sampling_params
             temperature = sampling_params.temperature
+            dynatemp_min = sampling_params.dynatemp_min
+            dynatemp_max = sampling_params.dynatemp_max
+            dynatemp_exp = sampling_params.dynatemp_exponent
             temperature_last = sampling_params.temperature_last
             p = sampling_params.presence_penalty
             f = sampling_params.frequency_penalty
@@ -475,6 +485,8 @@ class SamplingTensors:
                 # (i.e., greedy sampling or beam search).
                 # Set the temperature to 1 to avoid division by zero.
                 temperature = 1.0
+            if not do_temperatures and temperature != 1.0:
+                do_temperatures = True
             if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
                                        or top_k != vocab_size):
                 do_top_p_top_k = True
@@ -510,6 +522,9 @@ class SamplingTensors:
                 assert query_len is not None
                 prefill_len = len(seq_group.prompt_logprob_indices)
                 temperatures += [temperature] * prefill_len
+                dynatemp_mins += [dynatemp_min] * prefill_len
+                dynatemp_maxs += [dynatemp_max] * prefill_len
+                dynatemp_exps += [dynatemp_exp] * prefill_len
                 temperature_lasts += [temperature_last] * prefill_len
                 top_ps += [top_p] * prefill_len
                 top_ks += [top_k] * prefill_len
@@ -531,6 +546,9 @@ class SamplingTensors:
                 sample_lens = len(seq_group.sample_indices)
                 assert sample_lens == len(seq_ids)
                 temperatures += [temperature] * len(seq_ids)
+                dynatemp_mins += [dynatemp_min] * len(seq_ids)
+                dynatemp_maxs += [dynatemp_max] * len(seq_ids)
+                dynatemp_exps += [dynatemp_exp] * len(seq_ids)
                 temperature_lasts += [temperature_last] * len(seq_ids)
                 top_ps += [top_p] * len(seq_ids)
                 top_ks += [top_k] * len(seq_ids)
@@ -587,18 +605,21 @@ class SamplingTensors:
                         output_tokens.append(seq_data.output_token_ids_array)
 
         sampling_tensors = SamplingTensors.from_lists(
-            temperatures, temperature_lasts, top_ps, top_ks, top_as, min_ps,
+            temperatures, dynatemp_mins, dynatemp_maxs, dynatemp_exps,
+            temperature_lasts, top_ps, top_ks, top_as, min_ps,
             presence_penalties, frequency_penalties, repetition_penalties,
             tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
             smoothing_curves, xtc_thresholds, xtc_probabilities,sampling_seeds,
             sample_indices, prompt_tokens, output_tokens, vocab_size,
             extra_seeds_to_generate, device, dtype)
-        return (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as,
-                do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
-                do_typical_ps, do_quadratic, do_xtc, do_temp_last)
+        return (sampling_tensors, do_penalties, do_temperatures,
+                do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
+                do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc,
+                do_temp_last)
 
     @classmethod
-    def from_lists(cls, temperatures: List[float],
+    def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
+                   dynatemp_maxs: List[float], dynatemp_exps: List[float],
                    temperature_lasts: List[bool], top_ps: List[float],
                    top_ks: List[int], top_as: List[float],
                    min_ps: List[float], presence_penalties: List[float],
@@ -643,6 +664,24 @@ class SamplingTensors:
             dtype=dtype,
             pin_memory=pin_memory,
         )
+        dynatemp_mins_t = torch.tensor(
+            dynatemp_mins,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        dynatemp_maxs_t = torch.tensor(
+            dynatemp_maxs,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        dynatemp_exps_t = torch.tensor(
+            dynatemp_exps,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
         temp_lasts_t = torch.tensor(
             temperature_lasts,
             device="cpu",
@@ -751,6 +790,9 @@ class SamplingTensors:
 
         return cls(
             temperatures=temperatures_t.to(device=device, non_blocking=True),
+            dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
+            dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
+            dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
             temperature_lasts=temp_lasts_t.to(device=device, non_blocking=True),
             top_ps=top_ps_t.to(device=device, non_blocking=True),
             top_ks=top_ks_t.to(device=device, non_blocking=True),