1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- from typing import List, Tuple, Optional, Dict
- from dataclasses import dataclass
- from abc import ABC, abstractmethod
- import torch
- from aphrodite.common.sequence import SequenceGroupMetadata
- @dataclass
- class SpeculativeProposals:
- """Datastructure used to represent proposal tokens from some proposer. It
- also tracks how many speculative tokens each sequence has.
- """
- # Speculative proposal tokens.
- proposal_token_ids: torch.Tensor
- # Probabilities of the proposal tokens according to the proposer.
- proposal_probs: torch.Tensor
- # The valid length of each proposal; can be zero.
- proposal_lens: torch.Tensor
- def __repr__(self):
- return (f"SpeculativeProposals("
- f"proposal_token_ids={self.proposal_token_ids.shape}, "
- f"proposal_probs={self.proposal_probs.shape}, "
- f"proposal_lens={self.proposal_lens.shape})")
- @dataclass
- class SpeculativeScores:
- """Datastructure used to represent the scores of speculative tokens
- according to the scoring model.
- """
- # Probabilities of the speculative tokens according to the scoring model.
- probs: torch.Tensor
- # Token ids sampled from the scoring model. Used for speculative bonus
- # tokens and also non-speculative normal decoding.
- token_ids: torch.Tensor
- def __repr__(self):
- return (f"SpeculativeScores("
- f"probs={self.probs.shape}, "
- f"token_ids={self.token_ids.shape})")
- class SpeculativeProposer(ABC):
- @abstractmethod
- def get_proposals(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- blocks_to_swap_in: Dict[int, int],
- blocks_to_swap_out: Dict[int, int],
- blocks_to_copy: Dict[int, List[int]],
- max_proposal_len: int,
- ) -> SpeculativeProposals:
- raise NotImplementedError
- class SpeculativeScorer(ABC):
- @abstractmethod
- def score_proposals(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- blocks_to_swap_in: Optional[Dict[int, int]],
- blocks_to_swap_out: Optional[Dict[int, int]],
- blocks_to_copy: Optional[Dict[int, List[int]]],
- k: int,
- proposals: SpeculativeProposals,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- raise NotImplementedError
|