Browse Source

feat: add skew sampling (#834)

* feat: add skew sampling

* add tests

* modify based on comments from turboderp

* limit skew to positive values only
AlpinDale 3 months ago
parent
commit
60f7b828d5

+ 9 - 1
aphrodite/common/sampling_params.py

@@ -173,6 +173,8 @@ class SamplingParams(
             input into sections where repetition is evaluated separately.
             Common examples are newlines, quotes, and other structural tokens.
             Defaults to None.
+        skew: Bias the token selection towards higher or lower probability
+            tokens. Defaults to 0 (disabled).
     """
 
     n: int = 1
@@ -224,6 +226,7 @@ class SamplingParams(
     dry_base: float = 1.75
     dry_allowed_length: int = 2
     dry_sequence_breaker_ids: List[int] = []
+    skew: float = 0.0
     # The below fields are not supposed to be used as an input.
     # They are set in post_init.
     output_text_buffer_length: int = 0
@@ -275,6 +278,7 @@ class SamplingParams(
         "dry_base": 1.75,
         "dry_allowed_length": 2,
         "dry_sequence_breaker_ids": [],
+        "skew": 0.0,
     }
 
     def __post_init__(self) -> None:
@@ -419,7 +423,11 @@ class SamplingParams(
         if self.dry_allowed_length < 0:
             raise ValueError(
                 "dry_allowed_length must be non-negative, got "
-                f"{self.dry_allowed_length}.")    
+                f"{self.dry_allowed_length}.")
+        if self.skew < 0.0:
+            raise ValueError(
+                "skew must be non-negative, got "
+                f"{self.skew}.")
 
     def _verify_beam_search(self) -> None:
         if self.best_of == 1:

+ 4 - 0
aphrodite/endpoints/openai/protocol.py

@@ -158,6 +158,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
     dynatemp_max: Optional[float] = 0.0
     dynatemp_exponent: Optional[float] = 1.0
     nsigma: Optional[float] = 0.0
+    skew: Optional[float] = 0.0
     custom_token_bans: Optional[List[int]] = None
     # doc: end-chat-completion-sampling-params
 
@@ -314,6 +315,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
             dynatemp_max=self.dynatemp_max,
             dynatemp_exponent=self.dynatemp_exponent,
             nsigma=self.nsigma,
+            skew=self.skew,
             custom_token_bans=self.custom_token_bans,
         )
 
@@ -432,6 +434,7 @@ class CompletionRequest(OpenAIBaseModel):
     dynatemp_max: Optional[float] = 0.0
     dynatemp_exponent: Optional[float] = 1.0
     nsigma: Optional[float] = 0.0
+    skew: Optional[float] = 0.0
     custom_token_bans: Optional[List[int]] = None
     # doc: end-completion-sampling-params
 
@@ -547,6 +550,7 @@ class CompletionRequest(OpenAIBaseModel):
             dynatemp_max=self.dynatemp_max,
             dynatemp_exponent=self.dynatemp_exponent,
             nsigma=self.nsigma,
+            skew=self.skew,
             custom_token_bans=self.custom_token_bans,
         )
 

+ 14 - 1
aphrodite/modeling/layers/sampler.py

@@ -86,7 +86,7 @@ class Sampler(nn.Module):
         (sampling_tensors, do_penalties, do_no_repeat_ngrams, do_temperatures,
          do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
          do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc, do_nsigmas,
-         do_dry, do_temp_last
+         do_dry, do_skew, do_temp_last
          ) = SamplingTensors.from_sampling_metadata(
              sampling_metadata, vocab_size, logits.device, logits.dtype)
 
@@ -105,6 +105,7 @@ class Sampler(nn.Module):
         self._do_xtc = do_xtc
         self._do_nsgimas = do_nsigmas
         self._do_dry = do_dry
+        self._do_skew = do_skew
         self._do_temp_last = do_temp_last
 
     def forward(
@@ -146,6 +147,7 @@ class Sampler(nn.Module):
         do_xtc = self._do_xtc
         do_nsigmas = self._do_nsgimas
         do_dry = self._do_dry
+        do_skew = self._do_skew
         do_temp_last = self._do_temp_last
 
         logits = _apply_min_tokens_penalty(logits, sampling_metadata)
@@ -230,6 +232,17 @@ class Sampler(nn.Module):
         # We use float32 for probabilities and log probabilities.
         # Compute the probabilities.
         probs = torch.softmax(logits, dim=-1, dtype=torch.float)
+
+        # skew needs to be applied post-softmax
+        if do_skew:
+            # reference: https://github.com/turboderp/exllamav2/commit/1de4cdd70b09208e7b4f17ee322c190e16f60efd
+            cum_probs = torch.cumsum(probs, dim=-1)
+            cum_probs = torch.pow(cum_probs, torch.exp(
+                sampling_tensors.skews).unsqueeze(dim=1))
+            probs = torch.diff(cum_probs, dim=-1,
+                               prepend=torch.zeros_like(cum_probs[..., :1]))
+            logits = torch.log(probs)
+
         # Compute the log probabilities.
         logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
 

+ 19 - 6
aphrodite/modeling/sampling_metadata.py

@@ -393,6 +393,7 @@ class SamplingTensors:
     dry_bases: torch.Tensor
     dry_allowed_lengths: torch.Tensor
     dry_sequence_breaker_ids: torch.Tensor
+    skews: torch.Tensor
     sampling_seeds: torch.Tensor
     sample_indices: torch.Tensor
     extra_seeds: Optional[torch.Tensor]
@@ -410,7 +411,7 @@ class SamplingTensors:
         extra_seeds_to_generate: int = 0,
         extra_entropy: Optional[Tuple[int, ...]] = None
     ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
-               bool, bool, bool, bool, bool, bool, bool, bool]:
+               bool, bool, bool, bool, bool, bool, bool, bool, bool]:
         """
         extra_seeds_to_generate: extra seeds to generate using the
             user-defined seed for each sequence.
@@ -446,6 +447,7 @@ class SamplingTensors:
         dry_bases: List[float] = []
         dry_allowed_lengths: List[int] = []
         dry_sequence_breaker_ids: List[List[int]] = []
+        skews: List[float] = []
 
         do_penalties = False
         do_no_repeat_ngrams = False
@@ -461,6 +463,7 @@ class SamplingTensors:
         do_xtc = False
         do_nsigmas = False
         do_dry = False
+        do_skews = False
         do_temp_last = False
 
         if _USE_TRITON_SAMPLER:
@@ -506,6 +509,7 @@ class SamplingTensors:
             do_xtc |= params.xtc_probability > _SAMPLING_EPS
             do_nsigmas |= params.nsigma > _SAMPLING_EPS
             do_dry |= params.dry_multiplier > _SAMPLING_EPS
+            do_skews |= abs(params.skew) > _SAMPLING_EPS
 
             do_temp_last |= params.temperature_last
 
@@ -548,6 +552,7 @@ class SamplingTensors:
             dry_allowed_lengths += [params.dry_allowed_length] * n_seqs
             dry_sequence_breaker_ids += (
                 [params.dry_sequence_breaker_ids] * n_seqs)
+            skews += [params.skew] * n_seqs
 
             if _USE_TRITON_SAMPLER:
                 if is_prompt:
@@ -596,13 +601,14 @@ class SamplingTensors:
             no_repeat_ngram_sizes, tfss, eta_cutoffs, epsilon_cutoffs,
             typical_ps, smoothing_factors, smoothing_curves, xtc_thresholds,
             xtc_probabilities, nsigmas, dry_multipliers, dry_bases,
-            dry_allowed_lengths, dry_sequence_breaker_ids, sampling_seeds,
-            sample_indices, prompt_tokens, output_tokens, vocab_size,
-            extra_seeds_to_generate, device, dtype)
+            dry_allowed_lengths, dry_sequence_breaker_ids, skews,
+            sampling_seeds, sample_indices, prompt_tokens, output_tokens,
+            vocab_size, extra_seeds_to_generate, device, dtype)
         return (sampling_tensors, do_penalties, do_no_repeat_ngrams,
                 do_temperatures, do_top_p_top_k, do_top_as, do_min_p,
                 do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
-                do_quadratic, do_xtc, do_nsigmas, do_dry, do_temp_last)
+                do_quadratic, do_xtc, do_nsigmas, do_dry, do_skews,
+                do_temp_last)
 
     @classmethod
     def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
@@ -620,7 +626,7 @@ class SamplingTensors:
                    dry_multipliers: List[float], dry_bases: List[float],
                    dry_allowed_lengths: List[int],
                    dry_sequence_breaker_ids: List[List[int]],
-                   sampling_seeds: List[List[int]],
+                   skews: List[float], sampling_seeds: List[List[int]],
                    sample_indices: List[int], prompt_tokens: List[array],
                    output_tokens: List[array], vocab_size: int,
                    extra_seeds_to_generate: int, device: torch.device,
@@ -786,6 +792,12 @@ class SamplingTensors:
             dtype=torch.long,
             pin_memory=pin_memory,
         )
+        skews_t = torch.tensor(
+            skews,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
 
         sample_indices_t = torch.tensor(
             sample_indices,
@@ -853,6 +865,7 @@ class SamplingTensors:
                                                          non_blocking=True),
             dry_sequence_breaker_ids=dry_sequence_breakers_t.to(device=device,
                                                                 non_blocking=True),
+            skews=skews_t.to(device=device, non_blocking=True),
             typical_ps=typical_ps_t.to(device=device, non_blocking=True),
             prompt_tokens=prompt_t.to(device=device, non_blocking=True),
             output_tokens=output_t.to(device=device, non_blocking=True),

+ 53 - 0
tests/samplers/test_sampler.py

@@ -851,6 +851,59 @@ def test_sampler_nsigma(seed: int, device: str):
             "Top-nsigma sampling is not deterministic with same seed"
 
 
+@pytest.mark.parametrize("seed", RANDOM_SEEDS)
+@pytest.mark.parametrize("device", CUDA_DEVICES)
+def test_sampler_skew(seed: int, device: str):
+    """Test that skew sampling behaves as expected."""
+    set_random_seed(seed)
+    torch.set_default_device(device)
+    batch_size = random.randint(1, 256)
+    _, fake_logits, sampler = _prepare_test(batch_size)
+
+    high_prob_tokens = {}
+    for i in range(batch_size):
+        # Make token i have a much higher logit in sequence i
+        fake_logits[i, i] = 10.0
+        high_prob_tokens[i] = i
+
+    test_cases = [
+        # (skew, expected_behavior)
+        (2.0, "low"),     # Strong bias away from high probability tokens
+        (0.5, "subtle"),  # Subtle bias away from high probability tokens
+        (0.0, "neutral"), # No bias (regular sampling)
+    ]
+
+    for skew, expected_behavior in test_cases:
+        sampling_params = SamplingParams(
+            temperature=1.0,  # neutral temperature
+            skew=skew,
+            seed=random.randint(0, 10000),  # for determinism
+        )
+
+        sampler_output = _do_sample(batch_size, fake_logits.clone(), sampler,
+                                  sampling_params, device)
+
+        for batch_idx, sequence_output in enumerate(sampler_output):
+            token_id = sequence_output.samples[0].output_token
+
+            if expected_behavior == "low":
+                # strong skew should bias away from high probability tokens
+                assert token_id != high_prob_tokens[batch_idx], \
+                    f"With high skew {skew}, should not select high " \
+                    f"probability token {high_prob_tokens[batch_idx]}"
+
+            elif expected_behavior == "subtle":
+                # we don't assert anything for subtle effect,
+                # as it's probabilistic
+                pass
+
+        # determinism
+        second_output = _do_sample(batch_size, fake_logits.clone(), sampler,
+                                 sampling_params, device)
+        assert sampler_output == second_output, \
+            f"Skew sampling with seed is not deterministic for skew={skew}"
+
+
 @pytest.mark.parametrize("device", CUDA_DEVICES)
 def test_sampler_include_gpu_probs_tensor(device: str):
     set_random_seed(42)