logits.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from abc import ABC, abstractmethod
  2. import torch
  3. from typing import Dict
  4. class LogitsProcessor(ABC):
  5. @abstractmethod
  6. def __call__(self, logits: torch.Tensor, output_tokens: list[list[int]]) -> torch.Tensor:
  7. pass
  8. class BiasLogitsProcessor(LogitsProcessor):
  9. """This is to enable logit_bias in the OpenAI server.
  10. biases is a dict where each value is -100 to 100
  11. according to the OpenAI API docs.
  12. Args:
  13. biases: Dict ov values from -100 to 100 to scale the
  14. probability of a token being generated.
  15. Each key of the dict coresponds to the the token id.
  16. """
  17. def __init__(self, biases: Dict[int, float]):
  18. self.biases = biases
  19. if not biases:
  20. return
  21. self.keys = torch.tensor(list(self.biases.keys()), dtype=torch.long)
  22. self.values = torch.tensor(list(self.biases.values()),
  23. dtype=torch.long)
  24. def __call__(self, logits: torch.Tensor, output_tokens: list[list[int]]) -> torch.Tensor:
  25. if not self.biases:
  26. return logits
  27. values = self.values.to(logits.device)
  28. keys = self.keys.to(logits.device)
  29. update_factors = torch.where(values >= 0, 1 + (values / 100),
  30. 1 / (1 - (values / 100)))
  31. logits[0, keys] *= update_factors
  32. return logits