AlpinDale 6 months ago
parent
commit
a11dee6352

+ 58 - 0
aphrodite/common/sampling_params.py

@@ -1,4 +1,5 @@
 """Sampling parameters for text generation."""
+import ast
 import copy
 import os
 from enum import IntEnum
@@ -62,6 +63,16 @@ class SamplingParams:
             freq_pen is applied additively while
             rep_pen is applied multiplicatively.
             Must be in [1, inf). Set to 1 to disable the effect.
+        dry_multiplier: Float that controls the magnitude of the penalty for
+            the shortest penalized sequences in the DRY sampler. Set to values
+            greater than 0 to enable DRY sampling.
+        dry_base: Float that controls how fast the penalty grows with
+            increasing sequence length in the DRY sampler.
+        dry_allowed_length: Integer that controls the maximum length of
+            sequences that can be repeated without being penalized in the DRY
+            sampler.
+        dry_sequence_breakers: Tokens across which sequence matching is not
+            continued. Specified as a comma-separated list of quoted strings.
         temperature: Float that controls the randomness of the sampling. Lower
             values make the model more deterministic, while higher values make
             the model more random. Zero means greedy sampling.
@@ -150,6 +161,10 @@ class SamplingParams:
         presence_penalty: float = 0.0,
         frequency_penalty: float = 0.0,
         repetition_penalty: float = 1.0,
+        dry_multiplier: float = 0.0,
+        dry_base: float = 1.75,
+        dry_allowed_length: int = 2,
+        dry_sequence_breakers: Union[str, List[List[int]]] = '"\\n", ":", "\\"", "*"',
         temperature: float = 1.0,
         temperature_last: bool = False,
         top_p: float = 1.0,
@@ -186,6 +201,10 @@ class SamplingParams:
         self.presence_penalty = presence_penalty
         self.frequency_penalty = frequency_penalty
         self.repetition_penalty = repetition_penalty
+        self.dry_multiplier = dry_multiplier
+        self.dry_base = dry_base
+        self.dry_allowed_length = dry_allowed_length
+        self.dry_sequence_breakers = self._parse_dry_sequence_breakers(dry_sequence_breakers)
         if 0 < temperature < _MAX_TEMP:
             logger.warning(
                 f"temperature {temperature} is less than {_MAX_TEMP}, "
@@ -246,6 +265,10 @@ class SamplingParams:
             "presence_penalty": 0.0,
             "frequency_penalty": 0.0,
             "repetition_penalty": 1.0,
+            "dry_multiplier": 0.0,
+            "dry_base": 1.75,
+            "dry_allowed_length": 2,
+            "dry_sequence_breakers": '"\\n", ":", "\\"", "*"',
             "temperature": 1.0,
             "temperature_last": False,
             "top_p": 1.0,
@@ -305,6 +328,29 @@ class SamplingParams:
         # eos_token_id is added to this by the engine
         self.all_stop_token_ids = set(self.stop_token_ids)
 
+    def _parse_dry_sequence_breakers(self, dry_sequence_breakers: Union[str, List[List[int]]]) -> List[str]:
+        if isinstance(dry_sequence_breakers, list):
+            return dry_sequence_breakers
+        try:
+            # Use ast.literal_eval to safely evaluate the string as a Python expression
+            parsed = ast.literal_eval(f'[{dry_sequence_breakers}]')
+            return [str(item) for item in parsed]
+        except (SyntaxError, ValueError):
+            # If parsing fails, return the original string as a single-item list
+            return [dry_sequence_breakers]
+
+    def tokenize_dry_sequence_breakers(self, tokenizer):
+        if not isinstance(self.dry_sequence_breakers[0], str):
+            # Already tokenized
+            return
+        
+        tokenized_breakers = []
+        for breaker in self.dry_sequence_breakers:
+            tokenized_breaker = tokenizer.encode(breaker, add_special_tokens=False)
+            tokenized_breakers.append(tokenized_breaker)
+        
+        self.dry_sequence_breakers = tokenized_breakers
+
     def _verify_args(self) -> None:
         if self.n < 1:
             raise ValueError(f"n must be at least 1, got {self.n}.")
@@ -320,6 +366,18 @@ class SamplingParams:
         if self.repetition_penalty < 1.0:
             raise ValueError("repetition_penalty must be in [1, inf), got "
                              f"{self.repetition_penalty}.")
+        if self.dry_multiplier < 0.0:
+            raise ValueError("dry_multiplier must be non-negative, got "
+                             f"{self.dry_multiplier}.")
+        if self.dry_base < 1.0:
+            raise ValueError(
+                f"dry_base must be at least 1, got {self.dry_base}.")
+        if self.dry_allowed_length < 1:
+            raise ValueError("dry_allowed_length must be at least 1, got "
+                             f"{self.dry_allowed_length}.")
+        if not all(isinstance(s, str) for s in self.dry_sequence_breakers):
+            raise ValueError(
+                "dry_sequence_breakers must be a list of strings.")
         if self.temperature < 0.0:
             raise ValueError(
                 f"temperature must be non-negative, got {self.temperature}.")

+ 20 - 0
aphrodite/endpoints/openai/protocol.py

@@ -375,6 +375,10 @@ class CompletionRequest(OpenAIBaseModel):
     smoothing_factor: Optional[float] = 0.0
     smoothing_curve: Optional[float] = 1.0
     repetition_penalty: Optional[float] = 1.0
+    dry_multiplier: Optional[float] = 0.0
+    dry_base: Optional[float] = 1.75
+    dry_allowed_length: Optional[int] = 2
+    dry_sequence_breakers: Optional[List[str]] = Field(default_factory=list)
     length_penalty: Optional[float] = 1.0
     early_stopping: Optional[bool] = False
     stop_token_ids: Optional[List[int]] = Field(default_factory=list)
@@ -431,6 +435,16 @@ class CompletionRequest(OpenAIBaseModel):
 
     # doc: end-completion-extra-params
 
+    def _tokenize_dry_sequence_breakers(self, tokenizer: PreTrainedTokenizer):
+        if not self.dry_sequence_breakers:
+            return []
+        
+        tokenized_breakers = []
+        for breaker in self.dry_sequence_breakers:
+            tokenized_breaker = tokenizer.encode(breaker, add_special_tokens=False)
+            tokenized_breakers.extend(tokenized_breaker)
+        return tokenized_breakers
+
     def to_sampling_params(
             self, tokenizer: PreTrainedTokenizer,
             guided_decode_logits_processor: Optional[LogitsProcessorFunc],
@@ -449,6 +463,8 @@ class CompletionRequest(OpenAIBaseModel):
         if guided_decode_logits_processor:
             logits_processors.append(guided_decode_logits_processor)
 
+        tokenized_dry_sequence_breakers = self._tokenize_dry_sequence_breakers(tokenizer)
+
         return SamplingParams(
             n=self.n,
             best_of=self.best_of,
@@ -484,6 +500,10 @@ class CompletionRequest(OpenAIBaseModel):
             logits_processors=logits_processors,
             truncate_prompt_tokens=self.truncate_prompt_tokens,
             temperature_last=self.temperature_last,
+            dry_multiplier=self.dry_multiplier,
+            dry_base=self.dry_base,
+            dry_allowed_length=self.dry_allowed_length,
+            dry_sequence_breakers=tokenized_dry_sequence_breakers,
         )
 
     @model_validator(mode="before")

+ 59 - 2
aphrodite/modeling/layers/sampler.py

@@ -70,13 +70,14 @@ class Sampler(nn.Module):
         self._sampling_tensors = None
 
         # Initialize new sampling tensors
-        (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as, do_min_p,
-         do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
+        (sampling_tensors, do_penalties, do_dries, do_top_p_top_k, do_top_as,
+        do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
          do_quadratic, do_temp_last) = SamplingTensors.from_sampling_metadata(
              sampling_metadata, vocab_size, logits.device, logits.dtype)
 
         self._sampling_tensors = sampling_tensors
         self._do_penalties = do_penalties
+        self._do_dries = do_dries
         self._do_top_p_top_k = do_top_p_top_k
         self._do_top_as = do_top_as
         self._do_min_p = do_min_p
@@ -113,6 +114,7 @@ class Sampler(nn.Module):
         assert self._sampling_tensors is not None
         sampling_tensors = self._sampling_tensors
         do_penalties = self._do_penalties
+        do_dries = self._do_dries
         do_top_p_top_k = self._do_top_p_top_k
         do_top_as = self._do_top_as
         do_min_p = self._do_min_p
@@ -132,6 +134,14 @@ class Sampler(nn.Module):
                                       sampling_tensors.presence_penalties,
                                       sampling_tensors.frequency_penalties,
                                       sampling_tensors.repetition_penalties)
+            
+        if do_dries:
+            logits = _apply_dry_penalty(logits, sampling_tensors.prompt_tokens,
+                                        sampling_tensors.output_tokens,
+                                        sampling_tensors.dry_multipliers,
+                                        sampling_tensors.dry_bases,
+                                        sampling_tensors.dry_allowed_lengths,
+                                        sampling_tensors.dry_sequence_breakerss)
 
         # Apply temperature scaling if not doing temp_last.
         if not do_temp_last:
@@ -286,6 +296,53 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
     return logits
 
 
+def _apply_dry_penalty(
+    logits: torch.Tensor,
+    prompt_tokens: torch.Tensor,
+    output_tokens: torch.Tensor,
+    dry_multipliers: torch.Tensor,
+    dry_bases: torch.Tensor,
+    dry_allowed_lengths: torch.Tensor,
+    dry_sequence_breakerss: torch.Tensor,
+) -> torch.Tensor:
+    num_seqs, vocab_size = logits.shape
+    device = logits.device
+
+    for i in range(num_seqs):
+        input_ids = torch.cat([prompt_tokens[i], output_tokens[i]])
+        if input_ids.numel() == 0:
+            continue
+
+        last_token = input_ids[-1].item()
+        if any(last_token in breaker for breaker in dry_sequence_breakerss[i]):
+            continue
+
+        match_indices = (input_ids[:-1] == last_token).nonzero().squeeze()
+        match_lengths = {}
+
+        for idx in match_indices:
+            next_token = input_ids[idx + 1].item()
+            if any(next_token in breaker for breaker in dry_sequence_breakerss[i]):
+                continue
+
+            match_length = 1
+            while idx - match_length >= 0:
+                previous_token = input_ids[-(match_length + 1)].item()
+                if (input_ids[idx - match_length] != previous_token or
+                    any(previous_token in breaker for breaker in dry_sequence_breakerss[i])):
+                    break
+                match_length += 1
+
+            match_lengths[next_token] = max(match_lengths.get(next_token, 0), match_length)
+
+        for token, match_length in match_lengths.items():
+            if match_length >= dry_allowed_lengths[i]:
+                penalty = dry_multipliers[i] * (dry_bases[i] ** (match_length - dry_allowed_lengths[i]))
+                logits[i, token] -= penalty
+
+    return logits
+
+
 def _apply_token_bans(logits: torch.Tensor,
                       banned_tokens: List[List[int]]) -> torch.Tensor:
     for i, banned_token_ids in enumerate(banned_tokens):

+ 79 - 8
aphrodite/modeling/sampling_metadata.py

@@ -254,6 +254,7 @@ def _prepare_seq_groups(
 
             sample_obj.prompt_logprob_indices.clear()
             sample_obj.sample_indices.clear()
+        dry_sequence_breakerss = []
         sampling_params = seq_group_metadata.sampling_params
         is_prompt = seq_group_metadata.is_prompt
         generator: Optional[torch.Generator] = None
@@ -265,6 +266,7 @@ def _prepare_seq_groups(
         sample_indices: List[int] = \
             sample_obj.sample_indices if cache is not None else []
         do_sample = seq_group_metadata.do_sample
+        dry_sequence_breakerss.extend([sampling_params.dry_sequence_breakers] * len(seq_ids))
 
         if seq_group_metadata.is_prompt:
             if sampling_params.seed is not None:
@@ -375,6 +377,10 @@ class SamplingTensors:
     presence_penalties: torch.Tensor
     frequency_penalties: torch.Tensor
     repetition_penalties: torch.Tensor
+    dry_multipliers: torch.Tensor
+    dry_bases: torch.Tensor
+    dry_allowed_lengths: torch.Tensor
+    dry_sequence_breakerss: torch.Tensor
     tfss: torch.Tensor
     eta_cutoffs: torch.Tensor
     epsilon_cutoffs: torch.Tensor
@@ -398,7 +404,7 @@ class SamplingTensors:
         extra_seeds_to_generate: int = 0,
         extra_entropy: Optional[Tuple[int, ...]] = None
     ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
-               bool, bool, bool]:
+               bool, bool, bool, bool]:
         """
         extra_seeds_to_generate: extra seeds to generate using the
             user-defined seed for each sequence.
@@ -415,6 +421,10 @@ class SamplingTensors:
         presence_penalties: List[float] = []
         frequency_penalties: List[float] = []
         repetition_penalties: List[float] = []
+        dry_multipliers: List[float] = []
+        dry_bases: List[float] = []
+        dry_allowed_lengths: List[int] = []
+        dry_sequence_breakerss: List[int] = []
         tfss: List[float] = []
         eta_cutoffs: List[float] = []
         epsilon_cutoffs: List[float] = []
@@ -424,6 +434,7 @@ class SamplingTensors:
         sampling_seeds: List[int] = []
         sample_indices: List[int] = []
         do_penalties = False
+        do_dries = False
         do_top_p_top_k = False
         do_top_as = False
         do_min_p = False
@@ -450,6 +461,10 @@ class SamplingTensors:
             p = sampling_params.presence_penalty
             f = sampling_params.frequency_penalty
             r = sampling_params.repetition_penalty
+            dry_multiplier = sampling_params.dry_multiplier
+            dry_base = sampling_params.dry_base
+            dry_allowed_length = sampling_params.dry_allowed_length
+            dry_sequence_breakers = sampling_params.dry_sequence_breakers
             top_p = sampling_params.top_p
             top_a = sampling_params.top_a
             min_p = sampling_params.min_p
@@ -479,6 +494,8 @@ class SamplingTensors:
                                      or abs(f) >= _SAMPLING_EPS
                                      or abs(r - 1.0) >= _SAMPLING_EPS):
                 do_penalties = True
+            if do_dries is False and dry_multiplier > _SAMPLING_EPS:
+                do_dries = True
             if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
                 do_tfss = True
             if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
@@ -509,6 +526,10 @@ class SamplingTensors:
                 presence_penalties += [0] * prefill_len
                 frequency_penalties += [0] * prefill_len
                 repetition_penalties += [1] * prefill_len
+                dry_multipliers += [0] * prefill_len
+                dry_bases += [0] * prefill_len
+                dry_allowed_lengths += [0] * prefill_len
+                dry_sequence_breakerss += [0] * prefill_len
                 tfss += [1] * prefill_len
                 eta_cutoffs += [0] * prefill_len
                 epsilon_cutoffs += [0] * prefill_len
@@ -528,6 +549,10 @@ class SamplingTensors:
                 presence_penalties += [p] * len(seq_ids)
                 frequency_penalties += [f] * len(seq_ids)
                 repetition_penalties += [r] * len(seq_ids)
+                dry_multipliers += [dry_multiplier] * len(seq_ids)
+                dry_bases += [dry_base] * len(seq_ids)
+                dry_allowed_lengths += [dry_allowed_length] * len(seq_ids)
+                dry_sequence_breakerss += [dry_sequence_breakers] * len(seq_ids)
                 tfss += [tfs] * len(seq_ids)
                 eta_cutoffs += [eta_cutoff] * len(seq_ids)
                 epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
@@ -576,12 +601,14 @@ class SamplingTensors:
         sampling_tensors = SamplingTensors.from_lists(
             temperatures, temperature_lasts, top_ps, top_ks, top_as, min_ps,
             presence_penalties, frequency_penalties, repetition_penalties,
-            tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
-            smoothing_curves, sampling_seeds, sample_indices, prompt_tokens,
-            output_tokens, vocab_size, extra_seeds_to_generate, device, dtype)
-        return (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as,
-                do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
-                do_typical_ps, do_quadratic, do_temp_last)
+            dry_multipliers, dry_bases, dry_allowed_lengths,
+            dry_sequence_breakerss, tfss, eta_cutoffs, epsilon_cutoffs,
+            typical_ps, smoothing_factors, smoothing_curves, sampling_seeds,
+            sample_indices, prompt_tokens, output_tokens, vocab_size,
+            extra_seeds_to_generate, device, dtype)
+        return (sampling_tensors, do_penalties, do_dries, do_top_p_top_k,
+                do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
+                do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_temp_last)
 
     @classmethod
     def from_lists(cls, temperatures: List[float],
@@ -589,7 +616,10 @@ class SamplingTensors:
                    top_ks: List[int], top_as: List[float],
                    min_ps: List[float], presence_penalties: List[float],
                    frequency_penalties: List[float],
-                   repetition_penalties: List[float], tfss: List[float],
+                   repetition_penalties: List[float],
+                   dry_multipliers: List[float], dry_bases: List[float],
+                   dry_allowed_lengths: List[int],
+                   dry_sequence_breakerss: List[List[List[int]]], tfss: List[float],
                    eta_cutoffs: List[float], epsilon_cutoffs: List[float],
                    typical_ps: List[float], smoothing_factors: List[float],
                    smoothing_curves: List[float], sampling_seeds: List[int],
@@ -668,6 +698,30 @@ class SamplingTensors:
             dtype=dtype,
             pin_memory=pin_memory,
         )
+        dry_multipliers_t = torch.tensor(
+            dry_multipliers,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        dry_bases_t = torch.tensor(
+            dry_bases,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        dry_allowed_lengths_t = torch.tensor(
+            dry_allowed_lengths,
+            device="cpu",
+            dtype=torch.int,
+            pin_memory=pin_memory,
+        )
+        # dry_sequence_breakerss_t = torch.tensor(
+        #     dry_sequence_breakerss,
+        #     device="cpu",
+        #     dtype=torch.int,
+        #     pin_memory=pin_memory,
+        # )
         top_ks_t = torch.tensor(
             top_ks,
             device="cpu",
@@ -726,6 +780,16 @@ class SamplingTensors:
             extra_seeds_gpu = None
         sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
 
+        max_breakers = max(len(breakers) for breakers in dry_sequence_breakerss)
+        max_breaker_length = max(max(len(breaker) for breaker in breakers) for breakers in dry_sequence_breakerss)
+        
+        dry_sequence_breakerss_t = torch.full((len(dry_sequence_breakerss), max_breakers, max_breaker_length), 
+                                              -1, device="cpu", dtype=torch.long, pin_memory=pin_memory)
+        
+        for i, breakers in enumerate(dry_sequence_breakerss):
+            for j, breaker in enumerate(breakers):
+                dry_sequence_breakerss_t[i, j, :len(breaker)] = torch.tensor(breaker, dtype=torch.long)
+
         return cls(
             temperatures=temperatures_t.to(device=device, non_blocking=True),
             temperature_lasts=temp_lasts_t.to(device=device, non_blocking=True),
@@ -739,6 +803,13 @@ class SamplingTensors:
                                                          non_blocking=True),
             repetition_penalties=repetition_penalties_t.to(device=device,
                                                            non_blocking=True),
+            dry_multipliers=dry_multipliers_t.to(device=device,
+                                                 non_blocking=True),
+            dry_bases=dry_bases_t.to(device=device, non_blocking=True),
+            dry_allowed_lengths=dry_allowed_lengths_t.to(device=device,
+                                                        non_blocking=True),
+            dry_sequence_breakerss=dry_sequence_breakerss_t.to(device=device,
+                                                              non_blocking=True),
             tfss=tfss_t.to(device=device, non_blocking=True),
             eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
             epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,