interfaces.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass
  3. import torch
  4. from aphrodite.common.sequence import ExecuteModelRequest
  5. @dataclass
  6. class SpeculativeProposals:
  7. """Datastructure used to represent proposal tokens from some proposer. It
  8. also tracks how many speculative tokens each sequence has.
  9. """
  10. # Speculative proposal tokens.
  11. proposal_token_ids: torch.Tensor
  12. # Probabilities of the proposal tokens according to the proposer.
  13. proposal_probs: torch.Tensor
  14. # The valid length of each proposal; can be zero.
  15. proposal_lens: torch.Tensor
  16. def __repr__(self):
  17. return (f"SpeculativeProposals("
  18. f"proposal_token_ids={self.proposal_token_ids}, "
  19. f"proposal_probs={self.proposal_probs.shape}, "
  20. f"proposal_lens={self.proposal_lens})")
  21. @dataclass
  22. class SpeculativeScores:
  23. """Datastructure used to represent the scores of speculative tokens
  24. according to the scoring model.
  25. """
  26. # Probabilities of the speculative tokens according to the scoring model.
  27. probs: torch.Tensor
  28. # Log-probabilities of the speculative tokens according to the scoring
  29. # model. These values can be used to generate Logprob objects that are
  30. # returned to the user.
  31. logprobs: torch.Tensor
  32. # Token ids sampled from the scoring model. Used for speculative bonus
  33. # tokens and also non-speculative normal decoding.
  34. token_ids: torch.Tensor
  35. def __repr__(self):
  36. return (f"SpeculativeScores("
  37. f"probs={self.probs.shape}, "
  38. f"token_ids={self.token_ids.shape})")
  39. class SpeculativeProposer(ABC):
  40. @abstractmethod
  41. def get_proposals(
  42. self,
  43. execute_model_req: ExecuteModelRequest,
  44. ) -> SpeculativeProposals:
  45. raise NotImplementedError
  46. class SpeculativeScorer(ABC):
  47. @abstractmethod
  48. def score_proposals(
  49. self,
  50. execute_model_req: ExecuteModelRequest,
  51. proposals: SpeculativeProposals,
  52. ) -> SpeculativeScores:
  53. raise NotImplementedError