logits_processor.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from abc import ABC, abstractmethod
  2. from typing import Dict, List
  3. import torch
  4. class LogitsProcessor(ABC):
  5. @abstractmethod
  6. def __call__(self, output_tokens: List[int],
  7. logits: torch.Tensor) -> torch.Tensor:
  8. """Logits are edited in-place"""
  9. pass
  10. @abstractmethod
  11. def batched(self, logits: torch.Tensor,
  12. output_tokens: List[List[int]]) -> None:
  13. """Logits are edited in-place"""
  14. pass
  15. class BiasLogitsProcessor(LogitsProcessor):
  16. """Apply an additive bias to specific token logits.
  17. Args:
  18. biases: Dict of bias values. Each key corresponds to the the token id.
  19. """
  20. def __init__(self, biases: Dict[int, float]):
  21. assert biases
  22. self.biases = biases
  23. self.keys = torch.tensor(list(self.biases.keys()), dtype=torch.long)
  24. self.values = torch.tensor(list(self.biases.values()),
  25. dtype=torch.float)
  26. def __call__(self, output_tokens: List[int],
  27. logits: torch.Tensor) -> torch.Tensor:
  28. values = self.values.to(logits.device)
  29. keys = self.keys.to(logits.device)
  30. logits[keys] += values
  31. return logits
  32. def batched(self, logits: torch.Tensor,
  33. output_tokens: List[List[int]]) -> None:
  34. values = self.values.to(logits.device)
  35. keys = self.keys.to(logits.device)
  36. logits[:, keys] += values
  37. class BanEOSUntil(LogitsProcessor):
  38. """Bans the EOS token until a certain condition is met.
  39. In this case, 'number of output tokens'.
  40. With this condition, both 'min_tokens' and 'ignore_eos'
  41. parameters can be handled gracefully."""
  42. def __init__(self, min_tokens: int, eos_token_id: int):
  43. self._min_tokens = min_tokens
  44. self._eos_token_id = eos_token_id
  45. def __call__(self, output_tokens: List[int],
  46. logits: torch.Tensor) -> torch.Tensor:
  47. if len(output_tokens) < self._min_tokens:
  48. logits[self._eos_token_id] = -float("inf")
  49. return logits
  50. def batched(self, logits: torch.Tensor,
  51. output_tokens: List[List[int]]) -> None:
  52. terminate_mask = torch.tensor(
  53. [len(toks) < self._min_tokens for toks in output_tokens],
  54. device=logits.device)
  55. logits[terminate_mask, self._eos_token_id] = -float("inf")