1
0

logits_processor.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from abc import ABC, abstractmethod
  2. import torch
  3. from typing import Dict, List
  4. class LogitsProcessor(ABC):
  5. @abstractmethod
  6. def __call__(self, logits: torch.Tensor,
  7. output_tokens: List[List[int]]) -> None:
  8. """Logits are edited in-place"""
  9. pass
  10. class BiasLogitsProcessor(LogitsProcessor):
  11. """This is to enable logit_bias in the OpenAI server.
  12. biases is a dict where each value is -100 to 100
  13. according to the OpenAI API docs.
  14. Args:
  15. biases: Dict of values from -100 to 100 to scale the
  16. probability of a token being generated.
  17. Each key of the dict corresponds to the the token id.
  18. """
  19. def __init__(self, biases: Dict[int, float]):
  20. self.biases = biases
  21. if not biases:
  22. return
  23. self.keys = torch.tensor(list(self.biases.keys()), dtype=torch.long)
  24. self.values = torch.tensor(list(self.biases.values()),
  25. dtype=torch.long)
  26. def __call__(self, logits, output_tokens):
  27. if not self.biases:
  28. return
  29. values = self.values.to(logits.device)
  30. keys = self.keys.to(logits.device)
  31. update_factors = torch.where(values >= 0, 1 + (values / 100),
  32. 1 / (1 - (values / 100)))
  33. logits[0, keys] *= update_factors
  34. class BanEOSUntil(LogitsProcessor):
  35. """Bans the EOS token until a certain condition is met.
  36. In this case, 'number of output tokens'.
  37. With this condition, both 'min_tokens' and 'ignore_eos'
  38. parameters can be handled gracefully."""
  39. def __init__(self, min_tokens: int, eos_token_id: int):
  40. self._min_tokens = min_tokens
  41. self._eos_token_id = eos_token_id
  42. def __call__(self, logits, output_tokens):
  43. for i in range(len(output_tokens)):
  44. if len(output_tokens[i]) < self._min_tokens:
  45. logits[i][self._eos_token_id] = -float("inf")