瀏覽代碼

jsd, kld, and dynamic typical p

AlpinDale 5 月之前
父節點
當前提交
63ff55acb5

+ 34 - 0
aphrodite/common/sampling_params.py

@@ -189,6 +189,10 @@ class SamplingParams:
         truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
         xtc_threshold: float = 0.1,
         xtc_probability: float = 0,
+        kl_threshold: float = 0.0,
+        jsd_threshold: float = 0.0,
+        min_typical_p: float = 1.0,
+        max_typical_p: float = 1.0,
     ) -> None:
         self.n = n
         self.best_of = best_of if best_of is not None else n
@@ -253,6 +257,10 @@ class SamplingParams:
             self.output_text_buffer_length = 0
         self.xtc_threshold = xtc_threshold
         self.xtc_probability = xtc_probability
+        self.kl_threshold = kl_threshold
+        self.jsd_threshold = jsd_threshold
+        self.min_typical_p = min_typical_p
+        self.max_typical_p = max_typical_p
 
         self.default_values = {
             "n": 1,
@@ -294,6 +302,10 @@ class SamplingParams:
             "truncate_prompt_tokens": None,
             "xtc_threshold": 0.1,
             "xtc_probability": 0,
+            "kl_threshold": 0.0,
+            "jsd_threshold": 0.0,
+            "min_typical_p": 1.0,
+            "max_typical_p": 1.0,
         }
 
         # Number of characters to hold back for stop string evaluation
@@ -397,6 +409,28 @@ class SamplingParams:
             raise ValueError(
                 "xtc_probability must be in [0, 1], got "
                 f"{self.xtc_probability}.")
+        if self.kl_threshold < 0.0:
+            raise ValueError(
+                "kl_threshold must be non-negative, got "
+                f"{self.kl_threshold}.")
+        # jsd_threshold has to be between 0 and 1
+        if not 0.0 <= self.jsd_threshold <= 1.0:
+            raise ValueError(
+                "jsd_threshold must be in [0, 1], got "
+                f"{self.jsd_threshold}.")
+        if self.min_typical_p < 0.0 or self.min_typical_p > 1.0:
+            raise ValueError(
+                "min_typical_p must be in [0, 1], got "
+                f"{self.min_typical_p}.")
+        if self.max_typical_p < 0.0 or self.max_typical_p > 1.0:
+            raise ValueError(
+                "max_typical_p must be in [0, 1], got "
+                f"{self.max_typical_p}.")
+        if self.min_typical_p > self.max_typical_p:
+            raise ValueError(
+                "min_typical_p must be less than or equal to max_typical_p, "
+                f"got min_typical_p={self.min_typical_p} and "
+                f"max_typical_p={self.max_typical_p}.")
 
     def _verify_beam_search(self) -> None:
         if self.best_of == 1:

+ 8 - 0
aphrodite/endpoints/openai/protocol.py

@@ -401,6 +401,10 @@ class CompletionRequest(OpenAIBaseModel):
     prompt_logprobs: Optional[int] = None
     xtc_threshold: Optional[float] = 0.1
     xtc_probability: Optional[float] = 0.0
+    kl_threshold: Optional[float] = 0.0
+    jsd_threshold: Optional[float] = 0.0
+    min_typical_p: Optional[float] = 1.0
+    max_typical_p: Optional[float] = 1.0
     dynatemp_min: Optional[float] = 0.0
     dynatemp_max: Optional[float] = 0.0
     dynatemp_exponent: Optional[float] = 1.0
@@ -503,6 +507,10 @@ class CompletionRequest(OpenAIBaseModel):
             temperature_last=self.temperature_last,
             xtc_threshold=self.xtc_threshold,
             xtc_probability=self.xtc_probability,
+            kl_threshold=self.kl_threshold,
+            jsd_threshold=self.jsd_threshold,
+            min_typical_p=self.min_typical_p,
+            max_typical_p=self.max_typical_p,
             dynatemp_min=self.dynatemp_min,
             dynatemp_max=self.dynatemp_max,
             dynatemp_exponent=self.dynatemp_exponent,

+ 140 - 1
aphrodite/modeling/layers/sampler.py

@@ -72,7 +72,8 @@ class Sampler(nn.Module):
         # Initialize new sampling tensors
         (sampling_tensors, do_penalties, do_temperatures, do_top_p_top_k,
          do_top_as, do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
-         do_typical_ps, do_quadratic, do_xtc, do_temp_last
+         do_typical_ps, do_quadratic, do_xtc, do_kl_threshold, do_jsd_threshold,
+         do_dynatypical_p, do_temp_last
          ) = SamplingTensors.from_sampling_metadata(
              sampling_metadata, vocab_size, logits.device, logits.dtype)
 
@@ -88,6 +89,9 @@ class Sampler(nn.Module):
         self._do_typical_ps = do_typical_ps
         self._do_quadratic = do_quadratic
         self._do_xtc = do_xtc
+        self._do_kl_threshold = do_kl_threshold
+        self._do_jsd_threshold = do_jsd_threshold
+        self._do_dynatypical_p = do_dynatypical_p
         self._do_temp_last = do_temp_last
 
     def forward(
@@ -126,6 +130,9 @@ class Sampler(nn.Module):
         do_typical_ps = self._do_typical_ps
         do_quadratic = self._do_quadratic
         do_xtc = self._do_xtc
+        do_kl_threshold = self._do_kl_threshold
+        do_jsd_threshold = self._do_jsd_threshold
+        do_dynatypical_p = self._do_dynatypical_p
         do_temp_last = self._do_temp_last
 
         logits = _apply_min_tokens_penalty(logits, sampling_metadata)
@@ -178,6 +185,18 @@ class Sampler(nn.Module):
             logits = _apply_xtc_sampling(
                 logits, sampling_tensors.xtc_thresholds,
                 sampling_tensors.xtc_probabilities)
+            
+        if do_kl_threshold:
+            logits = _apply_kl_divergence_sampling(
+                logits, sampling_tensors.kl_thresholds)
+
+        if do_jsd_threshold:
+            logits = _apply_jsd_sampling(logits, sampling_tensors.jsd_thresholds)
+
+        if do_dynatypical_p:
+            logits = _apply_dynamic_typical_sampling(
+                logits, sampling_tensors.min_typical_ps,
+                sampling_tensors.max_typical_ps)
 
         if do_temperatures and do_temp_last:
             _apply_temperatures(logits, sampling_tensors.temperatures,
@@ -539,6 +558,126 @@ def _apply_typical_sampling(
     return logits
 
 
+def _apply_dynamic_typical_sampling(
+    logits: torch.Tensor,
+    min_typical_p: torch.Tensor,
+    max_typical_p: torch.Tensor,
+) -> torch.Tensor:
+    """Applies typical sampling with a dynamic typical_p threshold based on entropy.
+
+    Args:
+        logits: Tensor of shape (batch_size, vocab_size) containing the logits.
+        min_typical_p: Minimum threshold for typical sampling.
+        max_typical_p: Maximum threshold for typical sampling.
+
+    Returns:
+        Modified logits tensor with atypical tokens masked out.
+    """
+    shifted_logits = torch.log_softmax(logits, dim=-1)
+    probs = shifted_logits.exp()
+
+    neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
+    entropy = -neg_entropy  # Entropy is the negative of the sum
+
+    # Normalize entropy to range [0, 1]
+    max_entropy = torch.log(torch.tensor(logits.size(-1), dtype=logits.dtype, device=logits.device))
+    normalized_entropy = entropy / max_entropy  # (batch_size, 1)
+
+    # Compute dynamic typical_p
+    dynamic_typical_p = min_typical_p + (max_typical_p - min_typical_p) * normalized_entropy
+    dynamic_typical_p = dynamic_typical_p.squeeze(dim=-1)  # (batch_size,)
+
+    # Calculate surprisal deviation for each token
+    surprisal_deviations = (entropy - shifted_logits).abs()
+
+    # Sort tokens by surprisal deviation
+    sorted_deviations, indices = torch.sort(surprisal_deviations, dim=-1)
+    sorted_probs = probs.gather(-1, indices)
+
+    # Compute cumulative probabilities of sorted tokens
+    cum_probs = sorted_probs.cumsum(dim=-1)
+
+    # Create a mask for tokens exceeding the dynamic typical_p threshold
+    typ_mask = cum_probs > dynamic_typical_p.unsqueeze(dim=1)
+
+    # Ensure at least one token is kept
+    typ_mask[..., 0] = False
+
+    # Scatter the mask back to the original logits order
+    typ_mask = typ_mask.scatter(1, indices, typ_mask)
+
+    # Mask out the filtered tokens
+    logits = logits.masked_fill(typ_mask, -float("inf"))
+    return logits
+
+
+def _apply_kl_divergence_sampling(
+    logits: torch.Tensor,
+    kl_thresholds: torch.Tensor,
+) -> torch.Tensor:
+    """Applies KL Divergence-based sampling to filter tokens.
+
+    Args:
+        logits: Tensor of shape (batch_size, vocab_size) containing the logits.
+        kl_thresholds: Tensor of shape (batch_size,) containing the KL divergence thresholds.
+
+    Returns:
+        Modified logits tensor with tokens below the KL divergence threshold masked out.
+    """
+    # Compute the probability distribution from logits
+    probs = torch.softmax(logits, dim=-1)
+
+    # Create a uniform distribution
+    uniform_probs = torch.full_like(probs, 1.0 / probs.size(-1))
+
+    # Calculate the KL divergence between the predicted and uniform distributions for each token
+    kl_div_tokens = -torch.log(probs + 1e-8) - torch.log(uniform_probs + 1e-8)
+
+    # Create a mask for tokens where KL divergence is less than the threshold
+    kl_mask = kl_div_tokens < kl_thresholds.unsqueeze(-1)
+
+    # Mask out tokens where the KL divergence is below the threshold
+    logits = logits.masked_fill(kl_mask, -float("inf"))
+    return logits
+
+
+def _apply_jsd_sampling(
+    logits: torch.Tensor,
+    jsd_thresholds: torch.Tensor,
+) -> torch.Tensor:
+    """Applies Jensen-Shannon Distance-based sampling to filter tokens.
+
+    Args:
+        logits: Tensor of shape (batch_size, vocab_size) containing the logits.
+        jsd_thresholds: Tensor of shape (batch_size,) with the JSD threshold for each sequence.
+
+    Returns:
+        Modified logits tensor with tokens beyond the JSD threshold masked out.
+    """
+    # Compute the probability distribution from logits
+    probs = torch.softmax(logits, dim=-1)
+
+    # Create a uniform distribution
+    uniform_probs = torch.full_like(probs, 1.0 / probs.size(-1))
+
+    # Compute the average distribution
+    average_probs = 0.5 * (probs + uniform_probs)
+
+    # Compute the Kullback-Leibler divergences
+    kl_probs = torch.sum(probs * (torch.log(probs + 1e-8) - torch.log(average_probs + 1e-8)), dim=-1)
+    kl_uniform = torch.sum(uniform_probs * (torch.log(uniform_probs + 1e-8) - torch.log(average_probs + 1e-8)), dim=-1)
+
+    # Calculate the Jensen-Shannon Distance
+    jsd = 0.5 * (kl_probs + kl_uniform)
+
+    # Create a mask for tokens where JSD is less than the threshold
+    jsd_mask = jsd.unsqueeze(-1) < jsd_thresholds.unsqueeze(-1)
+
+    # Mask out tokens where the JSD exceeds the threshold
+    logits = logits.masked_fill(~jsd_mask, -float("inf"))
+    return logits
+
+
 def _apply_quadratic_sampling(
     logits: torch.Tensor,
     smoothing_factor: torch.Tensor,

+ 59 - 5
aphrodite/modeling/sampling_metadata.py

@@ -386,6 +386,10 @@ class SamplingTensors:
     smoothing_curves: torch.Tensor
     xtc_thresholds: torch.Tensor
     xtc_probabilities: torch.Tensor
+    kl_thresholds: torch.Tensor
+    jsd_thresholds: torch.Tensor
+    min_typical_ps: torch.Tensor
+    max_typical_ps: torch.Tensor
     sampling_seeds: torch.Tensor
     sample_indices: torch.Tensor
     extra_seeds: Optional[torch.Tensor]
@@ -403,7 +407,7 @@ class SamplingTensors:
         extra_seeds_to_generate: int = 0,
         extra_entropy: Optional[Tuple[int, ...]] = None
     ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
-               bool, bool, bool, bool, bool]:
+               bool, bool, bool, bool, bool, bool, bool, bool]:
         """
         extra_seeds_to_generate: extra seeds to generate using the
             user-defined seed for each sequence.
@@ -431,6 +435,10 @@ class SamplingTensors:
         smoothing_curves: List[float] = []
         xtc_thresholds: List[float] = []
         xtc_probabilities: List[float] = []
+        kl_thresholds: List[float] = []
+        jsd_thresholds: List[float] = []
+        min_typical_ps: List[float] = []
+        max_typical_ps: List[float] = []
         sampling_seeds: List[int] = []
         sample_indices: List[int] = []
         do_penalties = False
@@ -444,6 +452,9 @@ class SamplingTensors:
         do_typical_ps = False
         do_quadratic = False
         do_xtc = False
+        do_kl_threshold = False
+        do_jsd_threshold = False
+        do_dynatypical_p = False
         do_temp_last = False
 
         if _USE_TRITON_SAMPLER:
@@ -476,6 +487,10 @@ class SamplingTensors:
             smoothing_curve = sampling_params.smoothing_curve
             xtc_threshold = sampling_params.xtc_threshold
             xtc_probability = sampling_params.xtc_probability
+            kl_threshold = sampling_params.kl_threshold
+            jsd_threshold = sampling_params.jsd_threshold
+            min_typical_p = sampling_params.min_typical_p
+            max_typical_p = sampling_params.max_typical_p
 
             # k should not be greater than the vocab size.
             top_k = min(sampling_params.top_k, vocab_size)
@@ -511,6 +526,13 @@ class SamplingTensors:
                 do_quadratic = True
             if do_xtc is False and xtc_probability > _SAMPLING_EPS:
                 do_xtc = True
+            if do_kl_threshold is False and kl_threshold > _SAMPLING_EPS:
+                do_kl_threshold = True
+            if do_jsd_threshold is False and jsd_threshold > _SAMPLING_EPS:
+                do_jsd_threshold = True
+            if do_dynatypical_p is False and (min_typical_p < 1.0 - _SAMPLING_EPS
+                                              or max_typical_p < 1.0 - _SAMPLING_EPS):
+                do_dynatypical_p = True
             if do_temp_last is False and temperature_last:
                 do_temp_last = True
 
@@ -541,6 +563,10 @@ class SamplingTensors:
                 smoothing_curves += [smoothing_curve] * prefill_len
                 xtc_thresholds += [xtc_threshold] * prefill_len
                 xtc_probabilities += [xtc_probability] * prefill_len
+                kl_thresholds += [kl_threshold] * prefill_len
+                jsd_thresholds += [jsd_threshold] * prefill_len
+                min_typical_ps += [min_typical_p] * prefill_len
+                max_typical_ps += [max_typical_p] * prefill_len
 
             if seq_group.do_sample:
                 sample_lens = len(seq_group.sample_indices)
@@ -565,6 +591,10 @@ class SamplingTensors:
                 smoothing_curves += [smoothing_curve] * len(seq_ids)
                 xtc_thresholds += [xtc_threshold] * len(seq_ids)
                 xtc_probabilities += [xtc_probability] * len(seq_ids)
+                kl_thresholds += [kl_threshold] * len(seq_ids)
+                jsd_thresholds += [jsd_threshold] * len(seq_ids)
+                min_typical_ps += [min_typical_p] * len(seq_ids)
+                max_typical_ps += [max_typical_p] * len(seq_ids)
 
             if _USE_TRITON_SAMPLER:
                 if is_prompt:
@@ -609,12 +639,14 @@ class SamplingTensors:
             temperature_lasts, top_ps, top_ks, top_as, min_ps,
             presence_penalties, frequency_penalties, repetition_penalties,
             tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
-            smoothing_curves, xtc_thresholds, xtc_probabilities,sampling_seeds,
-            sample_indices, prompt_tokens, output_tokens, vocab_size,
-            extra_seeds_to_generate, device, dtype)
+            smoothing_curves, xtc_thresholds, xtc_probabilities, kl_thresholds,
+            jsd_thresholds, min_typical_ps, max_typical_ps,
+            sampling_seeds, sample_indices, prompt_tokens, output_tokens,
+            vocab_size, extra_seeds_to_generate, device, dtype)
         return (sampling_tensors, do_penalties, do_temperatures,
                 do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
                 do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc,
+                do_kl_threshold, do_jsd_threshold, do_dynatypical_p,
                 do_temp_last)
 
     @classmethod
@@ -628,7 +660,9 @@ class SamplingTensors:
                    eta_cutoffs: List[float], epsilon_cutoffs: List[float],
                    typical_ps: List[float], smoothing_factors: List[float],
                    smoothing_curves: List[float], xtc_thresholds: List[float],
-                   xtc_probabilities: List[float], sampling_seeds: List[int],
+                   xtc_probabilities: List[float], kl_thresholds: List[float],
+                   jsd_thresholds: List[float], min_typical_ps: List[float],
+                   max_typical_ps: List[float], sampling_seeds: List[int],
                    sample_indices: List[int], prompt_tokens: List[array],
                    output_tokens: List[array], vocab_size: int,
                    extra_seeds_to_generate: int, device: torch.device,
@@ -760,6 +794,22 @@ class SamplingTensors:
                                            device="cpu",
                                            dtype=dtype,
                                            pin_memory=pin_memory)
+        kl_thresholds_t = torch.tensor(kl_thresholds,
+                                        device="cpu",
+                                        dtype=dtype,
+                                        pin_memory=pin_memory)
+        jsd_thresholds_t = torch.tensor(jsd_thresholds,
+                                        device="cpu",
+                                        dtype=dtype,
+                                        pin_memory=pin_memory)
+        min_typical_ps_t = torch.tensor(min_typical_ps,
+                                        device="cpu",
+                                        dtype=dtype,
+                                        pin_memory=pin_memory)
+        max_typical_ps_t = torch.tensor(max_typical_ps,
+                                        device="cpu",
+                                        dtype=dtype,
+                                        pin_memory=pin_memory)
         sample_indices_t = torch.tensor(
             sample_indices,
             device="cpu",
@@ -816,6 +866,10 @@ class SamplingTensors:
                                                non_blocking=True),
             xtc_probabilities=xtc_probabilities_t.to(device=device,
                                                      non_blocking=True),
+            kl_thresholds=kl_thresholds_t.to(device=device, non_blocking=True),
+            jsd_thresholds=jsd_thresholds_t.to(device=device, non_blocking=True),
+            min_typical_ps=min_typical_ps_t.to(device=device, non_blocking=True),
+            max_typical_ps=max_typical_ps_t.to(device=device, non_blocking=True),
             typical_ps=typical_ps_t.to(device=device, non_blocking=True),
             prompt_tokens=prompt_t.to(device=device, non_blocking=True),
             output_tokens=output_t.to(device=device, non_blocking=True),