123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273 |
- from typing import List, Optional, Set, Tuple
- import torch
- from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
- SequenceGroupMetadata)
- from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
- SpeculativeProposer)
- from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
- from aphrodite.spec_decode.util import sampler_output_to_torch
- class Top1Proposer(SpeculativeProposer):
- """Helper class which separates out sequences which would exceed the max
- model length when speculated upon.
- This allows combinations of models such as JackFram/llama-68m draft with
- meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
- 2048 while Llama2-13b has max_position_embeddings of 4096.
- We treat the sequences which exceed the proposal draft model length as
- "non-spec sequences". Essentially they skip the draft model and go through
- normal decoding in the target model.
- Currently, only proposal_lens of 0 and k are supported, where k is a global
- batch proposal length. In the future Aphrodite should support per-sequence
- proposal lengths.
- """
- def __init__(
- self,
- worker: ProposerWorkerBase,
- device: str,
- vocab_size: int,
- max_proposal_len: Optional[int] = None,
- ):
- self._worker = worker
- self._device = device
- self.max_proposal_len = max_proposal_len
- self._vocab_size = vocab_size
- def get_spec_proposals(
- self,
- execute_model_req: ExecuteModelRequest,
- seq_ids_with_bonus_token_in_last_step: Set[int],
- ) -> SpeculativeProposals:
- """Get speculative proposals given the input batch.
- Sequences which would exceed the max model length are skipped during
- speculation.
- """
- proposal_len = execute_model_req.num_lookahead_slots
- seq_group_metadata_list = execute_model_req.seq_group_metadata_list
- # Split speculative- and non-speculative- sequences.
- (
- proposal_lens,
- nonzero_proposal_len_seqs,
- nonzero_proposal_len_indices,
- ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
- if nonzero_proposal_len_seqs:
- # Speculate tokens using the draft worker for the speculative
- # sequences.
- # If sampler_transposed is true, then maybe_sampler_output's
- # token_ids is like [batch] format in proposal_len size list,
- # while if it is false, the format would be [proposal_len]
- # in batch size list
- hidden_states = execute_model_req.previous_hidden_states
- if hidden_states is not None:
- hidden_states.prune(nonzero_proposal_len_seqs)
- nonzero_execute_model_req = ExecuteModelRequest(
- seq_group_metadata_list=nonzero_proposal_len_seqs,
- num_lookahead_slots=proposal_len,
- previous_hidden_states=hidden_states,
- )
- maybe_sampler_output, transposed = self._worker.sampler_output(
- execute_model_req=nonzero_execute_model_req,
- sample_len=proposal_len,
- seq_ids_with_bonus_token_in_last_step=\
- seq_ids_with_bonus_token_in_last_step,
- )
- (
- proposal_lens,
- maybe_sampler_output,
- nonzero_proposal_len_indices,
- ) = self._remove_no_proposal_seqs(proposal_lens,
- maybe_sampler_output,
- nonzero_proposal_len_indices,
- transposed)
- else:
- # If no sequences can be speculated, set sampler output to None.
- maybe_sampler_output = None
- transposed = False
- # Combine speculative- and non-speculative sequences into the same
- # representation.
- proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
- batch_size=len(seq_group_metadata_list),
- proposal_len=proposal_len,
- maybe_sampler_output=maybe_sampler_output,
- proposal_lens=proposal_lens,
- nonzero_proposal_len_indices=nonzero_proposal_len_indices,
- sampler_transposed=transposed,
- )
- proposals = SpeculativeProposals(
- proposal_token_ids=proposal_tokens,
- proposal_probs=proposal_probs,
- proposal_lens=proposal_lens,
- no_proposals=maybe_sampler_output is None,
- )
- return proposals
- def _split_by_proposal_len(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- proposal_len: int,
- ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
- """Split sequences by two groups:
- 1. Sequences with non-zero proposal length.
- 2. Sequences with zero proposal length (due to disabled speculation
- or exceed the maximum model length).
- """
- proposal_lens: List[int] = []
- nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
- nonzero_proposal_len_indices: List[int] = []
- for i, seq_group_metadata in enumerate(seq_group_metadata_list):
- # The speculative decoding for this request has been disabled
- # (e.g. due to high traffic).
- if seq_group_metadata.num_speculative_tokens == 0:
- proposal_lens.append(0)
- continue
- seq_data = next(iter(seq_group_metadata.seq_data.values()))
- seq_len = seq_data.get_len()
- # Currently only proposal lens of 0 or the global batch proposal len
- # are supported.
- # If max_proposal_len is defined, then we shall no exccess this
- # quota for nonzero_proposal
- new_k = 0
- if (self.max_proposal_len is None
- or seq_len + proposal_len < self.max_proposal_len):
- new_k = proposal_len
- nonzero_proposal_len_seqs.append(seq_group_metadata)
- nonzero_proposal_len_indices.append(i)
- proposal_lens.append(new_k)
- seq_group_metadata.num_speculative_tokens = new_k
- return (
- proposal_lens,
- nonzero_proposal_len_seqs,
- nonzero_proposal_len_indices,
- )
- @staticmethod
- def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
- nonzero_proposal_len_indices, transposed):
- """Remove sequences from nonzero_proposal_len_indices and reset
- their proposal_len to 0 the draft worker does not provide a proposal
- (maybe_sampler_output=None). This can avoid scoring overheads.
- """
- # If maybe_sampler_output is None, then the draft worker did not
- # provide a proposal for any sequence and thus no action needed.
- # Also we do not support transposed maybe_sampler_output for now
- # because it seems not straightforward for draft workers outputting
- # transposed sampler outputs to handle the case of no proposal.
- if maybe_sampler_output is None or transposed:
- return (proposal_lens, maybe_sampler_output,
- nonzero_proposal_len_indices)
- new_proposal_lens: List[int] = []
- new_nonzero_proposal_len_indices: List[int] = []
- new_maybe_sampler_output: List[SamplerOutput] = []
- nonzero_proposal_len_idx_ptr = 0
- seq_idx = 0
- while seq_idx < len(
- proposal_lens) and nonzero_proposal_len_idx_ptr < len(
- nonzero_proposal_len_indices):
- if seq_idx < nonzero_proposal_len_indices[
- nonzero_proposal_len_idx_ptr]:
- # Sequence is not in the original nonzero_proposal_len_indices,
- # meaning that it has a proposal length of 0 before sending to
- # the draft worker.
- assert proposal_lens[seq_idx] == 0
- new_proposal_lens.append(0)
- else:
- # Sequence is in the original nonzero_proposal_len_indices
- if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
- # but does not have a proposal from the draft worker.
- new_proposal_lens.append(0)
- else:
- # and has a proposal from the draft worker. Add it to the
- # new nonzero proposal list and keep the sampler output.
- new_proposal_lens.append(proposal_lens[seq_idx])
- new_nonzero_proposal_len_indices.append(seq_idx)
- new_maybe_sampler_output.append(
- maybe_sampler_output[nonzero_proposal_len_idx_ptr])
- nonzero_proposal_len_idx_ptr += 1
- seq_idx += 1
- # The remaining sequences should have proposal length of 0.
- new_proposal_lens.extend(proposal_lens[seq_idx:])
- # We assume sampler_output will not be a list of all Nones.
- # In this case this function should not be called.
- assert new_maybe_sampler_output
- return (new_proposal_lens, new_maybe_sampler_output,
- new_nonzero_proposal_len_indices)
- def _merge_outputs(
- self,
- batch_size: int,
- proposal_len: int,
- maybe_sampler_output: Optional[List[SamplerOutput]],
- proposal_lens: List[int],
- nonzero_proposal_len_indices: List[int],
- sampler_transposed: bool,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """After speculations are produced, merge the speculation results with
- the skipped sequences.
- """
- if maybe_sampler_output is None:
- # If no speculative tokens, the sampler output will be None.
- # In this case we return empty proposals.
- proposal_tokens = torch.tensor(-1,
- dtype=torch.long,
- device=self._device).expand(
- batch_size, proposal_len)
- proposal_probs = torch.tensor(0,
- dtype=torch.float32,
- device=self._device).expand(
- batch_size, proposal_len,
- self._vocab_size)
- proposal_lens_tensor = torch.tensor(0,
- dtype=torch.long,
- device=self._device).expand(
- len(proposal_lens))
- return proposal_tokens, proposal_probs, proposal_lens_tensor
- sampler_output = maybe_sampler_output
- proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
- sampler_output, sampler_transposed)
- # Now, reformat the output GPU tensors such that each sequence has
- # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
- entire_proposal_tokens = proposal_tokens.new_full(
- size=(batch_size, *proposal_tokens.shape[1:]),
- fill_value=-1,
- )
- entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
- entire_proposal_probs = proposal_probs.new_zeros(
- batch_size,
- *proposal_probs.shape[1:],
- )
- entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
- proposal_tokens, proposal_probs = (
- entire_proposal_tokens,
- entire_proposal_probs,
- )
- proposal_lens_tensor = torch.zeros(batch_size,
- dtype=torch.long,
- device=self._device)
- proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
- return proposal_tokens, proposal_probs, proposal_lens_tensor
|