Prechádzať zdrojové kódy

feat: Add DRY (Don't Repeat Yourself) sampling (#827)

* feat(sampling): Add DRY (Do not Repeat Yourself) sampling

Adds a logits processor that exponentially penalizes repetitive sequences.
Parameters control penalty strength (multiplier), growth rate (base), minimum
length before penalties (allowed_length), and tokens that break sequence
matching (sequence_breakers).

Key changes:
- Add DRY parameters to SamplingParams
- Add DRY tensors to SamplingTensors
- Implement _apply_dry() logits processor

* Set useful defaults for DRY API

* Fix refactor bug

* Fix naming bug

* Refactor based on closer examination of other examples

* Revert "Refactor based on closer examination of other examples"

This reverts commit d9562b96f814c11ad3f69197b137a09e7f2ab4ff.

I am lost

* take sequence breakers as both a string literal list and a List[int]

The original ooba implementation wraps the entire thing in a string
literal,
e.g. `dry_sequence_breakers: '["\\n",":","\\"","*"]'`. This is obviously
not a valid list, so we have to handle that case as well.

* formatting

* fix: sequence breaker ids are a list of int

* fix: init the prompt and output token tensors for dry

* fix dry penalties

Author: AThirdPath <discordianbelle@gmail.com>

* add co-author

Co-authored-by: AThirdPath <discordianbelle@gmail.com>

* apply dry first

* limit the backwards match to 50

---------

Co-authored-by: AlpinDale <alpindale@gmail.com>
Co-authored-by: AThirdPath <discordianbelle@gmail.com>
Selali 3 mesiacov pred
rodič
commit
4c4a365f77

+ 37 - 2
aphrodite/common/sampling_params.py

@@ -153,6 +153,23 @@ class SamplingParams(
             (max_logit - nsgima * std_dev) are filtered out. Higher values
             (e.g. 3.0) keep more tokens, lower values (e.g. 1.0) are more
             selective. Must be positive. 0 to disable.
+        dry_multiplier: Float that controls the magnitude of the DRY sampling
+            penalty. Higher values create stronger penalties against
+            repetition. The penalty is multiplied by this value before being
+            applied. Must be non-negative. 0 disables the sampler.
+        dry_base: Base for the exponential growth of the DRY sampling penalty.
+            Controls how quickly the penalty increases with longer repeated
+            sequences. Must be greater than 1. Higher values (e.g. 2.0) create
+            more aggressive penalties for longer repetitions. Defaults to 1.75.
+        dry_allowed_length: Maximum number of tokens that can be repeated
+            without incurring a DRY sampling penalty. Sequences longer than
+            this will be penalized exponentially. Must be at least 1.
+            Defaults to 2.
+        dry_sequence_breaker_ids: List of token IDs that stop
+            the matching of repeated content. These tokens will break up the
+            input into sections where repetition is evaluated separately.
+            Common examples are newlines, quotes, and other structural tokens.
+            Defaults to None.
     """
 
     n: int = 1
@@ -199,7 +216,10 @@ class SamplingParams(
     xtc_threshold: float = 0.1
     xtc_probability: float = 0
     nsigma: float = 0.0
-
+    dry_multiplier: float = 0.0
+    dry_base: float = 1.75
+    dry_allowed_length: int = 2
+    dry_sequence_breaker_ids: List[int] = []
     # The below fields are not supposed to be used as an input.
     # They are set in post_init.
     output_text_buffer_length: int = 0
@@ -246,6 +266,10 @@ class SamplingParams(
         "xtc_threshold": 0.1,
         "xtc_probability": 0,
         "nsigma": 0.0,
+        "dry_multiplier": 0.0,
+        "dry_base": 1.75,
+        "dry_allowed_length": 2,
+        "dry_sequence_breaker_ids": [],
     }
 
     def __post_init__(self) -> None:
@@ -379,7 +403,18 @@ class SamplingParams(
             raise ValueError(
                 "nsigma must be non-negative, got "
                 f"{self.nsigma}.")
-            
+        if self.dry_multiplier < 0.0:
+            raise ValueError(
+                "dry_multiplier must be non-negative, got "
+                f"{self.dry_multiplier}.")
+        if self.dry_base <= 1.0:
+            raise ValueError(
+                "dry_base must be greater than 1, got "
+                f"{self.dry_base}.")
+        if self.dry_allowed_length < 0:
+            raise ValueError(
+                "dry_allowed_length must be non-negative, got "
+                f"{self.dry_allowed_length}.")    
 
     def _verify_beam_search(self) -> None:
         if self.best_of == 1:

+ 60 - 0
aphrodite/endpoints/openai/protocol.py

@@ -1,5 +1,6 @@
 # Adapted from
 # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
+import json
 import time
 from typing import Any, Dict, List, Literal, Optional, Union
 
@@ -147,6 +148,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
     prompt_logprobs: Optional[int] = None
     xtc_threshold: Optional[float] = 0.1
     xtc_probability: Optional[float] = 0.0
+    dry_multiplier: Optional[float] = 0
+    dry_base: Optional[float] = 1.75
+    dry_allowed_length: Optional[int] = 2
+    dry_sequence_breakers: Optional[List[str]] = Field(
+        default=["\n", ":", "\"", "*"])
     dynatemp_min: Optional[float] = 0.0
     dynatemp_max: Optional[float] = 0.0
     dynatemp_exponent: Optional[float] = 1.0
@@ -255,6 +261,13 @@ class ChatCompletionRequest(OpenAIBaseModel):
         if guided_decode_logits_processor:
             logits_processors.append(guided_decode_logits_processor)
 
+
+        dry_sequence_breaker_ids = []
+        if self.dry_sequence_breakers:
+            for s in self.dry_sequence_breakers:
+                token_id = tokenizer.encode(f'a{s}')[-1]
+                dry_sequence_breaker_ids.append(token_id)
+
         return SamplingParams(
             n=self.n,
             presence_penalty=self.presence_penalty,
@@ -291,6 +304,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
             temperature_last=self.temperature_last,
             xtc_threshold=self.xtc_threshold,
             xtc_probability=self.xtc_probability,
+            dry_multiplier=self.dry_multiplier,
+            dry_base=self.dry_base,
+            dry_allowed_length=self.dry_allowed_length,
+            dry_sequence_breaker_ids=dry_sequence_breaker_ids,
             dynatemp_min=self.dynatemp_min,
             dynatemp_max=self.dynatemp_max,
             dynatemp_exponent=self.dynatemp_exponent,
@@ -403,6 +420,11 @@ class CompletionRequest(OpenAIBaseModel):
     prompt_logprobs: Optional[int] = None
     xtc_threshold: Optional[float] = 0.1
     xtc_probability: Optional[float] = 0.0
+    dry_multiplier: Optional[float] = 0
+    dry_base: Optional[float] = 1.75
+    dry_allowed_length: Optional[int] = 2
+    dry_sequence_breakers: Optional[List[str]] = Field(
+        default=["\n", ":", "\"", "*"])
     dynatemp_min: Optional[float] = 0.0
     dynatemp_max: Optional[float] = 0.0
     dynatemp_exponent: Optional[float] = 1.0
@@ -469,6 +491,13 @@ class CompletionRequest(OpenAIBaseModel):
         if guided_decode_logits_processor:
             logits_processors.append(guided_decode_logits_processor)
 
+        dry_sequence_breaker_ids = []
+        if self.dry_sequence_breakers:
+            for s in self.dry_sequence_breakers:
+                s = bytes(s, "utf-8").decode("unicode_escape")
+                token_id = tokenizer.encode(f'a{s}')[-1]
+                dry_sequence_breaker_ids.append(token_id)
+
         return SamplingParams(
             n=self.n,
             best_of=self.best_of,
@@ -506,6 +535,10 @@ class CompletionRequest(OpenAIBaseModel):
             temperature_last=self.temperature_last,
             xtc_threshold=self.xtc_threshold,
             xtc_probability=self.xtc_probability,
+            dry_multiplier=self.dry_multiplier,
+            dry_base=self.dry_base,
+            dry_allowed_length=self.dry_allowed_length,
+            dry_sequence_breaker_ids=dry_sequence_breaker_ids,
             dynatemp_min=self.dynatemp_min,
             dynatemp_max=self.dynatemp_max,
             dynatemp_exponent=self.dynatemp_exponent,
@@ -542,6 +575,33 @@ class CompletionRequest(OpenAIBaseModel):
             raise ValueError(
                 "Stream options can only be defined when stream is True.")
         return data
+    
+    @model_validator(mode='before')
+    @classmethod
+    def parse_dry_sequence_breakers(cls, data):
+        if 'dry_sequence_breakers' in data:
+            breakers = data['dry_sequence_breakers']
+            if isinstance(breakers, str):
+                try:
+                    # Try to parse as JSON string
+                    data['dry_sequence_breakers'] = json.loads(breakers)
+                except json.JSONDecodeError as e:
+                    raise ValueError(f"Invalid JSON for dry_sequence_breakers:"
+                                     f" {e}") from e
+                
+            # Validate that we now have a list of strings
+            is_list = isinstance(data['dry_sequence_breakers'], list)
+            all_strings = all(
+                isinstance(x, str) 
+                for x in data['dry_sequence_breakers']
+            )
+            if not is_list or not all_strings:
+                raise ValueError(
+                    "dry_sequence_breakers must be a list of strings or a "
+                    "JSON string representing a list of strings"
+                )
+        
+        return data
 
 
 class EmbeddingRequest(OpenAIBaseModel):

+ 100 - 1
aphrodite/modeling/layers/sampler.py

@@ -85,7 +85,7 @@ class Sampler(nn.Module):
         # Initialize new sampling tensors
         (sampling_tensors, do_penalties, do_temperatures, do_top_p_top_k,
          do_top_as, do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
-         do_typical_ps, do_quadratic, do_xtc, do_nsigmas, do_temp_last
+         do_typical_ps, do_quadratic, do_xtc, do_nsigmas, do_dry, do_temp_last
          ) = SamplingTensors.from_sampling_metadata(
              sampling_metadata, vocab_size, logits.device, logits.dtype)
 
@@ -102,6 +102,7 @@ class Sampler(nn.Module):
         self._do_quadratic = do_quadratic
         self._do_xtc = do_xtc
         self._do_nsgimas = do_nsigmas
+        self._do_dry = do_dry
         self._do_temp_last = do_temp_last
 
     def forward(
@@ -141,10 +142,21 @@ class Sampler(nn.Module):
         do_quadratic = self._do_quadratic
         do_xtc = self._do_xtc
         do_nsigmas = self._do_nsgimas
+        do_dry = self._do_dry
         do_temp_last = self._do_temp_last
 
         logits = _apply_min_tokens_penalty(logits, sampling_metadata)
 
+        if do_dry:
+            logits = _apply_dry(
+                logits,
+                sampling_tensors.prompt_tokens,
+                sampling_tensors.dry_multipliers,
+                sampling_tensors.dry_bases, 
+                sampling_tensors.dry_allowed_lengths,
+                sampling_tensors.dry_sequence_breaker_ids
+            )
+
         # Apply presence and frequency penalties.
         if do_penalties:
             logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
@@ -406,6 +418,93 @@ def _apply_min_tokens_penalty(
     assert logits_applied == logits.shape[0]
     return logits
 
+def _apply_dry(
+    logits: torch.Tensor,
+    input_ids: torch.Tensor,
+    multipliers: torch.Tensor, 
+    bases: torch.Tensor,
+    allowed_lengths: torch.Tensor,
+    sequence_breakers_ids: torch.Tensor
+) -> torch.Tensor:
+    """
+    Apply Exclude Don't Repeat Yourself (DRY) sampling to the logits.
+
+    Reference: https://github.com/oobabooga/text-generation-webui/pull/5677
+    """
+    # Don't apply dry penalties if multiplier is 0
+    if torch.all(multipliers == 0):
+        return logits
+    
+    # Process each sequence in the batch
+    for i, (input_ids_row, logits_row) in enumerate(zip(input_ids, logits)):
+        multiplier = multipliers[i].item()
+        if multiplier == 0:
+            continue  # Skip processing for this sequence
+        # Get the last token
+        last_token = input_ids_row[-1].item()
+
+        # Skip if last token is a sequence breaker
+        if last_token in sequence_breakers_ids:
+            continue
+
+        # Find matches of the last token, excluding the last position
+        match_indices = (input_ids_row[:-1] == last_token).nonzero()
+
+        # Track max matching sequence length for each potential next token
+        match_lengths = {}
+
+        # Process each match
+        for idx in match_indices:
+            # Convert to scalar
+            idx = idx.item()
+            
+            # Get the token that followed this match in the input
+            next_token = input_ids_row[idx + 1].item()
+
+            # Skip if next token is a sequence breaker
+            if next_token in sequence_breakers_ids:
+                continue
+
+            # We found last_token matches at this index, so match length starts
+            # at 1
+            match_length = 1
+
+            # Try to extend match backwards
+            while match_length < 50:
+                j = idx - match_length
+                k = len(input_ids_row) - match_length - 1
+                if j < 0 or k < 0:
+                    # Reached start of input
+                    break
+
+                if input_ids_row[j].item() != input_ids_row[k].item():
+                    # No more matches
+                    break
+
+                if input_ids_row[k].item() in sequence_breakers_ids:
+                    # Hit a sequence breaker
+                    break
+
+                match_length += 1
+
+            # Update max match length for this next token
+            if next_token in match_lengths:
+                match_lengths[next_token] = max(
+                    match_length, match_lengths[next_token])
+            else:
+                match_lengths[next_token] = match_length
+
+        # Apply penalties based on match lengths
+        allowed_length = allowed_lengths[i]
+        multiplier = multipliers[i]  
+        base = bases[i]
+
+        for token, match_length in match_lengths.items():
+            if match_length >= allowed_length:
+                penalty = multiplier * (base ** (match_length - allowed_length))
+                logits_row[token] -= penalty
+
+    return logits
 
 def _apply_top_k_top_p(
     logits: torch.Tensor,

+ 59 - 5
aphrodite/modeling/sampling_metadata.py

@@ -388,6 +388,10 @@ class SamplingTensors:
     xtc_thresholds: torch.Tensor
     xtc_probabilities: torch.Tensor
     nsigmas: torch.Tensor
+    dry_multipliers: torch.Tensor
+    dry_bases: torch.Tensor
+    dry_allowed_lengths: torch.Tensor
+    dry_sequence_breaker_ids: torch.Tensor
     sampling_seeds: torch.Tensor
     sample_indices: torch.Tensor
     extra_seeds: Optional[torch.Tensor]
@@ -405,7 +409,7 @@ class SamplingTensors:
         extra_seeds_to_generate: int = 0,
         extra_entropy: Optional[Tuple[int, ...]] = None
     ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
-               bool, bool, bool, bool, bool, bool]:
+               bool, bool, bool, bool, bool, bool, bool]:
         """
         extra_seeds_to_generate: extra seeds to generate using the
             user-defined seed for each sequence.
@@ -436,6 +440,11 @@ class SamplingTensors:
         nsigmas: List[float] = []
         sampling_seeds: List[List[int]] = []
         sample_indices: List[int] = []
+        dry_multipliers: List[float] = []
+        dry_bases: List[float] = []
+        dry_allowed_lengths: List[int] = []
+        dry_sequence_breaker_ids: List[List[int]] = []
+
         do_penalties = False
         do_temperatures = False
         do_top_p_top_k = False
@@ -448,6 +457,7 @@ class SamplingTensors:
         do_quadratic = False
         do_xtc = False
         do_nsigmas = False
+        do_dry = False
         do_temp_last = False
 
         if _USE_TRITON_SAMPLER:
@@ -491,6 +501,8 @@ class SamplingTensors:
                              params.smoothing_curve > 1.0)
             do_xtc |= params.xtc_probability > _SAMPLING_EPS
             do_nsigmas |= params.nsigma > _SAMPLING_EPS
+            do_dry |= params.dry_multiplier > _SAMPLING_EPS
+
             do_temp_last |= params.temperature_last
 
             is_prompt = seq_group.is_prompt
@@ -526,6 +538,11 @@ class SamplingTensors:
             xtc_thresholds += [params.xtc_threshold] * n_seqs
             xtc_probabilities += [params.xtc_probability] * n_seqs
             nsigmas += [params.nsigma] * n_seqs
+            dry_multipliers += [params.dry_multiplier] * n_seqs
+            dry_bases += [params.dry_base] * n_seqs
+            dry_allowed_lengths += [params.dry_allowed_length] * n_seqs
+            dry_sequence_breaker_ids += (
+                [params.dry_sequence_breaker_ids] * n_seqs)
 
             if _USE_TRITON_SAMPLER:
                 if is_prompt:
@@ -549,7 +566,7 @@ class SamplingTensors:
                     sampling_seeds.append(seq_seeds)
                 sample_indices.extend(seq_group.sample_indices)
 
-        if do_penalties:
+        if do_penalties or do_dry:
             for seq_group in sampling_metadata.seq_groups:
                 seq_ids = seq_group.seq_ids
                 if (seq_group.is_prompt
@@ -573,12 +590,14 @@ class SamplingTensors:
             presence_penalties, frequency_penalties, repetition_penalties,
             tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
             smoothing_curves, xtc_thresholds, xtc_probabilities, nsigmas,
-            sampling_seeds, sample_indices, prompt_tokens, output_tokens,
-            vocab_size, extra_seeds_to_generate, device, dtype)
+            dry_multipliers, dry_bases, dry_allowed_lengths,
+            dry_sequence_breaker_ids, sampling_seeds, sample_indices,
+            prompt_tokens, output_tokens, vocab_size, extra_seeds_to_generate,
+            device, dtype)
         return (sampling_tensors, do_penalties, do_temperatures,
                 do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
                 do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc,
-                do_nsigmas, do_temp_last)
+                do_nsigmas, do_dry, do_temp_last)
 
     @classmethod
     def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
@@ -592,6 +611,9 @@ class SamplingTensors:
                    typical_ps: List[float], smoothing_factors: List[float],
                    smoothing_curves: List[float], xtc_thresholds: List[float],
                    xtc_probabilities: List[float], nsigmas: List[float],
+                   dry_multipliers: List[float], dry_bases: List[float],
+                   dry_allowed_lengths: List[int],
+                   dry_sequence_breaker_ids: List[List[int]],
                    sampling_seeds: List[List[int]],
                    sample_indices: List[int], prompt_tokens: List[array],
                    output_tokens: List[array], vocab_size: int,
@@ -728,6 +750,31 @@ class SamplingTensors:
                                  device="cpu",
                                  dtype=dtype,
                                  pin_memory=pin_memory)
+        dry_multipliers_t = torch.tensor(
+            dry_multipliers,
+            device="cpu", 
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        dry_bases_t = torch.tensor(
+            dry_bases,
+            device="cpu",
+            dtype=dtype, 
+            pin_memory=pin_memory,
+        )
+        dry_allowed_lengths_t = torch.tensor(
+            dry_allowed_lengths,
+            device="cpu",
+            dtype=torch.int,
+            pin_memory=pin_memory,
+        )
+        dry_sequence_breakers_t = torch.tensor(
+            dry_sequence_breaker_ids, 
+            device="cpu",
+            dtype=torch.long,
+            pin_memory=pin_memory,
+        )
+
         sample_indices_t = torch.tensor(
             sample_indices,
             device="cpu",
@@ -785,6 +832,13 @@ class SamplingTensors:
             xtc_probabilities=xtc_probabilities_t.to(device=device,
                                                      non_blocking=True),
             nsigmas=nsigmas_t.to(device=device, non_blocking=True),
+            dry_multipliers=dry_multipliers_t.to(device=device,
+                                                 non_blocking=True),
+            dry_bases=dry_bases_t.to(device=device, non_blocking=True),
+            dry_allowed_lengths=dry_allowed_lengths_t.to(device=device,
+                                                         non_blocking=True),
+            dry_sequence_breaker_ids=dry_sequence_breakers_t.to(device=device,
+                                                                non_blocking=True),
             typical_ps=typical_ps_t.to(device=device, non_blocking=True),
             prompt_tokens=prompt_t.to(device=device, non_blocking=True),
             output_tokens=output_t.to(device=device, non_blocking=True),