소스 검색

feat: implement top-nsigma sampling method

AlpinDale 4 달 전
부모
커밋
22429d07b3
4개의 변경된 파일60개의 추가작업 그리고 7개의 파일을 삭제
  1. 12 0
      aphrodite/common/sampling_params.py
  2. 4 0
      aphrodite/endpoints/openai/protocol.py
  3. 28 1
      aphrodite/modeling/layers/sampler.py
  4. 16 6
      aphrodite/modeling/sampling_metadata.py

+ 12 - 0
aphrodite/common/sampling_params.py

@@ -148,6 +148,11 @@ class SamplingParams(
             above this threshold, consider removing all but the last one.
         xtc_probability: Probability that the removal will actually happen.
             0 disables the sampler, 1 makes it always happen.
+        nsigma: Number of standard deviations from the maximum logit to use
+            as a cutoff threshold. Tokens with logits below
+            (max_logit - nsgima * std_dev) are filtered out. Higher values
+            (e.g. 3.0) keep more tokens, lower values (e.g. 1.0) are more
+            selective. Must be positive. 0 to disable.
     """
 
     n: int = 1
@@ -193,6 +198,7 @@ class SamplingParams(
     truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
     xtc_threshold: float = 0.1
     xtc_probability: float = 0
+    nsigma: float = 0.0
 
     # The below fields are not supposed to be used as an input.
     # They are set in post_init.
@@ -239,6 +245,7 @@ class SamplingParams(
         "truncate_prompt_tokens": None,
         "xtc_threshold": 0.1,
         "xtc_probability": 0,
+        "nsigma": 0.0,
     }
 
     def __post_init__(self) -> None:
@@ -368,6 +375,11 @@ class SamplingParams(
             raise ValueError(
                 "xtc_probability must be in [0, 1], got "
                 f"{self.xtc_probability}.")
+        if not self.nsigma <= 0.0:
+            raise ValueError(
+                "nsigma must be non-negative, got "
+                f"{self.nsigma}.")
+            
 
     def _verify_beam_search(self) -> None:
         if self.best_of == 1:

+ 4 - 0
aphrodite/endpoints/openai/protocol.py

@@ -150,6 +150,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
     dynatemp_min: Optional[float] = 0.0
     dynatemp_max: Optional[float] = 0.0
     dynatemp_exponent: Optional[float] = 1.0
+    nsigma: Optional[float] = 0.0
     custom_token_bans: Optional[List[int]] = None
     # doc: end-chat-completion-sampling-params
 
@@ -293,6 +294,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
             dynatemp_min=self.dynatemp_min,
             dynatemp_max=self.dynatemp_max,
             dynatemp_exponent=self.dynatemp_exponent,
+            nsigma=self.nsigma,
             custom_token_bans=self.custom_token_bans,
         )
 
@@ -404,6 +406,7 @@ class CompletionRequest(OpenAIBaseModel):
     dynatemp_min: Optional[float] = 0.0
     dynatemp_max: Optional[float] = 0.0
     dynatemp_exponent: Optional[float] = 1.0
+    nsigma: Optional[float] = 0.0
     custom_token_bans: Optional[List[int]] = None
     # doc: end-completion-sampling-params
 
@@ -506,6 +509,7 @@ class CompletionRequest(OpenAIBaseModel):
             dynatemp_min=self.dynatemp_min,
             dynatemp_max=self.dynatemp_max,
             dynatemp_exponent=self.dynatemp_exponent,
+            nsigma=self.nsigma,
             custom_token_bans=self.custom_token_bans,
         )
 

+ 28 - 1
aphrodite/modeling/layers/sampler.py

@@ -77,7 +77,7 @@ class Sampler(nn.Module):
         # Initialize new sampling tensors
         (sampling_tensors, do_penalties, do_temperatures, do_top_p_top_k,
          do_top_as, do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
-         do_typical_ps, do_quadratic, do_xtc, do_temp_last
+         do_typical_ps, do_quadratic, do_xtc, do_nsigmas, do_temp_last
          ) = SamplingTensors.from_sampling_metadata(
              sampling_metadata, vocab_size, logits.device, logits.dtype)
 
@@ -93,6 +93,7 @@ class Sampler(nn.Module):
         self._do_typical_ps = do_typical_ps
         self._do_quadratic = do_quadratic
         self._do_xtc = do_xtc
+        self._do_nsgimas = do_nsigmas
         self._do_temp_last = do_temp_last
 
     def forward(
@@ -131,6 +132,7 @@ class Sampler(nn.Module):
         do_typical_ps = self._do_typical_ps
         do_quadratic = self._do_quadratic
         do_xtc = self._do_xtc
+        do_nsigmas = self._do_nsgimas
         do_temp_last = self._do_temp_last
 
         logits = _apply_min_tokens_penalty(logits, sampling_metadata)
@@ -150,6 +152,9 @@ class Sampler(nn.Module):
                                 sampling_tensors.dynatemp_maxs,
                                 sampling_tensors.dynatemp_exps)
 
+        if do_nsigmas:
+            logits = _apply_top_nsigma(logits, sampling_tensors.nsigmas)
+
         if do_top_p_top_k:
             logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
                                         sampling_tensors.top_ks)
@@ -634,6 +639,28 @@ def _apply_xtc_sampling(
     return logits
 
 
+def _apply_top_nsigma(
+        logits: torch.Tensor,
+        nsigma: torch.Tensor,
+) -> torch.Tensor:
+    """Apply top-nsigma truncation to the logits.
+    
+    Reference: https://arxiv.org/abs/2411.07641
+
+    Args:
+        logits: Logits of shape (num_tokens, vocab_size)
+        nsigma: Number of standard deviations to use as threshold
+    Returns:
+        Modified logits with values below threshold set to -inf
+    """
+    std = logits.std(dim=-1, keepdim=True) 
+    threshold = (logits.max(dim=-1, keepdim=True).values -
+                 nsigma.unsqueeze(dim=1) * std)
+    logits[logits < threshold] = float("-inf")
+
+    return logits
+
+
 def _greedy_sample(
     selected_seq_groups: List[SequenceGroupToSample],
     samples: torch.Tensor,

+ 16 - 6
aphrodite/modeling/sampling_metadata.py

@@ -387,6 +387,7 @@ class SamplingTensors:
     smoothing_curves: torch.Tensor
     xtc_thresholds: torch.Tensor
     xtc_probabilities: torch.Tensor
+    nsigmas: torch.Tensor
     sampling_seeds: torch.Tensor
     sample_indices: torch.Tensor
     extra_seeds: Optional[torch.Tensor]
@@ -404,7 +405,7 @@ class SamplingTensors:
         extra_seeds_to_generate: int = 0,
         extra_entropy: Optional[Tuple[int, ...]] = None
     ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
-               bool, bool, bool, bool, bool]:
+               bool, bool, bool, bool, bool, bool]:
         """
         extra_seeds_to_generate: extra seeds to generate using the
             user-defined seed for each sequence.
@@ -432,6 +433,7 @@ class SamplingTensors:
         smoothing_curves: List[float] = []
         xtc_thresholds: List[float] = []
         xtc_probabilities: List[float] = []
+        nsigmas: List[float] = []
         sampling_seeds: List[List[int]] = []
         sample_indices: List[int] = []
         do_penalties = False
@@ -445,6 +447,7 @@ class SamplingTensors:
         do_typical_ps = False
         do_quadratic = False
         do_xtc = False
+        do_nsigmas = False
         do_temp_last = False
 
         if _USE_TRITON_SAMPLER:
@@ -487,6 +490,7 @@ class SamplingTensors:
             do_quadratic |= (params.smoothing_factor > _SAMPLING_EPS or
                              params.smoothing_curve > 1.0)
             do_xtc |= params.xtc_probability > _SAMPLING_EPS
+            do_nsigmas |= params.nsigma > _SAMPLING_EPS
             do_temp_last |= params.temperature_last
 
             is_prompt = seq_group.is_prompt
@@ -521,6 +525,7 @@ class SamplingTensors:
             smoothing_curves += [params.smoothing_curve] * n_seqs
             xtc_thresholds += [params.xtc_threshold] * n_seqs
             xtc_probabilities += [params.xtc_probability] * n_seqs
+            nsigmas += [params.nsigma] * n_seqs
 
             if _USE_TRITON_SAMPLER:
                 if is_prompt:
@@ -567,13 +572,13 @@ class SamplingTensors:
             temperature_lasts, top_ps, top_ks, top_as, min_ps,
             presence_penalties, frequency_penalties, repetition_penalties,
             tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
-            smoothing_curves, xtc_thresholds, xtc_probabilities, sampling_seeds,
-            sample_indices, prompt_tokens, output_tokens, vocab_size,
-            extra_seeds_to_generate, device, dtype)
+            smoothing_curves, xtc_thresholds, xtc_probabilities, nsigmas,
+            sampling_seeds, sample_indices, prompt_tokens, output_tokens,
+            vocab_size, extra_seeds_to_generate, device, dtype)
         return (sampling_tensors, do_penalties, do_temperatures,
                 do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
                 do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc,
-                do_temp_last)
+                do_nsigmas, do_temp_last)
 
     @classmethod
     def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
@@ -586,7 +591,7 @@ class SamplingTensors:
                    eta_cutoffs: List[float], epsilon_cutoffs: List[float],
                    typical_ps: List[float], smoothing_factors: List[float],
                    smoothing_curves: List[float], xtc_thresholds: List[float],
-                   xtc_probabilities: List[float],
+                   xtc_probabilities: List[float], nsigmas: List[float],
                    sampling_seeds: List[List[int]],
                    sample_indices: List[int], prompt_tokens: List[array],
                    output_tokens: List[array], vocab_size: int,
@@ -719,6 +724,10 @@ class SamplingTensors:
                                            device="cpu",
                                            dtype=dtype,
                                            pin_memory=pin_memory)
+        nsigmas_t = torch.tensor(nsigmas,
+                                 device="cpu",
+                                 dtype=dtype,
+                                 pin_memory=pin_memory)
         sample_indices_t = torch.tensor(
             sample_indices,
             device="cpu",
@@ -775,6 +784,7 @@ class SamplingTensors:
                                                non_blocking=True),
             xtc_probabilities=xtc_probabilities_t.to(device=device,
                                                      non_blocking=True),
+            nsigmas=nsigmas_t.to(device=device, non_blocking=True),
             typical_ps=typical_ps_t.to(device=device, non_blocking=True),
             prompt_tokens=prompt_t.to(device=device, non_blocking=True),
             output_tokens=output_t.to(device=device, non_blocking=True),