Selaa lähdekoodia

fix: crash in token bans (#764)

* Update sampler.py

Fixes a crash from token bans.

* Update sampler.py

Removed hashed out note that was causing linting failure.
Pyroserenus 5 kuukautta sitten
vanhempi
commit
99fc6f4697
1 muutettua tiedostoa jossa 2 lisäystä ja 0 poistoa
  1. 2 0
      aphrodite/modeling/layers/sampler.py

+ 2 - 0
aphrodite/modeling/layers/sampler.py

@@ -330,6 +330,8 @@ def _apply_temperatures(
 def _apply_token_bans(logits: torch.Tensor,
                       banned_tokens: List[List[int]]) -> torch.Tensor:
     for i, banned_token_ids in enumerate(banned_tokens):
+        if i >= logits.size(0):
+            break
         if not banned_token_ids:
             continue
         logits[i, banned_token_ids] = -float("inf")