Ver código fonte

fix: crash in token bans (#764)

* Update sampler.py

Fixes a crash from token bans.

* Update sampler.py

Removed hashed out note that was causing linting failure.
Pyroserenus 5 meses atrás
pai
commit
99fc6f4697
1 arquivos alterados com 2 adições e 0 exclusões
  1. 2 0
      aphrodite/modeling/layers/sampler.py

+ 2 - 0
aphrodite/modeling/layers/sampler.py

@@ -330,6 +330,8 @@ def _apply_temperatures(
 def _apply_token_bans(logits: torch.Tensor,
                       banned_tokens: List[List[int]]) -> torch.Tensor:
     for i, banned_token_ids in enumerate(banned_tokens):
+        if i >= logits.size(0):
+            break
         if not banned_token_ids:
             continue
         logits[i, banned_token_ids] = -float("inf")