logits_processors.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from functools import lru_cache, partial
  2. from typing import Dict, FrozenSet, Iterable, List, Optional, Union
  3. import torch
  4. from transformers import PreTrainedTokenizer
  5. from aphrodite.common.sampling_params import LogitsProcessorFunc
  6. class AllowedTokenIdsLogitsProcessor:
  7. """Logits processor for constraining generated tokens to a
  8. specific set of token ids."""
  9. def __init__(self, allowed_ids: Iterable[int]):
  10. self.allowed_ids: Optional[List[int]] = list(allowed_ids)
  11. self.mask: Optional[torch.Tensor] = None
  12. def __call__(self, token_ids: List[int],
  13. logits: torch.Tensor) -> torch.Tensor:
  14. if self.mask is None:
  15. self.mask = torch.ones((logits.shape[-1], ),
  16. dtype=torch.bool,
  17. device=logits.device)
  18. self.mask[self.allowed_ids] = False
  19. self.allowed_ids = None
  20. logits.masked_fill_(self.mask, float("-inf"))
  21. return logits
  22. @lru_cache(maxsize=32)
  23. def _get_allowed_token_ids_logits_processor(
  24. allowed_token_ids: FrozenSet[int],
  25. vocab_size: int,
  26. ) -> LogitsProcessorFunc:
  27. if not allowed_token_ids:
  28. raise ValueError("Empty allowed_token_ids provided")
  29. if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
  30. raise ValueError("allowed_token_ids contains "
  31. "out-of-vocab token id")
  32. return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
  33. def logit_bias_logits_processor(logit_bias: Dict[str,
  34. float], token_ids: List[int],
  35. logits: torch.Tensor) -> torch.Tensor:
  36. for token_id, bias in logit_bias.items():
  37. logits[token_id] += bias
  38. return logits
  39. def get_logits_processors(
  40. logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
  41. allowed_token_ids: Optional[List[int]],
  42. tokenizer: PreTrainedTokenizer) -> List[LogitsProcessorFunc]:
  43. logits_processors = []
  44. if logit_bias:
  45. try:
  46. # Convert token_id to integer
  47. # Clamp the bias between -100 and 100 per OpenAI API spec
  48. clamped_logit_bias: Dict[int, float] = {
  49. int(token_id): min(100.0, max(-100.0, bias))
  50. for token_id, bias in logit_bias.items()
  51. }
  52. except ValueError as exc:
  53. raise ValueError(
  54. "Found token_id in logit_bias that is not "
  55. "an integer or string representing an integer") from exc
  56. # Check if token_id is within the vocab size
  57. for token_id, bias in clamped_logit_bias.items():
  58. if token_id < 0 or token_id >= tokenizer.vocab_size:
  59. raise ValueError("token_id in logit_bias contains "
  60. "out-of-vocab token id")
  61. logits_processors.append(
  62. partial(logit_bias_logits_processor, clamped_logit_bias))
  63. if allowed_token_ids is not None:
  64. logits_processors.append(
  65. _get_allowed_token_ids_logits_processor(
  66. frozenset(allowed_token_ids), tokenizer.vocab_size))
  67. return logits_processors