Browse Source

sampler: fix dry concurrency issue (#852)

AlpinDale 3 months ago
parent
commit
72c505ad84
1 changed files with 6 additions and 4 deletions
  1. 6 4
      aphrodite/modeling/layers/sampler.py

+ 6 - 4
aphrodite/modeling/layers/sampler.py

@@ -625,7 +625,7 @@ def _apply_dry(
     sequence_breakers_ids: torch.Tensor
 ) -> torch.Tensor:
     """
-    Apply Exclude Don't Repeat Yourself (DRY) sampling to the logits.
+    Apply Don't Repeat Yourself (DRY) sampling to the logits.
 
     Reference: https://github.com/oobabooga/text-generation-webui/pull/5677
     """
@@ -635,6 +635,8 @@ def _apply_dry(
 
     # we need to apply dry to both input and output tokens
     input_ids = torch.cat((input_token_ids, output_token_ids), dim=1)
+    vocab_size = logits.size(-1)
+    
     # Process each sequence in the batch
     for i, (input_ids_row, logits_row) in enumerate(zip(input_ids, logits)):
         multiplier = multipliers[i].item()
@@ -661,8 +663,8 @@ def _apply_dry(
             # Get the token that followed this match in the input
             next_token = input_ids_row[idx + 1].item()
 
-            # Skip if next token is a sequence breaker
-            if next_token in sequence_breakers_ids:
+            # Skip if next token is a sequence breaker or out of vocab range
+            if next_token in sequence_breakers_ids or next_token >= vocab_size:
                 continue
 
             # We found last_token matches at this index, so match length starts
@@ -700,7 +702,7 @@ def _apply_dry(
         base = bases[i]
 
         for token, match_length in match_lengths.items():
-            if match_length >= allowed_length:
+            if match_length >= allowed_length and token < vocab_size:
                 penalty = multiplier * (base ** (match_length - allowed_length))
                 logits_row[token] -= penalty