interfaces.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass
  3. from typing import Optional, Set
  4. import torch
  5. from aphrodite.common.sequence import ExecuteModelRequest
  6. @dataclass
  7. class SpeculativeProposals:
  8. """Datastructure used to represent proposal tokens from some proposer. It
  9. also tracks how many speculative tokens each sequence has.
  10. """
  11. # Speculative proposal tokens.
  12. proposal_token_ids: torch.Tensor
  13. # Probabilities of the proposal tokens according to the proposer.
  14. proposal_probs: torch.Tensor
  15. # The valid length of each proposal; can be zero.
  16. proposal_lens: torch.Tensor
  17. # A flag to mark that there's no available proposals
  18. no_proposals: bool = False
  19. def __repr__(self):
  20. return (f"SpeculativeProposals("
  21. f"proposal_token_ids={self.proposal_token_ids}, "
  22. f"proposal_probs={self.proposal_probs.shape}, "
  23. f"proposal_lens={self.proposal_lens})")
  24. @dataclass
  25. class SpeculativeScores:
  26. """Datastructure used to represent the scores of speculative tokens
  27. according to the scoring model.
  28. """
  29. # Probabilities of the speculative tokens according to the scoring model.
  30. probs: torch.Tensor
  31. # Log-probabilities of the speculative tokens according to the scoring
  32. # model. These values can be used to generate Logprob objects that are
  33. # returned to the user.
  34. logprobs: torch.Tensor
  35. # Token ids sampled from the scoring model. Used for speculative bonus
  36. # tokens and also non-speculative normal decoding.
  37. token_ids: torch.Tensor
  38. # Optional last hidden states from the scoring model.
  39. hidden_states: Optional[torch.Tensor] = None
  40. def __repr__(self):
  41. return (f"SpeculativeScores("
  42. f"probs={self.probs.shape}, "
  43. f"token_ids={self.token_ids.shape})")
  44. class SpeculativeProposer(ABC):
  45. @abstractmethod
  46. def get_spec_proposals(
  47. self,
  48. execute_model_req: ExecuteModelRequest,
  49. # If set, this contains all sequence IDs that were assigned
  50. # bonus tokens in their last forward pass.
  51. seq_ids_with_bonus_token_in_last_step: Set[int],
  52. ) -> SpeculativeProposals:
  53. raise NotImplementedError
  54. class SpeculativeScorer(ABC):
  55. @abstractmethod
  56. def score_proposals(
  57. self,
  58. execute_model_req: ExecuteModelRequest,
  59. proposals: SpeculativeProposals,
  60. ) -> SpeculativeScores:
  61. raise NotImplementedError