123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458 |
- from array import array
- from itertools import chain, count
- from typing import Iterator, List, Optional, Tuple
- import torch
- from aphrodite import SamplingParams
- from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
- ExecuteModelRequest, SequenceData,
- SequenceGroupMetadata, get_all_seq_ids)
- from aphrodite.modeling.layers.sampler import SamplerOutput
- from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
- SpeculativeScorer,
- SpeculativeScores)
- from aphrodite.spec_decode.util import nvtx_range, split_batch_by_proposal_len
- from aphrodite.worker.worker_base import WorkerBase
- SeqId = int
- TargetSeqId = int
- TokenId = int
- DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
- class BatchExpansionTop1Scorer(SpeculativeScorer):
- """Implements a speculative scorer that uses batch expansion to get
- probabilities of speculative tokens according to the scoring model.
- Batch expansion converts a list of sequences and multiple query positions
- to a new batch of sequences, each with a single query position. This allows
- for MQA-like scoring in speculative decoding without requiring an MQA
- kernel.
- It is strictly less efficient than MQA scoring.
- It only supports scoring the top1 proposal tokens of the proposer, instead
- of topk/tree.
- """
- def __init__(self, scorer_worker: WorkerBase, device: str,
- vocab_size: int):
- self._scorer_worker = scorer_worker
- self._device = device
- self._vocab_size = vocab_size
- @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
- def score_proposals(
- self,
- execute_model_req: ExecuteModelRequest,
- proposals: SpeculativeProposals,
- ) -> SpeculativeScores:
- """Score the proposed tokens via the scorer model.
- This converts each input sequence to a set of k+1 target sequences. The
- target sequences have the unique continuations to be scored and a
- unique sequence ID that is different from all input sequence ids.
- If a speculative sequence length would exceed the max model length, then
- no speculation is produced for that sequence.
- Args:
- execute_model_req: The execution request.
- proposals: The speculative proposals to score.
- Returns:
- SpeculativeScores: The scores of each speculative token, along with
- which sequences were ignored during scoring.
- """
- # TODO: perform this on GPU to remove blocking call.
- proposal_lens_list = proposals.proposal_lens.tolist()
- proposal_token_ids_list = proposals.proposal_token_ids.tolist()
- # Filter the list to ignore -1 proposals.
- proposal_token_ids_list_without_skips = [
- proposals for proposals in proposal_token_ids_list
- if -1 not in proposals
- ]
- (spec_indices, non_spec_indices, target_seq_group_metadata_list,
- num_scoring_tokens) = self._expand_batch(
- seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
- proposal_token_ids_list=proposal_token_ids_list_without_skips,
- proposal_lens_list=proposal_lens_list,
- )
- target_sampler_output = self._scorer_worker.execute_model(
- execute_model_req=execute_model_req.clone(
- seq_group_metadata_list=target_seq_group_metadata_list))
- assert len(target_sampler_output) == 1, "expected single-step output"
- target_sampler_output = target_sampler_output[0]
- if not non_spec_indices:
- # All sequence groups in batch have spec decoding enabled
- contracted = self._contract_batch_all_spec(
- target_sampler_output=target_sampler_output,
- proposals=proposals,
- )
- else:
- # Batch has a mix of spec decode enabled and disabled seq groups
- contracted = self._contract_batch(
- contracted_bs=len(execute_model_req.seq_group_metadata_list),
- target_sampler_output=target_sampler_output,
- proposals=proposals,
- num_scoring_tokens=num_scoring_tokens,
- non_spec_indices=non_spec_indices,
- spec_indices=spec_indices,
- k=execute_model_req.num_lookahead_slots,
- )
- all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
- return SpeculativeScores(
- probs=all_probs,
- token_ids=all_tokens,
- logprobs=spec_logprobs,
- hidden_states=all_hidden_states,
- )
- def _expand_batch(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- proposal_token_ids_list: List[List[TokenId]],
- proposal_lens_list: List[int],
- ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
- """Given the input sequences and potentially multiple corresponding
- proposal tokens, create a new batch where each sequence has a single
- query token.
- """
- # Aphrodite currently only supports proposal lens equal to zero or the
- # batch proposal len. This adds some complexity (splitting the batch
- # into spec and non spec sequences) and should be removed in the
- # future. It can be done by supporting per-sequence proposal lens.
- (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
- split_batch_by_proposal_len(
- seq_group_metadata_list, proposal_lens_list)
- target_seq_group_metadata_list = self._create_scoring_model_input(
- seq_group_metadata_list=spec_seqs,
- proposal_token_ids=proposal_token_ids_list,
- # NOTE: We determine the seq ids in the expanded batch using the
- # full seq_group_metadata_list, instead of only spec_seqs.
- target_seq_ids_iter=self._create_target_seq_id_iterator(
- seq_ids=get_all_seq_ids(seq_group_metadata_list)),
- )
- num_scoring_tokens = len(target_seq_group_metadata_list)
- target_seq_group_metadata_list.extend(non_spec_seqs)
- return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
- num_scoring_tokens)
- def _contract_batch(
- self, contracted_bs: int, target_sampler_output: SamplerOutput,
- proposals: SpeculativeProposals, num_scoring_tokens: int,
- non_spec_indices: List[int], spec_indices: List[int], k: int
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
- Optional[torch.Tensor]]:
- """Contract the expanded batch back into its original size.
- This maps the scores of speculative tokens back to their original
- sequences.
- contracted_bs is the original batch size, and the batch size that the
- target_sampler_output will be contracted to.
- """
- (target_token_ids, target_probs, target_logprobs, target_hidden_states,
- non_spec_target_token_ids, non_spec_target_probs,
- non_spec_target_logprobs,
- non_spec_target_hidden_states) = self._split_scoring_output(
- target_sampler_output, num_scoring_tokens)
- # Map distinct sequences used to score each token
- # of shape [batch_size * k + 1] back to [batch_size, k + 1].
- expanded_batch_size, k = proposals.proposal_token_ids.shape
- # The number of tokens in the expanded batch used for speculation is
- # equal to the total expanded batch size minus the number of samples for
- # non-speculative sequences.
- non_spec_expanded_bs = len(non_spec_target_token_ids)
- spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
- target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
- target_probs = target_probs.reshape(*target_token_ids.shape,
- self._vocab_size)
- target_logprobs = target_logprobs.reshape(target_probs.shape)
- if target_hidden_states is not None:
- target_hidden_states = target_hidden_states.reshape(
- *target_token_ids.shape, target_hidden_states.shape[-1])
- all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
- fill_value=-1)
- all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
- all_logprobs = target_logprobs.new_full(size=all_probs.shape,
- fill_value=-float("inf"))
- if target_sampler_output.hidden_states is not None:
- all_hidden_states = target_hidden_states.new_zeros(
- size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
- else:
- all_hidden_states = None
- if non_spec_indices:
- all_tokens[non_spec_indices, :1] = \
- non_spec_target_token_ids.unsqueeze(1)
- all_probs[non_spec_indices, :1, :] = \
- non_spec_target_probs.unsqueeze(1)
- all_logprobs[non_spec_indices, :1, :] = \
- non_spec_target_logprobs.unsqueeze(1)
- if all_hidden_states is not None:
- assert non_spec_target_hidden_states is not None
- all_hidden_states[non_spec_indices, :1, :] = \
- non_spec_target_hidden_states.unsqueeze(1)
- if spec_indices:
- all_tokens[spec_indices] = target_token_ids
- all_probs[spec_indices] = target_probs
- all_logprobs[spec_indices] = target_logprobs
- if all_hidden_states is not None:
- all_hidden_states[spec_indices] = target_hidden_states
- return all_tokens, all_probs, all_logprobs, all_hidden_states
- def _contract_batch_all_spec(
- self,
- target_sampler_output: SamplerOutput,
- proposals: SpeculativeProposals,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
- Optional[torch.Tensor]]:
- """Contract the expanded batch back into its original size.
- This maps the scores of speculative tokens back to their original
- sequences.
- It assumes all sequences in the batch were previously expanded.
- """
- # Map distinct sequences used to score each token
- # of shape [batch_size * k + 1] back to [batch_size, k + 1].
- contracted_bs, k = proposals.proposal_token_ids.shape
- # Reshape tensors to original batch size
- target_token_ids = target_sampler_output.sampled_token_ids.reshape(
- contracted_bs, k + 1)
- target_probs = target_sampler_output.sampled_token_probs.reshape(
- *target_token_ids.shape, self._vocab_size)
- target_logprobs = target_sampler_output.logprobs.reshape(
- target_probs.shape)
- target_hidden_states = target_sampler_output.hidden_states
- if target_hidden_states is not None:
- target_hidden_states = target_hidden_states.reshape(
- *target_token_ids.shape, target_hidden_states.shape[-1])
- return (target_token_ids, target_probs, target_logprobs,
- target_hidden_states)
- def _create_scoring_model_input(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
- target_seq_ids_iter: Iterator[TargetSeqId],
- ) -> List[SequenceGroupMetadata]:
- """Given the original input sequences and proposed tokens from the draft
- model, create a list of target sequences that can be used for scoring.
- target_seq_ids_iter provides sequence ids for the expanded batch,
- fulfilling the requirement that no seq id in the expanded batch is equal
- to the seq id in the original batch.
- """
- if not seq_group_metadata_list:
- return []
- target_seq_group_metadata = list(
- chain.from_iterable(
- self._create_target_seq_group_metadata(
- seq_group_metadata,
- proposal_token_ids,
- i,
- target_seq_ids_iter,
- ) for i, seq_group_metadata in enumerate(
- seq_group_metadata_list)))
- return target_seq_group_metadata
- def _create_target_seq_group_metadata(
- self,
- input_seq_group_metadata: SequenceGroupMetadata,
- proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
- batch_index: int,
- target_seq_ids_iter: Iterator[TargetSeqId],
- ) -> List[SequenceGroupMetadata]:
- """Given an input sequence group metadata and a list of draft tokens,
- create a list of target SequenceGroupMetadata, one for each
- token id that needs to be scored.
- Naive speculative decoding requires K target model scores, one for each
- draft model token. However one can add a bonus token such that if each
- token is accepted, then a final token may be sampled from the model.
- This function creates K+1 target SequenceGroupMetadata to take
- advantage of the bonus token.
- """
- assert not input_seq_group_metadata.is_prompt, (
- "Speculating on "
- "prompts not yet supported")
- assert len(input_seq_group_metadata.seq_data) == 1, (
- "Beam search "
- "not supported in speculative decoding")
- input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
- token_ids_to_score = self._get_token_ids_to_score(
- proposal_token_ids[batch_index])
- # Use simpler sampling parameters apart from for final token
- # (in particular don't do seeded sampling) since those sampled tokens
- # aren't used.
- # We don't replace the sampling_params in the greedy case because
- # this also controls whether the probs get modified in the sampler
- # (see use of _modify_greedy_probs_inplace there).
- sampling_params = input_seq_group_metadata.sampling_params
- non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
- if sampling_params.temperature else sampling_params
- target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
- last_index = len(token_ids_to_score) - 1
- for i, token_ids in enumerate(token_ids_to_score):
- target_sampling_params = sampling_params if i == last_index \
- else non_bonus_sampling_params
- target_seq_group_metadata_list.append(
- self._create_single_target_seq_group_metadata(
- input_seq_group_metadata,
- input_seq_id,
- next(target_seq_ids_iter),
- token_ids,
- sampling_params=target_sampling_params,
- ))
- return target_seq_group_metadata_list
- @staticmethod
- def _create_single_target_seq_group_metadata(
- seq_group_metadata: SequenceGroupMetadata,
- seq_id: SeqId,
- target_seq_id: TargetSeqId,
- token_ids: List[TokenId],
- sampling_params: SamplingParams,
- ) -> SequenceGroupMetadata:
- """Create a single target SequenceGroupMetadata.
- Args:
- seq_group_metadata: The metadata for the input sequence.
- seq_id: The input sequence ID.
- target_seq_id: The corresponding target sequence ID.
- token_ids: The list of token ids that are to be appended to the
- input sequence.
- """
- seq_data = seq_group_metadata.seq_data[seq_id]
- prompt_token_ids = seq_data.prompt_token_ids_array
- new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
- new_seq_data_dict = {
- target_seq_id:
- SequenceData(
- prompt_token_ids,
- _output_token_ids=array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
- new_output_token_ids),
- ),
- }
- # This is a hack. Technically, spec decoding should compute
- # num_lookahead slots at one shot, but instead, it expands the batch
- # and evaluate one by one right now. context_len is seq_len - 1 because
- # the kv cache is filled by a previous batch in the batch expansion.
- for data in new_seq_data_dict.values():
- data.update_num_computed_tokens(data.get_len() - 1)
- return SequenceGroupMetadata(
- request_id=seq_group_metadata.request_id,
- is_prompt=seq_group_metadata.is_prompt,
- seq_data=new_seq_data_dict,
- sampling_params=sampling_params,
- block_tables={
- target_seq_id: seq_group_metadata.block_tables[seq_id],
- },
- lora_request=None,
- token_chunk_size=1,
- )
- @staticmethod
- def _split_scoring_output(
- sampler_output: SamplerOutput, num_scoring_tokens: int
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
- Optional[torch.Tensor], torch.Tensor, torch.Tensor,
- torch.Tensor, Optional[torch.Tensor]]:
- """Split the target model output into speculative and non-speculative
- output.
- """
- # Aphrodite currently only supports proposal lens equal to zero or the
- # batch proposal len. This adds some complexity (splitting the batch
- # into spec and non spec sequences) and should be removed in the
- # future. It can be done by supporting per-sequence proposal lens.
- # First samples are from speculative scoring, latter samples are non-
- # speculative samples.
- split_sizes = (num_scoring_tokens,
- sampler_output.sampled_token_ids.numel() -
- num_scoring_tokens)
- (spec_probs, non_spec_probs
- ) = sampler_output.sampled_token_probs.split(split_sizes)
- (spec_sampled_tokens, non_spec_sampled_tokens
- ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
- (
- spec_logprobs,
- non_spec_logprobs,
- ) = sampler_output.logprobs.split(split_sizes)
- if sampler_output.hidden_states is not None:
- (
- spec_hidden_states,
- non_spec_hidden_states,
- ) = sampler_output.hidden_states.split(split_sizes)
- else:
- spec_hidden_states, non_spec_hidden_states = None, None
- return (spec_sampled_tokens, spec_probs, spec_logprobs,
- spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
- non_spec_logprobs, non_spec_hidden_states)
- @staticmethod
- def _create_target_seq_id_iterator(
- seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
- """Create an iterator for creating target sequence ids.
- Target sequence ids are distinct from sequence ids because we create a
- distinct target sequence id for each proposal token to be scored.
- This implementation increments a counter starting at 1 + max of all
- provided input sequence ids.
- """
- return count(start=max(seq_ids) + 1)
- @staticmethod
- def _get_token_ids_to_score(
- full_spec_token_ids: List[TokenId] # shape: [k]
- ) -> List[List[TokenId]]:
- """Given an int tensor of proposal token ids, return a list of
- token ids that should be scored.
- Returns k+1 output lists. The additional one is used for generating the
- bonus token.
- Example:
- Input: [0, 1, 2, 3] (k=4)
- Output: (k+1 lists)
- []
- [0]
- [0, 1]
- [0, 1, 2]
- [0, 1, 2, 3]
- """
- empty_token_ids: List[TokenId] = []
- token_ids_to_score = [empty_token_ids]
- token_ids_to_score.extend(full_spec_token_ids[:i + 1]
- for i in range(len(full_spec_token_ids)))
- return token_ids_to_score
|