logits_processor.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from abc import ABC, abstractmethod
  2. import torch
  3. from typing import Dict, List
  4. class LogitsProcessor(ABC):
  5. @abstractmethod
  6. def __call__(self, logits: torch.Tensor,
  7. output_tokens: List[List[int]]) -> None:
  8. """Logits are edited in-place"""
  9. pass
  10. class BiasLogitsProcessor(LogitsProcessor):
  11. """This is to enable logit_bias in the OpenAI server,
  12. an additive bias on the original logit values.
  13. Args:
  14. biases: Dict of bias values. Each key corresponds to the the token id.
  15. """
  16. def __init__(self, biases: Dict[int, float]):
  17. super().__init__()
  18. self.biases = biases
  19. if not biases:
  20. return
  21. self.keys = torch.tensor(list(self.biases.keys()), dtype=torch.long)
  22. self.values = torch.tensor(list(self.biases.values()),
  23. dtype=torch.float)
  24. def __call__(self, logits: torch.Tensor,
  25. output_tokens: List[List[int]]) -> None:
  26. if not self.biases:
  27. return
  28. values = self.values.to(logits.device)
  29. keys = self.keys.to(logits.device)
  30. logits[0, keys] += values
  31. class BanEOSUntil(LogitsProcessor):
  32. """Bans the EOS token until a certain condition is met.
  33. In this case, 'number of output tokens'.
  34. With this condition, both 'min_tokens' and 'ignore_eos'
  35. parameters can be handled gracefully."""
  36. def __init__(self, min_tokens: int, eos_token_id: int):
  37. super().__init__()
  38. self._min_tokens = min_tokens
  39. self._eos_token_id = eos_token_id
  40. def __call__(self, logits: torch.Tensor,
  41. output_tokens: List[List[int]]) -> None:
  42. for i in range(len(output_tokens)):
  43. if len(output_tokens[i]) < self._min_tokens:
  44. logits[i][self._eos_token_id] = -float("inf")