logits_processors.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from functools import lru_cache, partial
  2. from typing import Dict, FrozenSet, Iterable, List, Optional, Union
  3. import torch
  4. from transformers import PreTrainedTokenizer
  5. from aphrodite.common.sampling_params import LogitsProcessorFunc
  6. class AllowedTokenIdsLogitsProcessor:
  7. """Logits processor for constraining generated tokens to a
  8. specific set of token ids."""
  9. def __init__(self, allowed_ids: Iterable[int]):
  10. self.allowed_ids: Optional[List[int]] = list(allowed_ids)
  11. self.mask: Optional[torch.Tensor] = None
  12. def __call__(self, token_ids: List[int],
  13. logits: torch.Tensor) -> torch.Tensor:
  14. if self.mask is None:
  15. self.mask = torch.ones((logits.shape[-1], ),
  16. dtype=torch.bool,
  17. device=logits.device)
  18. self.mask[self.allowed_ids] = False
  19. self.allowed_ids = None
  20. logits.masked_fill_(self.mask, float("-inf"))
  21. return logits
  22. @lru_cache(maxsize=32)
  23. def _get_allowed_token_ids_logits_processor(
  24. allowed_token_ids: FrozenSet[int],
  25. vocab_size: int,
  26. ) -> LogitsProcessorFunc:
  27. if not allowed_token_ids:
  28. raise ValueError("Empty allowed_token_ids provided")
  29. if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
  30. raise ValueError("allowed_token_ids contains "
  31. "out-of-vocab token id")
  32. return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
  33. def logit_bias_logits_processor(
  34. logit_bias: Dict[int, float],
  35. token_ids: List[int],
  36. logits: torch.Tensor,
  37. ) -> torch.Tensor:
  38. for token_id, bias in logit_bias.items():
  39. logits[token_id] += bias
  40. return logits
  41. def get_logits_processors(
  42. logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
  43. allowed_token_ids: Optional[List[int]],
  44. tokenizer: PreTrainedTokenizer) -> List[LogitsProcessorFunc]:
  45. logits_processors = []
  46. if logit_bias:
  47. try:
  48. # Convert token_id to integer
  49. # Clamp the bias between -100 and 100 per OpenAI API spec
  50. clamped_logit_bias: Dict[int, float] = {
  51. int(token_id): min(100.0, max(-100.0, bias))
  52. for token_id, bias in logit_bias.items()
  53. }
  54. except ValueError as exc:
  55. raise ValueError(
  56. "Found token_id in logit_bias that is not "
  57. "an integer or string representing an integer") from exc
  58. # Check if token_id is within the vocab size
  59. for token_id, bias in clamped_logit_bias.items():
  60. if token_id < 0 or token_id >= tokenizer.vocab_size:
  61. raise ValueError("token_id in logit_bias contains "
  62. "out-of-vocab token id")
  63. logits_processors.append(
  64. partial(logit_bias_logits_processor, clamped_logit_bias))
  65. if allowed_token_ids is not None:
  66. logits_processors.append(
  67. _get_allowed_token_ids_logits_processor(
  68. frozenset(allowed_token_ids), tokenizer.vocab_size))
  69. return logits_processors