Jelajahi Sumber

feat: add no_repeat_ngram sampler (#832)

AlpinDale 3 bulan lalu
induk
melakukan
ba9d8f631a

+ 5 - 0
aphrodite/common/sampling_params.py

@@ -65,6 +65,9 @@ class SamplingParams(
             freq_pen is applied additively while
             rep_pen is applied multiplicatively.
             Must be in [1, inf). Set to 1 to disable the effect.
+        no_repeat_ngram_size: Size of the n-grams to prevent repeating.
+            1 would mean no token can appear twice.
+            2 would mean no pair of consecutive tokens can appear twice.
         temperature: Float that controls the randomness of the sampling. Lower
             values make the model more deterministic, while higher values make
             the model more random. Zero means greedy sampling.
@@ -177,6 +180,7 @@ class SamplingParams(
     presence_penalty: float = 0.0
     frequency_penalty: float = 0.0
     repetition_penalty: float = 1.0
+    no_repeat_ngram_size: int = 0
     temperature: float = 1.0
     dynatemp_min: float = 0.0
     dynatemp_max: float = 0.0
@@ -231,6 +235,7 @@ class SamplingParams(
         "presence_penalty": 0.0,
         "frequency_penalty": 0.0,
         "repetition_penalty": 1.0,
+        "no_repeat_ngram_size": 0,
         "temperature": 1.0,
         "dynatemp_min": 0.0,
         "dynatemp_max": 0.0,

+ 4 - 0
aphrodite/endpoints/openai/protocol.py

@@ -136,6 +136,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
     smoothing_factor: Optional[float] = 0.0
     smoothing_curve: Optional[float] = 1.0
     repetition_penalty: Optional[float] = 1.0
+    no_repeat_ngram_size: Optional[int] = 0
     length_penalty: Optional[float] = 1.0
     early_stopping: Optional[bool] = False
     ignore_eos: Optional[bool] = False
@@ -273,6 +274,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
             presence_penalty=self.presence_penalty,
             frequency_penalty=self.frequency_penalty,
             repetition_penalty=self.repetition_penalty,
+            no_repeat_ngram_size=self.no_repeat_ngram_size,
             temperature=self.temperature,
             top_p=self.top_p,
             min_p=self.min_p,
@@ -405,6 +407,7 @@ class CompletionRequest(OpenAIBaseModel):
     smoothing_factor: Optional[float] = 0.0
     smoothing_curve: Optional[float] = 1.0
     repetition_penalty: Optional[float] = 1.0
+    no_repeat_ngram_size: Optional[int] = 0
     length_penalty: Optional[float] = 1.0
     early_stopping: Optional[bool] = False
     stop_token_ids: Optional[List[int]] = Field(default_factory=list)
@@ -504,6 +507,7 @@ class CompletionRequest(OpenAIBaseModel):
             presence_penalty=self.presence_penalty,
             frequency_penalty=self.frequency_penalty,
             repetition_penalty=self.repetition_penalty,
+            no_repeat_ngram_size=self.no_repeat_ngram_size,
             temperature=self.temperature,
             top_p=self.top_p,
             top_k=self.top_k,

+ 123 - 3
aphrodite/modeling/layers/sampler.py

@@ -83,14 +83,16 @@ class Sampler(nn.Module):
         self._sampling_tensors = None
 
         # Initialize new sampling tensors
-        (sampling_tensors, do_penalties, do_temperatures, do_top_p_top_k,
-         do_top_as, do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
-         do_typical_ps, do_quadratic, do_xtc, do_nsigmas, do_dry, do_temp_last
+        (sampling_tensors, do_penalties, do_no_repeat_ngrams, do_temperatures,
+         do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
+         do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc, do_nsigmas,
+         do_dry, do_temp_last
          ) = SamplingTensors.from_sampling_metadata(
              sampling_metadata, vocab_size, logits.device, logits.dtype)
 
         self._sampling_tensors = sampling_tensors
         self._do_penalties = do_penalties
+        self._do_no_repeat_ngrams = do_no_repeat_ngrams
         self._do_temperatures = do_temperatures
         self._do_top_p_top_k = do_top_p_top_k
         self._do_top_as = do_top_as
@@ -131,6 +133,7 @@ class Sampler(nn.Module):
         assert self._sampling_tensors is not None
         sampling_tensors = self._sampling_tensors
         do_penalties = self._do_penalties
+        do_no_repeat_ngrams = self._do_no_repeat_ngrams
         do_temperatures = self._do_temperatures
         do_top_p_top_k = self._do_top_p_top_k
         do_top_as = self._do_top_as
@@ -164,6 +167,12 @@ class Sampler(nn.Module):
                                       sampling_tensors.presence_penalties,
                                       sampling_tensors.frequency_penalties,
                                       sampling_tensors.repetition_penalties)
+        
+        if do_no_repeat_ngrams:
+            logits = _apply_no_repeat_ngram(
+                logits,
+                sampling_tensors.prompt_tokens,
+                sampling_tensors.no_repeat_ngram_sizes)
 
         # Apply temperature scaling if not doing temp_last.
         if do_temperatures and not do_temp_last:
@@ -506,6 +515,39 @@ def _apply_dry(
 
     return logits
 
+def _apply_no_repeat_ngram(
+    logits: torch.Tensor,
+    input_ids: torch.Tensor,
+    ngram_size: torch.Tensor,
+) -> torch.Tensor:
+    """Apply no-repeat-ngram penalty which sets logits to -inf for tokens that 
+    would create a repeated n-gram.
+    """
+    if torch.all(ngram_size == 0):
+        return logits
+
+    batch_size = logits.shape[0]
+
+    for i in range(batch_size):
+        size = int(ngram_size[i].item())
+        if size == 0:
+            continue
+
+        cur_len = len(input_ids[i])
+        if cur_len < size:
+            continue
+
+        banned_tokens = _calc_banned_ngram_tokens(
+            ngram_size=size,
+            prev_input_ids=input_ids[i],
+            cur_len=cur_len-1
+        )
+
+        if banned_tokens:
+            logits[i, banned_tokens] = -float("inf")
+
+    return logits
+
 def _apply_top_k_top_p(
     logits: torch.Tensor,
     p: torch.Tensor,
@@ -1606,6 +1648,84 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
         next_token_index_start:next_token_index_end]
     return next_prompt_tokens
 
+def _get_ngrams(
+    ngram_size: int, 
+    prev_input_ids: torch.Tensor
+) -> Dict[Tuple[int, ...], List[int]]:
+    """Get dictionary of ngrams and the tokens that followed them.
+
+    Args:
+        ngram_size: Size of ngrams to track
+        prev_input_ids: 1D tensor of previous token ids
+
+    Returns:
+        Dictionary mapping ngram tuples to list of tokens that followed them
+    """
+    generated_ngrams = {}
+    gen_tokens = prev_input_ids.tolist()
+
+    for i in range(len(gen_tokens) - ngram_size + 1):
+        ngram = tuple(gen_tokens[i:i + ngram_size - 1])
+        next_token = gen_tokens[i + ngram_size - 1]
+        if ngram in generated_ngrams:
+            generated_ngrams[ngram].append(next_token)
+        else:
+            generated_ngrams[ngram] = [next_token]
+
+    return generated_ngrams
+
+def _get_generated_ngrams(
+    banned_ngrams: Dict[Tuple[int, ...], List[int]], 
+    prev_input_ids: torch.Tensor,
+    ngram_size: int, 
+    cur_len: int
+) -> List[int]:
+    """Get list of tokens that would create a repeated ngram if generated next.
+
+    Args:
+        banned_ngrams: Dictionary of previously seen ngrams and their next
+            tokens
+        prev_input_ids: Previous token ids
+        ngram_size: Size of ngrams to check
+        cur_len: Current position in sequence
+
+    Returns:
+        List of token ids that would create a repeat ngram
+    """
+    start_idx = cur_len + 1 - ngram_size
+    current_ngram = tuple(prev_input_ids[start_idx:cur_len].tolist())
+
+    return banned_ngrams.get(current_ngram, [])
+
+def _calc_banned_ngram_tokens(
+    ngram_size: int,
+    prev_input_ids: torch.Tensor,
+    cur_len: int
+) -> List[int]:
+    """Calculate tokens that would create repeated ngrams if generated next.
+
+    Args:
+        ngram_size: Size of ngrams to prevent repeating
+        prev_input_ids: Previous token ids in sequence
+        cur_len: Current position in sequence
+
+    Returns:
+        List of token ids that should be banned to prevent ngram repetition
+    """
+    if cur_len + 1 < ngram_size:
+        return []
+
+    generated_ngrams = _get_ngrams(ngram_size, prev_input_ids)
+
+    banned_tokens = _get_generated_ngrams(
+        generated_ngrams,
+        prev_input_ids, 
+        ngram_size,
+        cur_len
+    )
+
+    return banned_tokens
+
 
 # def _apply_mirostat_v2(logits: torch.Tensor,
 #                        sampling_tensors: SamplingTensors) -> torch.Tensor:

+ 27 - 13
aphrodite/modeling/sampling_metadata.py

@@ -379,6 +379,7 @@ class SamplingTensors:
     presence_penalties: torch.Tensor
     frequency_penalties: torch.Tensor
     repetition_penalties: torch.Tensor
+    no_repeat_ngram_sizes: torch.Tensor
     tfss: torch.Tensor
     eta_cutoffs: torch.Tensor
     epsilon_cutoffs: torch.Tensor
@@ -409,7 +410,7 @@ class SamplingTensors:
         extra_seeds_to_generate: int = 0,
         extra_entropy: Optional[Tuple[int, ...]] = None
     ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
-               bool, bool, bool, bool, bool, bool, bool]:
+               bool, bool, bool, bool, bool, bool, bool, bool]:
         """
         extra_seeds_to_generate: extra seeds to generate using the
             user-defined seed for each sequence.
@@ -429,6 +430,7 @@ class SamplingTensors:
         presence_penalties: List[float] = []
         frequency_penalties: List[float] = []
         repetition_penalties: List[float] = []
+        no_repeat_ngram_sizes: List[int] = []
         tfss: List[float] = []
         eta_cutoffs: List[float] = []
         epsilon_cutoffs: List[float] = []
@@ -446,6 +448,7 @@ class SamplingTensors:
         dry_sequence_breaker_ids: List[List[int]] = []
 
         do_penalties = False
+        do_no_repeat_ngrams = False
         do_temperatures = False
         do_top_p_top_k = False
         do_top_as = False
@@ -493,6 +496,7 @@ class SamplingTensors:
             do_penalties |= (abs(params.presence_penalty) >= _SAMPLING_EPS or
                              abs(params.frequency_penalty) >= _SAMPLING_EPS or
                              params.repetition_penalty > 1.0)
+            do_no_repeat_ngrams |= params.no_repeat_ngram_size > 0
             do_tfss |= params.tfs < 1.0 - _SAMPLING_EPS
             do_eta_cutoffs |= params.eta_cutoff > _SAMPLING_EPS
             do_epsilon_cutoffs |= params.epsilon_cutoff > _SAMPLING_EPS
@@ -529,6 +533,7 @@ class SamplingTensors:
             presence_penalties += [params.presence_penalty] * n_seqs
             frequency_penalties += [params.frequency_penalty] * n_seqs
             repetition_penalties += [params.repetition_penalty] * n_seqs
+            no_repeat_ngram_sizes += [params.no_repeat_ngram_size] * n_seqs
             tfss += [params.tfs] * n_seqs
             eta_cutoffs += [params.eta_cutoff] * n_seqs
             epsilon_cutoffs += [params.epsilon_cutoff] * n_seqs
@@ -566,7 +571,7 @@ class SamplingTensors:
                     sampling_seeds.append(seq_seeds)
                 sample_indices.extend(seq_group.sample_indices)
 
-        if do_penalties or do_dry:
+        if do_penalties or do_dry or do_no_repeat_ngrams:
             for seq_group in sampling_metadata.seq_groups:
                 seq_ids = seq_group.seq_ids
                 if (seq_group.is_prompt
@@ -588,16 +593,16 @@ class SamplingTensors:
             temperatures, dynatemp_mins, dynatemp_maxs, dynatemp_exps,
             temperature_lasts, top_ps, top_ks, top_as, min_ps,
             presence_penalties, frequency_penalties, repetition_penalties,
-            tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
-            smoothing_curves, xtc_thresholds, xtc_probabilities, nsigmas,
-            dry_multipliers, dry_bases, dry_allowed_lengths,
-            dry_sequence_breaker_ids, sampling_seeds, sample_indices,
-            prompt_tokens, output_tokens, vocab_size, extra_seeds_to_generate,
-            device, dtype)
-        return (sampling_tensors, do_penalties, do_temperatures,
-                do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
-                do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc,
-                do_nsigmas, do_dry, do_temp_last)
+            no_repeat_ngram_sizes, tfss, eta_cutoffs, epsilon_cutoffs,
+            typical_ps, smoothing_factors, smoothing_curves, xtc_thresholds,
+            xtc_probabilities, nsigmas, dry_multipliers, dry_bases,
+            dry_allowed_lengths, dry_sequence_breaker_ids, sampling_seeds,
+            sample_indices, prompt_tokens, output_tokens, vocab_size,
+            extra_seeds_to_generate, device, dtype)
+        return (sampling_tensors, do_penalties, do_no_repeat_ngrams,
+                do_temperatures, do_top_p_top_k, do_top_as, do_min_p,
+                do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
+                do_quadratic, do_xtc, do_nsigmas, do_dry, do_temp_last)
 
     @classmethod
     def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
@@ -606,7 +611,8 @@ class SamplingTensors:
                    top_ks: List[int], top_as: List[float],
                    min_ps: List[float], presence_penalties: List[float],
                    frequency_penalties: List[float],
-                   repetition_penalties: List[float], tfss: List[float],
+                   repetition_penalties: List[float],
+                   no_repeat_ngram_sizes: List[int], tfss: List[float],
                    eta_cutoffs: List[float], epsilon_cutoffs: List[float],
                    typical_ps: List[float], smoothing_factors: List[float],
                    smoothing_curves: List[float], xtc_thresholds: List[float],
@@ -708,6 +714,12 @@ class SamplingTensors:
             dtype=dtype,
             pin_memory=pin_memory,
         )
+        no_repeat_ngram_sizes_t = torch.tensor(
+            no_repeat_ngram_sizes,
+            device="cpu",
+            dtype=torch.int,
+            pin_memory=pin_memory,
+        )
         top_ks_t = torch.tensor(
             top_ks,
             device="cpu",
@@ -819,6 +831,8 @@ class SamplingTensors:
                                                          non_blocking=True),
             repetition_penalties=repetition_penalties_t.to(device=device,
                                                            non_blocking=True),
+            no_repeat_ngram_sizes=no_repeat_ngram_sizes_t.to(device=device,
+                                                             non_blocking=True),
             tfss=tfss_t.to(device=device, non_blocking=True),
             eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
             epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,

+ 73 - 0
tests/samplers/test_sampler.py

@@ -732,6 +732,79 @@ def test_sampler_repetition_penalty_mixed(device: str):
     assert tokens1[1] == tokens2[0]
 
 
+@pytest.mark.parametrize("seed", RANDOM_SEEDS)
+@pytest.mark.parametrize("device", CUDA_DEVICES)
+def test_sampler_no_repeat_ngram(seed: int, device: str):
+    """Test that no-repeat-ngram sampling behaves as expected."""
+    set_random_seed(seed)
+    torch.set_default_device(device)
+    batch_size = random.randint(1, 256)
+    _, fake_logits, sampler = _prepare_test(batch_size)
+
+    test_sequences = {
+        # Format: sequence: [tokens_that_should_be_blocked]
+        (1, 2, 3): [3],  # With ngram_size=2, should block 3 after [2]
+        (4, 5, 4, 5): [4],  # With ngram_size=2, should block 4 after [5]
+        (6, 7, 8, 6, 7): [8],  # With ngram_size=3, should block 8 after [6, 7]
+        (1, 2, 3, 4, 1, 2): [3],  # With ngram_size=4, should block 3 after [1, 2]  # noqa: E501
+    }
+
+    for input_seq, blocked_tokens in test_sequences.items():
+        for ngram_size in [2, 3, 4]:
+            sampling_params = SamplingParams(
+                temperature=1.0,
+                no_repeat_ngram_size=ngram_size,
+                seed=random.randint(0, 10000),
+            )
+
+            sampler_output = _do_sample(
+                1, 
+                fake_logits[0:1].clone(),  # Just use first row
+                sampler,
+                sampling_params,
+                device
+            )
+
+            if len(input_seq) >= ngram_size:
+                # check if blocked tokens have -inf logits
+                for token in blocked_tokens:
+                    assert sampler_output[0].samples[0].output_token != token, \
+                        f"Token {token} should have been blocked by {ngram_size}-gram repetition prevention"  # noqa: E501
+
+        # disabled
+        sampling_params = SamplingParams(
+            temperature=1.0,
+            no_repeat_ngram_size=0,
+            seed=random.randint(0, 10000),
+        )
+
+        sampler_output = _do_sample(
+            1,
+            fake_logits[0:1].clone(),
+            sampler,
+            sampling_params,
+            device
+        )
+
+        output_token = sampler_output[0].samples[0].output_token
+        assert output_token is not None, "Should produce output token with ngram_size=0"  # noqa: E501
+
+    # determinism
+    sampling_params = SamplingParams(
+        temperature=1.0,
+        no_repeat_ngram_size=3,
+        seed=random.randint(0, 10000),
+    )
+
+    first_output = _do_sample(batch_size, fake_logits.clone(), sampler,
+                             sampling_params, device)
+    second_output = _do_sample(batch_size, fake_logits.clone(), sampler,
+                              sampling_params, device)
+
+    assert first_output == second_output, \
+        "No-repeat-ngram sampling is not deterministic with same seed"
+
+
 @pytest.mark.parametrize("seed", RANDOM_SEEDS)
 @pytest.mark.parametrize("device", CUDA_DEVICES)
 def test_sampler_nsigma(seed: int, device: str):