|
@@ -330,6 +330,8 @@ def _apply_temperatures(
|
|
|
def _apply_token_bans(logits: torch.Tensor,
|
|
|
banned_tokens: List[List[int]]) -> torch.Tensor:
|
|
|
for i, banned_token_ids in enumerate(banned_tokens):
|
|
|
+ if i >= logits.size(0):
|
|
|
+ break
|
|
|
if not banned_token_ids:
|
|
|
continue
|
|
|
logits[i, banned_token_ids] = -float("inf")
|