|
@@ -258,8 +258,8 @@ def _apply_alphabet_soup(
|
|
|
logits_sort[mask] = -float("inf")
|
|
|
|
|
|
# Apply top-k.
|
|
|
- for i,k in enumerate(k):
|
|
|
- logits_sort[i, k:] = -float("inf")
|
|
|
+ for i,topk in enumerate(k):
|
|
|
+ logits_sort[i, topk:] = -float("inf")
|
|
|
|
|
|
# Re-sort the probabilities.
|
|
|
src = torch.arange(logits_idx.shape[-1],
|