Browse Source

feat: add sampler_priorty (#837)

* feat: add sampler_priorty

* fix: sampler arg verification

* more clean-up and remove min_tokens from the order

* more cleaning up and logs

* alias sampler_priority to sampler_order
AlpinDale 4 months ago
parent
commit
dfa34d1b24

+ 44 - 0
aphrodite/common/sampling_params.py

@@ -23,6 +23,25 @@ class SamplingType(IntEnum):
     RANDOM_SEED = 2
     RANDOM_SEED = 2
     BEAM = 3
     BEAM = 3
 
 
+class SamplerID(IntEnum):
+    # Mirror these in aphrodite/modeling/layers/sampler.py
+    # Values out of order to keep backwards compatibility
+    # with Koboldcpp values
+    DRY = 7
+    PENALTIES = 6
+    NO_REPEAT_NGRAM = 8
+    TEMPERATURE = 5
+    TOP_NSIGMA = 9
+    TOP_P_TOP_K = 0
+    TOP_A = 1
+    MIN_P = 2
+    TFS = 3
+    ETA_CUTOFF = 10
+    EPSILON_CUTOFF = 11
+    TYPICAL_P = 4
+    QUADRATIC = 12
+    XTC = 13
+
 
 
 LogitsProcessorFunc = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
 LogitsProcessorFunc = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
                             Callable[[List[int], List[int], torch.Tensor],
                             Callable[[List[int], List[int], torch.Tensor],
@@ -175,6 +194,8 @@ class SamplingParams(
             Defaults to None.
             Defaults to None.
         skew: Bias the token selection towards higher or lower probability
         skew: Bias the token selection towards higher or lower probability
             tokens. Defaults to 0 (disabled).
             tokens. Defaults to 0 (disabled).
+        sampler_priority: A list of integers to control the order in which
+            samplers are applied.
     """
     """
 
 
     n: int = 1
     n: int = 1
@@ -227,6 +248,7 @@ class SamplingParams(
     dry_allowed_length: int = 2
     dry_allowed_length: int = 2
     dry_sequence_breaker_ids: List[int] = []
     dry_sequence_breaker_ids: List[int] = []
     skew: float = 0.0
     skew: float = 0.0
+    sampler_priority: Optional[List[int]] = []
     # The below fields are not supposed to be used as an input.
     # The below fields are not supposed to be used as an input.
     # They are set in post_init.
     # They are set in post_init.
     output_text_buffer_length: int = 0
     output_text_buffer_length: int = 0
@@ -279,6 +301,7 @@ class SamplingParams(
         "dry_allowed_length": 2,
         "dry_allowed_length": 2,
         "dry_sequence_breaker_ids": [],
         "dry_sequence_breaker_ids": [],
         "skew": 0.0,
         "skew": 0.0,
+        "sampler_priority": [],
     }
     }
 
 
     def __post_init__(self) -> None:
     def __post_init__(self) -> None:
@@ -428,6 +451,27 @@ class SamplingParams(
             raise ValueError(
             raise ValueError(
                 "skew must be non-negative, got "
                 "skew must be non-negative, got "
                 f"{self.skew}.")
                 f"{self.skew}.")
+        
+        if self.sampler_priority is not None:
+            if not self.sampler_priority:
+                self.sampler_priority = None
+                return
+
+            if not isinstance(self.sampler_priority, list):
+                raise ValueError("sampler_priority must be a list of integers")
+            try:
+                provided_samplers = {
+                    SamplerID(x) for x in self.sampler_priority}
+            except ValueError as e:
+                raise ValueError(
+                    f"Invalid sampler ID in priority list: {e}") from e
+
+            required_samplers = set(SamplerID)
+            if not required_samplers.issubset(provided_samplers):
+                missing = required_samplers - provided_samplers
+                missing_names = [s.name for s in missing]
+                raise ValueError(f"Missing required samplers in priority list: "
+                                 f"{missing_names}")
 
 
     def _verify_beam_search(self) -> None:
     def _verify_beam_search(self) -> None:
         if self.best_of == 1:
         if self.best_of == 1:

+ 12 - 1
aphrodite/endpoints/openai/protocol.py

@@ -5,7 +5,8 @@ import time
 from typing import Any, Dict, List, Literal, Optional, Union
 from typing import Any, Dict, List, Literal, Optional, Union
 
 
 import torch
 import torch
-from pydantic import BaseModel, ConfigDict, Field, model_validator
+from pydantic import (AliasChoices, BaseModel, ConfigDict, Field,
+                      model_validator)
 from transformers import PreTrainedTokenizer
 from transformers import PreTrainedTokenizer
 from typing_extensions import Annotated
 from typing_extensions import Annotated
 
 
@@ -160,6 +161,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
     nsigma: Optional[float] = 0.0
     nsigma: Optional[float] = 0.0
     skew: Optional[float] = 0.0
     skew: Optional[float] = 0.0
     custom_token_bans: Optional[List[int]] = None
     custom_token_bans: Optional[List[int]] = None
+    sampler_priority: Optional[List[int]] = Field(
+        default=[],
+        validation_alias=AliasChoices("sampler_priority",
+                                      "sampler_order"))
     # doc: end-chat-completion-sampling-params
     # doc: end-chat-completion-sampling-params
 
 
     # doc: begin-chat-completion-extra-params
     # doc: begin-chat-completion-extra-params
@@ -317,6 +322,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
             nsigma=self.nsigma,
             nsigma=self.nsigma,
             skew=self.skew,
             skew=self.skew,
             custom_token_bans=self.custom_token_bans,
             custom_token_bans=self.custom_token_bans,
+            sampler_priority=self.sampler_priority,
         )
         )
 
 
     @model_validator(mode='before')
     @model_validator(mode='before')
@@ -436,6 +442,10 @@ class CompletionRequest(OpenAIBaseModel):
     nsigma: Optional[float] = 0.0
     nsigma: Optional[float] = 0.0
     skew: Optional[float] = 0.0
     skew: Optional[float] = 0.0
     custom_token_bans: Optional[List[int]] = None
     custom_token_bans: Optional[List[int]] = None
+    sampler_priority: Optional[List[int]] = Field(
+        default=[],
+        validation_alias=AliasChoices("sampler_priority",
+                                      "sampler_order"))
     # doc: end-completion-sampling-params
     # doc: end-completion-sampling-params
 
 
     # doc: begin-completion-extra-params
     # doc: begin-completion-extra-params
@@ -552,6 +562,7 @@ class CompletionRequest(OpenAIBaseModel):
             nsigma=self.nsigma,
             nsigma=self.nsigma,
             skew=self.skew,
             skew=self.skew,
             custom_token_bans=self.custom_token_bans,
             custom_token_bans=self.custom_token_bans,
+            sampler_priority=self.sampler_priority,
         )
         )
 
 
     @model_validator(mode="before")
     @model_validator(mode="before")

+ 249 - 75
aphrodite/modeling/layers/sampler.py

@@ -2,11 +2,13 @@
 import itertools
 import itertools
 import os
 import os
 import warnings
 import warnings
+from enum import IntEnum
 from math import inf
 from math import inf
 from typing import Dict, List, Optional, Tuple
 from typing import Dict, List, Optional, Tuple
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
+from loguru import logger
 
 
 import aphrodite._custom_ops as ops
 import aphrodite._custom_ops as ops
 from aphrodite.common.sampling_params import SamplingType
 from aphrodite.common.sampling_params import SamplingType
@@ -36,6 +38,26 @@ APHRODITE_USE_SAMPLING_KERNELS = bool(int(
     os.getenv("APHRODITE_USE_SAMPLING_KERNELS", "0")))
     os.getenv("APHRODITE_USE_SAMPLING_KERNELS", "0")))
 
 
 
 
+class SamplerID(IntEnum):
+    # Mirror these in aphrodite/common/sampling_params.py
+    # Values out of order to keep backwards compatibility
+    # with Koboldcpp values
+    DRY = 7
+    PENALTIES = 6
+    NO_REPEAT_NGRAM = 8
+    TEMPERATURE = 5
+    TOP_NSIGMA = 9
+    TOP_P_TOP_K = 0
+    TOP_A = 1
+    MIN_P = 2
+    TFS = 3
+    ETA_CUTOFF = 10
+    EPSILON_CUTOFF = 11
+    TYPICAL_P = 4
+    QUADRATIC = 12
+    XTC = 13
+
+
 class Sampler(nn.Module):
 class Sampler(nn.Module):
     """Samples the next tokens from the model's outputs.
     """Samples the next tokens from the model's outputs.
 
 
@@ -151,90 +173,242 @@ class Sampler(nn.Module):
         do_temp_last = self._do_temp_last
         do_temp_last = self._do_temp_last
 
 
         logits = _apply_min_tokens_penalty(logits, sampling_metadata)
         logits = _apply_min_tokens_penalty(logits, sampling_metadata)
-
-        if do_dry:
-            logits = _apply_dry(
-                logits,
-                sampling_tensors.prompt_tokens,
-                sampling_tensors.dry_multipliers,
-                sampling_tensors.dry_bases, 
-                sampling_tensors.dry_allowed_lengths,
-                sampling_tensors.dry_sequence_breaker_ids
-            )
-
-        # Apply presence and frequency penalties.
-        if do_penalties:
-            logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
-                                      sampling_tensors.output_tokens,
-                                      sampling_tensors.presence_penalties,
-                                      sampling_tensors.frequency_penalties,
-                                      sampling_tensors.repetition_penalties)
-        
-        if do_no_repeat_ngrams:
-            logits = _apply_no_repeat_ngram(
-                logits,
-                sampling_tensors.prompt_tokens,
-                sampling_tensors.no_repeat_ngram_sizes)
-
-        # Apply temperature scaling if not doing temp_last.
-        if do_temperatures and not do_temp_last:
-            _apply_temperatures(logits, sampling_tensors.temperatures,
-                                sampling_tensors.dynatemp_mins,
-                                sampling_tensors.dynatemp_maxs,
-                                sampling_tensors.dynatemp_exps)
-
-        if do_nsigmas:
-            logits = _apply_top_nsigma(logits, sampling_tensors.nsigmas)
-
-        if do_top_p_top_k and not APHRODITE_USE_SAMPLING_KERNELS:
-            logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
-                                        sampling_tensors.top_ks)
-
-        if do_top_as:
-            logits = _apply_top_a(logits, sampling_tensors.top_as)
-
-        if do_min_p:
-            logits = _apply_min_p(logits, sampling_tensors.min_ps)
-
-        if do_tfss:
-            logits = _apply_tfs(logits, sampling_tensors.tfss)
-
-        if do_eta_cutoffs:
-            logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
-
-        if do_epsilon_cutoffs:
-            logits = _apply_epsilon_cutoff(logits,
-                                           sampling_tensors.epsilon_cutoffs)
-
-        if do_typical_ps:
-            logits = _apply_typical_sampling(logits,
-                                             sampling_tensors.typical_ps)
-
-        if do_quadratic:
-            logits = _apply_quadratic_sampling(
-                logits, sampling_tensors.smoothing_factors,
-                sampling_tensors.smoothing_curves)
-
-        if do_xtc:
-            logits = _apply_xtc_sampling(
-                logits, sampling_tensors.xtc_thresholds,
-                sampling_tensors.xtc_probabilities)
-
-        if do_temperatures and do_temp_last:
-            _apply_temperatures(logits, sampling_tensors.temperatures,
-                                sampling_tensors.dynatemp_mins,
-                                sampling_tensors.dynatemp_maxs,
-                                sampling_tensors.dynatemp_exps)
-
         banned_tokens = _get_custom_token_bans(sampling_metadata)
         banned_tokens = _get_custom_token_bans(sampling_metadata)
         logits = _apply_token_bans(logits, banned_tokens)
         logits = _apply_token_bans(logits, banned_tokens)
 
 
+        sampler_order = None
+        if sampling_metadata.seq_groups:
+            sampler_order = sampling_metadata.seq_groups[
+                0].sampling_params.sampler_priority
+
+            # Warn if both custom order and temp_last are specified
+            if sampler_order is not None and do_temp_last:
+                logger.warning(
+                    "Both sampler_priority and temperature_last=True "
+                    "were specified. Using custom sampler_priority order "
+                    "and ignoring temperature_last.")
+
+        if sampler_order is None:
+            default_order = [
+                SamplerID.DRY,
+                SamplerID.PENALTIES,
+                SamplerID.NO_REPEAT_NGRAM,
+                SamplerID.TEMPERATURE,
+                SamplerID.TOP_NSIGMA,
+                SamplerID.TOP_P_TOP_K,
+                SamplerID.TOP_A,
+                SamplerID.MIN_P,
+                SamplerID.TFS,
+                SamplerID.ETA_CUTOFF,
+                SamplerID.EPSILON_CUTOFF,
+                SamplerID.TYPICAL_P,
+                SamplerID.QUADRATIC,
+                SamplerID.XTC,
+            ]
+
+            sampler_order = []
+            for sampler_id in default_order:
+                if sampler_id == SamplerID.TEMPERATURE and do_temp_last:
+                    continue
+                sampler_order.append(sampler_id)
+
+                if sampler_id == SamplerID.XTC and do_temp_last:
+                    sampler_order.append(SamplerID.TEMPERATURE)
+
+        if sampling_metadata.seq_groups and sampling_metadata.seq_groups[
+            0].is_prompt:
+            logger.debug("Sampler execution order: ")
+            for i, sampler_id in enumerate(sampler_order, 1):
+                logger.debug(f"{i}. {SamplerID(sampler_id).name}")
+
+            enabled_samplers = []
+            # ruff: noqa: E701
+            if do_penalties: enabled_samplers.append("PENALTIES")
+            if do_no_repeat_ngrams: enabled_samplers.append("NO_REPEAT_NGRAM")
+            if do_temperatures: enabled_samplers.append("TEMPERATURE")
+            if do_top_p_top_k: enabled_samplers.append("TOP_P_TOP_K")
+            if do_top_as: enabled_samplers.append("TOP_A")
+            if do_min_p: enabled_samplers.append("MIN_P")
+            if do_tfss: enabled_samplers.append("TFS")
+            if do_eta_cutoffs: enabled_samplers.append("ETA_CUTOFF")
+            if do_epsilon_cutoffs: enabled_samplers.append("EPSILON_CUTOFF")
+            if do_typical_ps: enabled_samplers.append("TYPICAL_P")
+            if do_quadratic: enabled_samplers.append("QUADRATIC")
+            if do_xtc: enabled_samplers.append("XTC")
+            if do_nsigmas: enabled_samplers.append("TOP_NSIGMA")
+            if do_dry: enabled_samplers.append("DRY")
+            if do_skew: enabled_samplers.append("SKEW")
+            logger.debug(f"Enabled samplers: {', '.join(enabled_samplers)}")
+
+        for sampler_id in sampler_order:
+            if sampler_id == SamplerID.DRY and do_dry:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        f"Applying DRY with dry_multiplier: "
+                        f"{sampling_tensors.dry_multipliers}.")
+                logits = _apply_dry(
+                    logits,
+                    sampling_tensors.prompt_tokens,
+                    sampling_tensors.dry_multipliers,
+                    sampling_tensors.dry_bases, 
+                    sampling_tensors.dry_allowed_lengths,
+                    sampling_tensors.dry_sequence_breaker_ids)
+
+            elif sampler_id == SamplerID.PENALTIES and do_penalties:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying penalties with "
+                        f"pres_pen: {sampling_tensors.presence_penalties}, "
+                        f"freq_pen: {sampling_tensors.frequency_penalties}, "
+                        f"rep_pen: {sampling_tensors.repetition_penalties}.")
+                logits = _apply_penalties(
+                    logits, sampling_tensors.prompt_tokens,
+                    sampling_tensors.output_tokens,
+                    sampling_tensors.presence_penalties,
+                    sampling_tensors.frequency_penalties,
+                    sampling_tensors.repetition_penalties)
+
+            elif sampler_id == SamplerID.NO_REPEAT_NGRAM and \
+                do_no_repeat_ngrams:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying no_repeat_ngram with no_repeat_ngram_size: "
+                        f"{sampling_tensors.no_repeat_ngram_sizes}.")
+                logits = _apply_no_repeat_ngram(
+                    logits,
+                    sampling_tensors.prompt_tokens,
+                    sampling_tensors.no_repeat_ngram_sizes)
+
+            elif sampler_id == SamplerID.TEMPERATURE and do_temperatures:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying temperatures with temperature: "
+                        f"{sampling_tensors.temperatures}, "
+                        f"dynatemp_min: {sampling_tensors.dynatemp_mins}, "
+                        f"dynatemp_max: {sampling_tensors.dynatemp_maxs}, "
+                        f"dynamtep_exp: {sampling_tensors.dynatemp_exps}.")
+                _apply_temperatures(
+                    logits, sampling_tensors.temperatures,
+                    sampling_tensors.dynatemp_mins,
+                    sampling_tensors.dynatemp_maxs,
+                    sampling_tensors.dynatemp_exps)
+
+            elif sampler_id == SamplerID.TOP_NSIGMA and do_nsigmas:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Top-Nsigma with nsigma: "
+                        f"{sampling_tensors.nsigmas}")
+                logits = _apply_top_nsigma(
+                    logits, sampling_tensors.nsigmas)
+
+            elif sampler_id == SamplerID.TOP_P_TOP_K and do_top_p_top_k and \
+                not APHRODITE_USE_SAMPLING_KERNELS:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Top-p and Top-k with top-p: "
+                        f"{sampling_tensors.top_ps}, top_k: "
+                        f"{sampling_tensors.top_ks}.")
+                logits = _apply_top_k_top_p(
+                    logits, sampling_tensors.top_ps,
+                    sampling_tensors.top_ks)
+
+            elif sampler_id == SamplerID.TOP_A and do_top_as:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Top-a with Top-a: "
+                        f"{sampling_tensors.top_as}.")
+                logits = _apply_top_a(
+                    logits, sampling_tensors.top_as)
+
+            elif sampler_id == SamplerID.MIN_P and do_min_p:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Min-p with Min-p: "
+                        f"{sampling_tensors.min_ps}.")
+                logits = _apply_min_p(
+                    logits, sampling_tensors.min_ps)
+
+            elif sampler_id == SamplerID.TFS and do_tfss:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Tail-Free Sampling with tfs: "
+                        f"{sampling_tensors.tfss}.")
+                logits = _apply_tfs(
+                    logits, sampling_tensors.tfss)
+
+            elif sampler_id == SamplerID.ETA_CUTOFF and do_eta_cutoffs:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying ETA Cutoff with eta_cutoff: "
+                        f"{sampling_tensors.eta_cutoffs}.")
+                logits = _apply_eta_cutoff(
+                    logits, sampling_tensors.eta_cutoffs)
+
+            elif sampler_id == SamplerID.EPSILON_CUTOFF and do_epsilon_cutoffs:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Epsilon Cutoff with epsilon_cutoff: "
+                        f"{sampling_tensors.epsilon_cutoffs}.")
+                logits = _apply_epsilon_cutoff(
+                    logits, sampling_tensors.epsilon_cutoffs)
+
+            elif sampler_id == SamplerID.TYPICAL_P and do_typical_ps:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Locally Typical Sampling with typical_p: "
+                        f"{sampling_tensors.typical_ps}.")
+                logits = _apply_typical_sampling(
+                    logits, sampling_tensors.typical_ps)
+
+            elif sampler_id == SamplerID.QUADRATIC and do_quadratic:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Quadratic and Cubic Sampling with "
+                        "smoothing_factors: "
+                        f"{sampling_tensors.smoothing_factors},"
+                        f" smoothing_curves: "
+                        f"{sampling_tensors.smoothing_curves}.")
+                logits = _apply_quadratic_sampling(
+                    logits, sampling_tensors.smoothing_factors,
+                    sampling_tensors.smoothing_curves)
+
+            elif sampler_id == SamplerID.XTC and do_xtc:
+                if (sampling_metadata.seq_groups and
+                    sampling_metadata.seq_groups[0].is_prompt):
+                    logger.debug(
+                        "Applying Exclude Top Choices sampling with "
+                        f"xtc_threshold: {sampling_tensors.xtc_thresholds}, "
+                        "xtc_probability: "
+                        f"{sampling_tensors.xtc_probabilities}.")
+                logits = _apply_xtc_sampling(
+                    logits, sampling_tensors.xtc_thresholds,
+                    sampling_tensors.xtc_probabilities)
+
+
         # We use float32 for probabilities and log probabilities.
         # We use float32 for probabilities and log probabilities.
         # Compute the probabilities.
         # Compute the probabilities.
         probs = torch.softmax(logits, dim=-1, dtype=torch.float)
         probs = torch.softmax(logits, dim=-1, dtype=torch.float)
 
 
         # skew needs to be applied post-softmax
         # skew needs to be applied post-softmax
         if do_skew:
         if do_skew:
+            if (sampling_metadata.seq_groups and
+                sampling_metadata.seq_groups[0].is_prompt):
+                logger.debug(
+                    "Applying Skew sampling with skew: "
+                    f"{sampling_tensors.skews}.")
             # reference: https://github.com/turboderp/exllamav2/commit/1de4cdd70b09208e7b4f17ee322c190e16f60efd
             # reference: https://github.com/turboderp/exllamav2/commit/1de4cdd70b09208e7b4f17ee322c190e16f60efd
             cum_probs = torch.cumsum(probs, dim=-1)
             cum_probs = torch.cumsum(probs, dim=-1)
             cum_probs = torch.pow(cum_probs, torch.exp(
             cum_probs = torch.pow(cum_probs, torch.exp(