|
@@ -625,7 +625,7 @@ def _apply_dry(
|
|
|
sequence_breakers_ids: torch.Tensor
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
- Apply Exclude Don't Repeat Yourself (DRY) sampling to the logits.
|
|
|
+ Apply Don't Repeat Yourself (DRY) sampling to the logits.
|
|
|
|
|
|
Reference: https://github.com/oobabooga/text-generation-webui/pull/5677
|
|
|
"""
|
|
@@ -635,6 +635,8 @@ def _apply_dry(
|
|
|
|
|
|
# we need to apply dry to both input and output tokens
|
|
|
input_ids = torch.cat((input_token_ids, output_token_ids), dim=1)
|
|
|
+ vocab_size = logits.size(-1)
|
|
|
+
|
|
|
# Process each sequence in the batch
|
|
|
for i, (input_ids_row, logits_row) in enumerate(zip(input_ids, logits)):
|
|
|
multiplier = multipliers[i].item()
|
|
@@ -661,8 +663,8 @@ def _apply_dry(
|
|
|
# Get the token that followed this match in the input
|
|
|
next_token = input_ids_row[idx + 1].item()
|
|
|
|
|
|
- # Skip if next token is a sequence breaker
|
|
|
- if next_token in sequence_breakers_ids:
|
|
|
+ # Skip if next token is a sequence breaker or out of vocab range
|
|
|
+ if next_token in sequence_breakers_ids or next_token >= vocab_size:
|
|
|
continue
|
|
|
|
|
|
# We found last_token matches at this index, so match length starts
|
|
@@ -700,7 +702,7 @@ def _apply_dry(
|
|
|
base = bases[i]
|
|
|
|
|
|
for token, match_length in match_lengths.items():
|
|
|
- if match_length >= allowed_length:
|
|
|
+ if match_length >= allowed_length and token < vocab_size:
|
|
|
penalty = multiplier * (base ** (match_length - allowed_length))
|
|
|
logits_row[token] -= penalty
|
|
|
|