Selaa lähdekoodia

feat: add sampler_priorty (#837)

* feat: add sampler_priorty

* fix: sampler arg verification

* more clean-up and remove min_tokens from the order

* more cleaning up and logs

* alias sampler_priority to sampler_order
AlpinDale 3 kuukautta sitten
vanhempi
commit
dfa34d1b24

+ 44 - 0
aphrodite/common/sampling_params.py

@@ -23,6 +23,25 @@ class SamplingType(IntEnum):
     RANDOM_SEED = 2
     BEAM = 3
 
+class SamplerID(IntEnum):
+    # Mirror these in aphrodite/modeling/layers/sampler.py
+    # Values out of order to keep backwards compatibility
+    # with Koboldcpp values
+    DRY = 7
+    PENALTIES = 6
+    NO_REPEAT_NGRAM = 8
+    TEMPERATURE = 5
+    TOP_NSIGMA = 9
+    TOP_P_TOP_K = 0
+    TOP_A = 1
+    MIN_P = 2
+    TFS = 3
+    ETA_CUTOFF = 10
+    EPSILON_CUTOFF = 11
+    TYPICAL_P = 4
+    QUADRATIC = 12
+    XTC = 13
+
 
 LogitsProcessorFunc = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
                             Callable[[List[int], List[int], torch.Tensor],
@@ -175,6 +194,8 @@ class SamplingParams(
             Defaults to None.
         skew: Bias the token selection towards higher or lower probability
             tokens. Defaults to 0 (disabled).
+        sampler_priority: A list of integers to control the order in which
+            samplers are applied.
     """
 
     n: int = 1
@@ -227,6 +248,7 @@ class SamplingParams(
     dry_allowed_length: int = 2
     dry_sequence_breaker_ids: List[int] = []
     skew: float = 0.0
+    sampler_priority: Optional[List[int]] = []
     # The below fields are not supposed to be used as an input.
     # They are set in post_init.
     output_text_buffer_length: int = 0
@@ -279,6 +301,7 @@ class SamplingParams(
         "dry_allowed_length": 2,
         "dry_sequence_breaker_ids": [],
         "skew": 0.0,
+        "sampler_priority": [],
     }
 
     def __post_init__(self) -> None:
@@ -428,6 +451,27 @@ class SamplingParams(
             raise ValueError(
                 "skew must be non-negative, got "
                 f"{self.skew}.")
+        
+        if self.sampler_priority is not None:
+            if not self.sampler_priority:
+                self.sampler_priority = None
+                return
+
+            if not isinstance(self.sampler_priority, list):
+                raise ValueError("sampler_priority must be a list of integers")
+            try:
+                provided_samplers = {
+                    SamplerID(x) for x in self.sampler_priority}
+            except ValueError as e:
+                raise ValueError(
+                    f"Invalid sampler ID in priority list: {e}") from e
+
+            required_samplers = set(SamplerID)
+            if not required_samplers.issubset(provided_samplers):
+                missing = required_samplers - provided_samplers
+                missing_names = [s.name for s in missing]
+                raise ValueError(f"Missing required samplers in priority list: "
+                                 f"{missing_names}")
 
     def _verify_beam_search(self) -> None:
         if self.best_of == 1:

+ 12 - 1
aphrodite/endpoints/openai/protocol.py

@@ -5,7 +5,8 @@ import time
 from typing import Any, Dict, List, Literal, Optional, Union
 
 import torch
-from pydantic import BaseModel, ConfigDict, Field, model_validator
+from pydantic import (AliasChoices, BaseModel, ConfigDict, Field,
+                      model_validator)
 from transformers import PreTrainedTokenizer
 from typing_extensions import Annotated
 
@@ -160,6 +161,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
     nsigma: Optional[float] = 0.0
     skew: Optional[float] = 0.0
     custom_token_bans: Optional[List[int]] = None
+    sampler_priority: Optional[List[int]] = Field(
+        default=[],
+        validation_alias=AliasChoices("sampler_priority",
+                                      "sampler_order"))
     # doc: end-chat-completion-sampling-params
 
     # doc: begin-chat-completion-extra-params
@@ -317,6 +322,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
             nsigma=self.nsigma,
             skew=self.skew,
             custom_token_bans=self.custom_token_bans,
+            sampler_priority=self.sampler_priority,
         )
 
     @model_validator(mode='before')
@@ -436,6 +442,10 @@ class CompletionRequest(OpenAIBaseModel):
     nsigma: Optional[float] = 0.0
     skew: Optional[float] = 0.0
     custom_token_bans: Optional[List[int]] = None
+    sampler_priority: Optional[List[int]] = Field(
+        default=[],
+        validation_alias=AliasChoices("sampler_priority",
+                                      "sampler_order"))
     # doc: end-completion-sampling-params
 
     # doc: begin-completion-extra-params
@@ -552,6 +562,7 @@ class CompletionRequest(OpenAIBaseModel):
             nsigma=self.nsigma,
             skew=self.skew,
             custom_token_bans=self.custom_token_bans,
+            sampler_priority=self.sampler_priority,
         )
 
     @model_validator(mode="before")

+ 249 - 75
aphrodite/modeling/layers/sampler.py

@@ -2,11 +2,13 @@
 import itertools
 import os
 import warnings
+from enum import IntEnum
 from math import inf
 from typing import Dict, List, Optional, Tuple
 
 import torch
 import torch.nn as nn
+from loguru import logger
 
 import aphrodite._custom_ops as ops
 from aphrodite.common.sampling_params import SamplingType
@@ -36,6 +38,26 @@ APHRODITE_USE_SAMPLING_KERNELS = bool(int(
     os.getenv("APHRODITE_USE_SAMPLING_KERNELS", "0")))
 
 
+class SamplerID(IntEnum):
+    # Mirror these in aphrodite/common/sampling_params.py
+    # Values out of order to keep backwards compatibility
+    # with Koboldcpp values
+    DRY = 7
+    PENALTIES = 6
+    NO_REPEAT_NGRAM = 8
+    TEMPERATURE = 5
+    TOP_NSIGMA = 9
+    TOP_P_TOP_K = 0
+    TOP_A = 1
+    MIN_P = 2
+    TFS = 3
+    ETA_CUTOFF = 10
+    EPSILON_CUTOFF = 11
+    TYPICAL_P = 4
+    QUADRATIC = 12
+    XTC = 13
+
+
 class Sampler(nn.Module):
     """Samples the next tokens from the model's outputs.
 
@@ -151,90 +173,242 @@ class Sampler(nn.Module):
         do_temp_last = self._do_temp_last
 
         logits = _apply_min_tokens_penalty(logits, sampling_metadata)
-
-        if do_dry:
-            logits = _apply_dry(
-                logits,
-                sampling_tensors.prompt_tokens,
-                sampling_tensors.dry_multipliers,
-                sampling_tensors.dry_bases, 
-                sampling_tensors.dry_allowed_lengths,
-                sampling_tensors.dry_sequence_breaker_ids
-            )
-
-        # Apply presence and frequency penalties.
-        if do_penalties:
-            logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
-                                      sampling_tensors.output_tokens,
-                                      sampling_tensors.presence_penalties,
-                                      sampling_tensors.frequency_penalties,
-                                      sampling_tensors.repetition_penalties)
-        
-        if do_no_repeat_ngrams:
-            logits = _apply_no_repeat_ngram(
-                logits,
-                sampling_tensors.prompt_tokens,
-                sampling_tensors.no_repeat_ngram_sizes)
-
-        # Apply temperature scaling if not doing temp_last.
-        if do_temperatures and not do_temp_last:
-            _apply_temperatures(logits, sampling_tensors.temperatures,
-                                sampling_tensors.dynatemp_mins,
-                                sampling_tensors.dynatemp_maxs,
-                                sampling_tensors.dynatemp_exps)
-
-        if do_nsigmas:
-            logits = _apply_top_nsigma(logits, sampling_tensors.nsigmas)
-
-        if do_top_p_top_k and not APHRODITE_USE_SAMPLING_KERNELS:
-            logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
-                                        sampling_tensors.top_ks)
-
-        if do_top_as:
-            logits = _apply_top_a(logits, sampling_tensors.top_as)
-
-        if do_min_p:
-            logits = _apply_min_p(logits, sampling_tensors.min_ps)
-
-        if do_tfss:
-            logits = _apply_tfs(logits, sampling_tensors.tfss)
-
-        if do_eta_cutoffs:
-            logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
-
-        if do_epsilon_cutoffs:
-            logits = _apply_epsilon_cutoff(logits,
-                                           sampling_tensors.epsilon_cutoffs)
-
-        if do_typical_ps:
-            logits = _apply_typical_sampling(logits,
-                                             sampling_tensors.typical_ps)
-
-        if do_quadratic:
-            logits = _apply_quadratic_sampling(
-                logits, sampling_tensors.smoothing_factors,
-                sampling_tensors.smoothing_curves)
-
-        if do_xtc:
-            logits = _apply_xtc_sampling(
-                logits, sampling_tensors.xtc_thresholds,
-                sampling_tensors.xtc_probabilities)
-
-        if do_temperatures and do_temp_last:
-            _apply_temperatures(logits, sampling_tensors.temperatures,
-                                sampling_tensors.dynatemp_mins,
-                                sampling_tensors.dynatemp_maxs,
-                                sampling_tensors.dynatemp_exps)
-
         banned_tokens = _get_custom_token_bans(sampling_metadata)
         logits = _apply_token_bans(logits, banned_tokens)
 
+        sampler_order = None
+        if sampling_metadata.seq_groups:
+            sampler_order = sampling_metadata.seq_groups[
+                0].sampling_params.sampler_priority
+
+            # Warn if both custom order and temp_last are specified
+            if sampler_order is not None and do_temp_last:
+                logger.warning(
+                    "Both sampler_priority and temperature_last=True "
+                    "were specified. Using custom sampler_priority order "
+                    "and ignoring temperature_last.")
+
+        if sampler_order is None:
+            default_order = [
+                SamplerID.DRY,
+                SamplerID.PENALTIES,
+                SamplerID.NO_REPEAT_NGRAM,
+                SamplerID.TEMPERATURE,
+                SamplerID.TOP_NSIGMA,
+                SamplerID.TOP_P_TOP_K,
+                SamplerID.TOP_A,
+                SamplerID.MIN_P,
+                SamplerID.TFS,
+                SamplerID.ETA_CUTOFF,
+                SamplerID.EPSILON_CUTOFF,
+                SamplerID.TYPICAL_P,
+                SamplerID.QUADRATIC,
+                SamplerID.XTC,
+            ]
+
+            sampler_order = []
+            for sampler_id in default_order:
+                if sampler_id == SamplerID.TEMPERATURE and do_temp_last:
+                    continue
+                sampler_order.append(sampler_id)
+
+                if sampler_id == SamplerID.XTC and do_temp_last:
+                    sampler_order.append(SamplerID.TEMPERATURE)
+
+        if sampling_metadata.seq_groups and sampling_metadata.seq_groups[
+            0].is_prompt:
+            logger.debug("Sampler execution order: ")
+            for i, sampler_id in enumerate(sampler_order, 1):
+                logger.debug(f"{i}. {SamplerID(sampler_id).name}")
+
+            enabled_samplers = []
+            # ruff: noqa: E701
+            if do_penalties: enabled_samplers.append("PENALTIES")
+            if do_no_repeat_ngrams: enabled_samplers.append("NO_REPEAT_NGRAM")
+            if do_temperatures: enabled_samplers.append("TEMPERATURE")
+            if do_top_p_top_k: enabled_samplers.append("TOP_P_TOP_K")
+            if do_top_as: enabled_samplers.append("TOP_A")
+            if do_min_p: enabled_samplers.append("MIN_P")
+            if do_tfss: enabled_samplers.append("TFS")
+            if do_eta_cutoffs: enabled_samplers.append("ETA_CUTOFF")
+            if do_epsilon_cutoffs: enabled_samplers.append("EPSILON_CUTOFF")
+            if do_typical_ps: enabled_samplers.append("TYPICAL_P")
+            if do_quadratic: enabled_samplers.append("QUADRATIC")
+            if do_xtc: enabled_samplers.append("XTC")
+            if do_nsigmas: enabled_samplers.append("TOP_NSIGMA")
+            if do_dry: enabled_samplers.append("DRY")
+            if do_skew: enabled_samplers.append("SKEW")
+            logger.debug(f"Enabled samplers: {', '.join(enabled_samplers)}")
+
+        for sampler_id in sampler_order:
+            if sampler_id == SamplerID.DRY and do_dry:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        f"Applying DRY with dry_multiplier: "
+                        f"{sampling_tensors.dry_multipliers}.")
+                logits = _apply_dry(
+                    logits,
+                    sampling_tensors.prompt_tokens,
+                    sampling_tensors.dry_multipliers,
+                    sampling_tensors.dry_bases, 
+                    sampling_tensors.dry_allowed_lengths,
+                    sampling_tensors.dry_sequence_breaker_ids)
+
+            elif sampler_id == SamplerID.PENALTIES and do_penalties:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying penalties with "
+                        f"pres_pen: {sampling_tensors.presence_penalties}, "
+                        f"freq_pen: {sampling_tensors.frequency_penalties}, "
+                        f"rep_pen: {sampling_tensors.repetition_penalties}.")
+                logits = _apply_penalties(
+                    logits, sampling_tensors.prompt_tokens,
+                    sampling_tensors.output_tokens,
+                    sampling_tensors.presence_penalties,
+                    sampling_tensors.frequency_penalties,
+                    sampling_tensors.repetition_penalties)
+
+            elif sampler_id == SamplerID.NO_REPEAT_NGRAM and \
+                do_no_repeat_ngrams:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying no_repeat_ngram with no_repeat_ngram_size: "
+                        f"{sampling_tensors.no_repeat_ngram_sizes}.")
+                logits = _apply_no_repeat_ngram(
+                    logits,
+                    sampling_tensors.prompt_tokens,
+                    sampling_tensors.no_repeat_ngram_sizes)
+
+            elif sampler_id == SamplerID.TEMPERATURE and do_temperatures:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying temperatures with temperature: "
+                        f"{sampling_tensors.temperatures}, "
+                        f"dynatemp_min: {sampling_tensors.dynatemp_mins}, "
+                        f"dynatemp_max: {sampling_tensors.dynatemp_maxs}, "
+                        f"dynamtep_exp: {sampling_tensors.dynatemp_exps}.")
+                _apply_temperatures(
+                    logits, sampling_tensors.temperatures,
+                    sampling_tensors.dynatemp_mins,
+                    sampling_tensors.dynatemp_maxs,
+                    sampling_tensors.dynatemp_exps)
+
+            elif sampler_id == SamplerID.TOP_NSIGMA and do_nsigmas:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Top-Nsigma with nsigma: "
+                        f"{sampling_tensors.nsigmas}")
+                logits = _apply_top_nsigma(
+                    logits, sampling_tensors.nsigmas)
+
+            elif sampler_id == SamplerID.TOP_P_TOP_K and do_top_p_top_k and \
+                not APHRODITE_USE_SAMPLING_KERNELS:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Top-p and Top-k with top-p: "
+                        f"{sampling_tensors.top_ps}, top_k: "
+                        f"{sampling_tensors.top_ks}.")
+                logits = _apply_top_k_top_p(
+                    logits, sampling_tensors.top_ps,
+                    sampling_tensors.top_ks)
+
+            elif sampler_id == SamplerID.TOP_A and do_top_as:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Top-a with Top-a: "
+                        f"{sampling_tensors.top_as}.")
+                logits = _apply_top_a(
+                    logits, sampling_tensors.top_as)
+
+            elif sampler_id == SamplerID.MIN_P and do_min_p:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Min-p with Min-p: "
+                        f"{sampling_tensors.min_ps}.")
+                logits = _apply_min_p(
+                    logits, sampling_tensors.min_ps)
+
+            elif sampler_id == SamplerID.TFS and do_tfss:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Tail-Free Sampling with tfs: "
+                        f"{sampling_tensors.tfss}.")
+                logits = _apply_tfs(
+                    logits, sampling_tensors.tfss)
+
+            elif sampler_id == SamplerID.ETA_CUTOFF and do_eta_cutoffs:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying ETA Cutoff with eta_cutoff: "
+                        f"{sampling_tensors.eta_cutoffs}.")
+                logits = _apply_eta_cutoff(
+                    logits, sampling_tensors.eta_cutoffs)
+
+            elif sampler_id == SamplerID.EPSILON_CUTOFF and do_epsilon_cutoffs:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Epsilon Cutoff with epsilon_cutoff: "
+                        f"{sampling_tensors.epsilon_cutoffs}.")
+                logits = _apply_epsilon_cutoff(
+                    logits, sampling_tensors.epsilon_cutoffs)
+
+            elif sampler_id == SamplerID.TYPICAL_P and do_typical_ps:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Locally Typical Sampling with typical_p: "
+                        f"{sampling_tensors.typical_ps}.")
+                logits = _apply_typical_sampling(
+                    logits, sampling_tensors.typical_ps)
+
+            elif sampler_id == SamplerID.QUADRATIC and do_quadratic:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Quadratic and Cubic Sampling with "
+                        "smoothing_factors: "
+                        f"{sampling_tensors.smoothing_factors},"
+                        f" smoothing_curves: "
+                        f"{sampling_tensors.smoothing_curves}.")
+                logits = _apply_quadratic_sampling(
+                    logits, sampling_tensors.smoothing_factors,
+                    sampling_tensors.smoothing_curves)
+
+            elif sampler_id == SamplerID.XTC and do_xtc:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Exclude Top Choices sampling with "
+                        f"xtc_threshold: {sampling_tensors.xtc_thresholds}, "
+                        "xtc_probability: "
+                        f"{sampling_tensors.xtc_probabilities}.")
+                logits = _apply_xtc_sampling(
+                    logits, sampling_tensors.xtc_thresholds,
+                    sampling_tensors.xtc_probabilities)
+
+
         # We use float32 for probabilities and log probabilities.
         # Compute the probabilities.
         probs = torch.softmax(logits, dim=-1, dtype=torch.float)
 
         # skew needs to be applied post-softmax
         if do_skew:
+            if (sampling_metadata.seq_groups and
+                sampling_metadata.seq_groups[0].is_prompt):
+                logger.debug(
+                    "Applying Skew sampling with skew: "
+                    f"{sampling_tensors.skews}.")
             # reference: https://github.com/turboderp/exllamav2/commit/1de4cdd70b09208e7b4f17ee322c190e16f60efd
             cum_probs = torch.cumsum(probs, dim=-1)
             cum_probs = torch.pow(cum_probs, torch.exp(