logits_processor.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from abc import ABC, abstractmethod
  2. import torch
  3. from typing import Dict, List
  4. class LogitsProcessor(ABC):
  5. @abstractmethod
  6. def __call__(self, logits: torch.Tensor,
  7. output_tokens: List[List[int]]) -> None:
  8. """Logits are edited in-place"""
  9. pass
  10. class BiasLogitsProcessor(LogitsProcessor):
  11. """This is to enable logit_bias in the OpenAI server.
  12. biases is a dict where each value is -100 to 100
  13. according to the OpenAI API docs.
  14. Args:
  15. biases: Dict of values from -100 to 100 to scale the
  16. probability of a token being generated.
  17. Each key of the dict corresponds to the the token id.
  18. """
  19. def __init__(self, biases: Dict[int, float]):
  20. super().__init__()
  21. self.biases = biases
  22. if not biases:
  23. return
  24. self.keys = torch.tensor(list(self.biases.keys()), dtype=torch.long)
  25. self.values = torch.tensor(list(self.biases.values()),
  26. dtype=torch.long)
  27. def __call__(self, logits, output_tokens):
  28. if not self.biases:
  29. return
  30. values = self.values.to(logits.device)
  31. keys = self.keys.to(logits.device)
  32. update_factors = torch.where(values >= 0, 1 + (values / 100),
  33. 1 / (1 - (values / 100)))
  34. logits[0, keys] *= update_factors
  35. class BanEOSUntil(LogitsProcessor):
  36. """Bans the EOS token until a certain condition is met.
  37. In this case, 'number of output tokens'.
  38. With this condition, both 'min_tokens' and 'ignore_eos'
  39. parameters can be handled gracefully."""
  40. def __init__(self, min_tokens: int, eos_token_id: int):
  41. super().__init__()
  42. self._min_tokens = min_tokens
  43. self._eos_token_id = eos_token_id
  44. def __call__(self, logits, output_tokens):
  45. for i in range(len(output_tokens)):
  46. if len(output_tokens[i]) < self._min_tokens:
  47. logits[i][self._eos_token_id] = -float("inf")