12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- from abc import ABC, abstractmethod
- from dataclasses import dataclass
- import torch
- from aphrodite.common.sequence import ExecuteModelRequest
- @dataclass
- class SpeculativeProposals:
- """Datastructure used to represent proposal tokens from some proposer. It
- also tracks how many speculative tokens each sequence has.
- """
-
- proposal_token_ids: torch.Tensor
-
- proposal_probs: torch.Tensor
-
- proposal_lens: torch.Tensor
- def __repr__(self):
- return (f"SpeculativeProposals("
- f"proposal_token_ids={self.proposal_token_ids}, "
- f"proposal_probs={self.proposal_probs.shape}, "
- f"proposal_lens={self.proposal_lens})")
- @dataclass
- class SpeculativeScores:
- """Datastructure used to represent the scores of speculative tokens
- according to the scoring model.
- """
-
- probs: torch.Tensor
-
-
-
- logprobs: torch.Tensor
-
-
- token_ids: torch.Tensor
- def __repr__(self):
- return (f"SpeculativeScores("
- f"probs={self.probs.shape}, "
- f"token_ids={self.token_ids.shape})")
- class SpeculativeProposer(ABC):
- @abstractmethod
- def get_spec_proposals(
- self,
- execute_model_req: ExecuteModelRequest,
- ) -> SpeculativeProposals:
- raise NotImplementedError
- class SpeculativeScorer(ABC):
- @abstractmethod
- def score_proposals(
- self,
- execute_model_req: ExecuteModelRequest,
- proposals: SpeculativeProposals,
- ) -> SpeculativeScores:
- raise NotImplementedError
|