interfaces.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass
  3. from typing import Dict, List, Optional
  4. import torch
  5. from aphrodite.common.sequence import SequenceGroupMetadata
  6. @dataclass
  7. class SpeculativeProposals:
  8. """Datastructure used to represent proposal tokens from some proposer. It
  9. also tracks how many speculative tokens each sequence has.
  10. """
  11. # Speculative proposal tokens.
  12. proposal_token_ids: torch.Tensor
  13. # Probabilities of the proposal tokens according to the proposer.
  14. proposal_probs: torch.Tensor
  15. # The valid length of each proposal; can be zero.
  16. proposal_lens: torch.Tensor
  17. def __repr__(self):
  18. return (f"SpeculativeProposals("
  19. f"proposal_token_ids={self.proposal_token_ids}, "
  20. f"proposal_probs={self.proposal_probs.shape}, "
  21. f"proposal_lens={self.proposal_lens})")
  22. @dataclass
  23. class SpeculativeScores:
  24. """Datastructure used to represent the scores of speculative tokens
  25. according to the scoring model.
  26. """
  27. # Probabilities of the speculative tokens according to the scoring model.
  28. probs: torch.Tensor
  29. # Token ids sampled from the scoring model. Used for speculative bonus
  30. # tokens and also non-speculative normal decoding.
  31. token_ids: torch.Tensor
  32. def __repr__(self):
  33. return (f"SpeculativeScores("
  34. f"probs={self.probs.shape}, "
  35. f"token_ids={self.token_ids.shape})")
  36. class SpeculativeProposer(ABC):
  37. @abstractmethod
  38. def get_proposals(
  39. self,
  40. seq_group_metadata_list: List[SequenceGroupMetadata],
  41. blocks_to_swap_in: Dict[int, int],
  42. blocks_to_swap_out: Dict[int, int],
  43. blocks_to_copy: Dict[int, List[int]],
  44. max_proposal_len: int,
  45. ) -> SpeculativeProposals:
  46. raise NotImplementedError
  47. class SpeculativeScorer(ABC):
  48. @abstractmethod
  49. def score_proposals(
  50. self,
  51. seq_group_metadata_list: List[SequenceGroupMetadata],
  52. blocks_to_swap_in: Optional[Dict[int, int]],
  53. blocks_to_swap_out: Optional[Dict[int, int]],
  54. blocks_to_copy: Optional[Dict[int, List[int]]],
  55. k: int,
  56. proposals: SpeculativeProposals,
  57. ) -> SpeculativeScores:
  58. raise NotImplementedError