123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466 |
- from functools import cached_property
- from typing import List, Optional, Tuple
- import torch
- from loguru import logger
- from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
- SequenceGroupMetadata)
- from aphrodite.modeling.layers.rejection import RejectionSampler
- from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
- from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
- SpeculativeScorer,
- SpeculativeScores)
- from aphrodite.spec_decode.metrics import AsyncMetricsCollector
- from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
- from aphrodite.spec_decode.ngram_worker import NGramWorker
- from aphrodite.spec_decode.util import (create_sequence_group_output,
- get_all_num_logprobs, get_all_seq_ids,
- get_sampled_token_logprobs, nvtx_range,
- split_batch_by_proposal_len)
- from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase,
- WorkerBase)
- class SpecDecodeWorker(LoraNotSupportedWorkerBase):
- """Worker which implements speculative decoding.
- Speculative decoding reduces decoding per-token latency by using a proposal
- method, such as a small draft model, to speculate ahead of a larger LLM. The
- probabilities of the speculative tokens are then determined by the larger
- LLM, after which some verification routine determines which (if any) of the
- speculative tokens are accepted by the larger LLM.
- The current implementation has the following limitations:
- * Only draft-model proposal is implemented (contributions for more forms are
- welcome!).
- * Only top-1 proposal and scoring are implemented. Tree-attention is left as
- future work.
- * Only lossless rejection sampling is supported. Contributions adding lossy
- verification routines are welcome (e.g. Medusa's typical acceptance).
- * All sequences in a batch must have the same proposal length, or zero. This
- can be improved by having per-sequence speculation in the future.
- * The scoring forward pass is done without an MQA kernel, which is
- suboptimal especially as the batch size, proposal length, and sequence
- lengths grow. Contributions to add a MQA scoring are welcome once
- correctness tests pass.
- """
- @classmethod
- def create_worker(
- cls,
- scorer_worker: WorkerBase,
- draft_worker_kwargs,
- ) -> "SpecDecodeWorker":
- ngram_prompt_lookup_max = (
- draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
- ngram_prompt_lookup_min = (
- draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
- if ngram_prompt_lookup_max > 0:
- proposer_worker = NGramWorker(**draft_worker_kwargs)
- proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
- ngram_prompt_lookup_max)
- else:
- proposer_worker = MultiStepWorker(**draft_worker_kwargs)
- logger.info("Configuring SpecDecodeWorker with "
- f"proposer={type(proposer_worker)}")
- return SpecDecodeWorker(
- proposer_worker,
- scorer_worker,
- # TODO: disable strict mode for speedup.
- rejection_sampler=RejectionSampler(strict_mode=True),
- )
- def __init__(
- self,
- proposer_worker: WorkerBase,
- scorer_worker: WorkerBase,
- rejection_sampler: RejectionSampler,
- metrics_collector: Optional[AsyncMetricsCollector] = None,
- ):
- """
- Create a SpecDecodeWorker.
- Args:
- proposer_worker: A worker that can produce speculative tokens for
- sequences.
- scorer_worker: A worker that produces probabilities of speculative
- tokens according to some base model. Typically a vanilla
- Aphrodite Worker.
- rejection_sampler: A Torch module used to perform modified rejection
- sampling for speculative decoding.
- metrics_collector: Helper class for collecting metrics; can be set
- for testing purposes.
- """
- self.proposer_worker = proposer_worker
- self.scorer_worker = scorer_worker
- self.rejection_sampler = rejection_sampler
- self._metrics = AsyncMetricsCollector(
- rejection_sampler
- ) if metrics_collector is None else metrics_collector
- self.probs_dtype = self.rejection_sampler.probs_dtype
- self.token_id_dtype = self.rejection_sampler.token_id_dtype
- # Lazy initiazliation.
- self.scorer: SpeculativeScorer
- def init_device(self) -> None:
- """Initialize both scorer and proposer models.
- """
- # The scorer worker model is initialized first in case the proposer
- # model has a smaller TP degree than the target worker.
- self.scorer_worker.init_device()
- self.proposer_worker.init_device()
- # NOTE: load_model is not part of the WorkerBase interface.
- self.scorer_worker.load_model()
- self.proposer_worker.load_model()
- self._metrics.init_gpu_tensors(self.rank)
- self.rejection_sampler.init_gpu_tensors(self.rank)
- self.scorer = BatchExpansionTop1Scorer(
- scorer_worker=self.scorer_worker,
- device=self.device,
- vocab_size=self._vocab_size)
- self._configure_model_sampler_for_spec_decode()
- def _configure_model_sampler_for_spec_decode(self):
- """Configure model sampler to emit GPU tensors. This allows spec decode
- to keep data on device without transferring to CPU and serializing,
- which significantly reduces overhead of rejection sampling.
- NOTE: This breaks abstraction boundaries pretty badly. The better
- design is to have the "move to CPU and serialize" sampling decision be
- done outside of the model/sampler; this way the "last-mile" worker
- object which interfaces with the scheduler can serialize and incur the
- performance hit as necessary. This allows us to run the worker several
- iterations in a row without incurring the "move to CPU and serialize"
- performance penalty.
- Since this requires a large change to vLLM, we defer it to later and
- temporarily accept this broken abstraction boundary.
- NOTE: This will require a special check if the proposer worker
- does not have a sampler (e.g. ngram speculation).
- """
- (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
- ) = True
- self.proposer_worker.set_include_gpu_probs_tensor()
- def determine_num_available_blocks(self) -> Tuple[int, int]:
- """Determine the number of cache blocks to use.
- This is done by profiling the scorer model (which is typically the
- larger of the two). Then the total memory which would be used by the
- scorer cache is divided evenly between the proposer and scorer model KV,
- such that the number of blocks is equal in both KV caches.
- """
- num_gpu_blocks, num_cpu_blocks = (
- self.scorer_worker.determine_num_available_blocks())
- scorer_cache_block_size_bytes = (
- self.scorer_worker.get_cache_block_size_bytes())
- proposer_cache_block_size_bytes = (
- self.proposer_worker.get_cache_block_size_bytes())
- new_num_gpu_blocks = split_num_cache_blocks_evenly(
- scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
- num_gpu_blocks)
- return new_num_gpu_blocks, num_cpu_blocks
- def initialize_cache(self, num_gpu_blocks: int,
- num_cpu_blocks: int) -> None:
- """Initialize the cache engine of the scorer and proposer workers.
- """
- self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
- num_cpu_blocks=num_cpu_blocks)
- self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
- num_cpu_blocks=num_cpu_blocks)
- @torch.inference_mode()
- def execute_model(
- self,
- execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
- """Perform speculative decoding on the input batch.
- """
- assert execute_model_req.seq_group_metadata_list is not None, (
- "speculative decoding "
- "requires non-None seq_group_metadata_list")
- # If no spec tokens, call the proposer and scorer workers normally.
- # Used for prefill.
- if execute_model_req.num_lookahead_slots == 0 or len(
- execute_model_req.seq_group_metadata_list) == 0:
- return self._run_no_spec(execute_model_req)
- return self._run_speculative_decoding_step(execute_model_req)
- @nvtx_range("spec_decode_worker._run_no_spec")
- def _run_no_spec(
- self,
- execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
- """Run a prefill step, without any speculation. The input is sent to the
- proposer and scorer model so that the KV cache is consistent between the
- two.
- """
- #logger.info("run proposer worker no spec")
- self.proposer_worker.execute_model(execute_model_req)
- #logger.info("run target worker no spec")
- sampler_output = self.scorer_worker.execute_model(execute_model_req)
- assert len(sampler_output) == 1
- sampler_output = sampler_output[0]
- # Clear device tensors from sampler output. This reduces communication
- # overhead when the engine runs in a different process than the workers.
- sampler_output.probs = None
- sampler_output.sampled_tokens = None
- sampler_output.logprobs = None
- return [sampler_output]
- @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
- def _run_speculative_decoding_step(
- self,
- execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
- """Execute a single step of speculative decoding.
- This invokes the proposer worker to get k speculative tokens for each
- sequence, then scores each speculative token using the scoring worker.
- Returns a list of SamplerOutput, each containing a single token per
- sequence.
- """
- #logger.info("get spec proposals")
- # Generate proposals using draft worker.
- proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
- #logger.info("score proposals")
- proposal_scores = self.scorer.score_proposals(
- execute_model_req,
- proposals,
- )
- #logger.info("verify proposals")
- accepted_token_ids, target_logprobs = self._verify_tokens(
- execute_model_req.seq_group_metadata_list, proposal_scores,
- proposals, execute_model_req.num_lookahead_slots)
- #logger.info("create output list")
- return self._create_output_sampler_list(
- execute_model_req.seq_group_metadata_list,
- accepted_token_ids,
- target_logprobs=target_logprobs,
- k=execute_model_req.num_lookahead_slots)
- @nvtx_range("spec_decode_worker._verify_tokens")
- def _verify_tokens(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- proposal_scores: SpeculativeScores,
- proposals: SpeculativeProposals,
- max_proposal_len: int,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Determine which speculative tokens are accepted using the
- probabilities of each token according to the proposer and scorer models.
- Returns a tuple of Tensors, one for the accepted token ids and one for
- the logprobs according to the scoring model.
- """
- proposal_lens_list = proposals.proposal_lens.tolist()
- # Aphrodite currently only supports proposal lens equal to zero or the
- # batch proposal len. This adds some complexity (splitting the batch
- # into spec and non spec sequences) and should be removed in the
- # future. It can be done by supporting per-sequence proposal lens.
- _, spec_indices = split_batch_by_proposal_len(
- seq_group_metadata_list,
- proposal_lens_list,
- select_proposal_len_zero=False)
- _, non_spec_indices = split_batch_by_proposal_len(
- seq_group_metadata_list,
- proposal_lens_list,
- select_proposal_len_zero=True)
- original_indices = spec_indices + non_spec_indices
- # Get probabilities of target model, excluding bonus token.
- proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
- # Get non-speculative sampled tokens from target model.
- non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
- # Get bonus tokens from target model.
- bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
- # Get probabilities according to proposal method.
- proposal_probs = proposals.proposal_probs[spec_indices]
- # Get proposed tokens.
- proposal_token_ids = proposals.proposal_token_ids[spec_indices]
- accepted_token_ids = self.rejection_sampler(
- target_probs=proposal_verifier_probs,
- bonus_token_ids=bonus_token_ids,
- draft_probs=proposal_probs,
- draft_token_ids=proposal_token_ids,
- )
- # Append output tokens from non-speculative sequences to
- # the accepted token ids tensor.
- non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
- 1).clone()
- non_spec_token_ids[:, 1:] = -1
- accepted_token_ids = torch.cat(
- [accepted_token_ids, non_spec_token_ids])
- logprobs = proposal_scores.logprobs
- # Rearrange so that results are in the order of the original seq group
- # metadata.
- accepted_token_ids[original_indices] = accepted_token_ids.clone()
- return accepted_token_ids, logprobs
- def _create_output_sampler_list(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
- target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
- k: int,
- ) -> List[SamplerOutput]:
- """Given the accepted token ids, create a list of SamplerOutput.
- The output is padded with -1 tokens such that each sequence has
- the same number of outputs.
- """
- batch_size, num_steps = accepted_token_ids.shape
- # Organize input tensors by step instead of by sequence.
- target_logprobs_by_step = target_logprobs.transpose(0, 1)
- accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
- # Get the logprobs/rank of the accepted tokens.
- (accepted_token_id_ranks_by_step,
- accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
- logprob_tensor=target_logprobs_by_step,
- sampled_token_ids=accepted_token_ids_by_step,
- )
- # Get the top-k logprobs (which may or may not include the logprob of
- # the accepted token).
- (topk_logprobs_by_step,
- topk_indices_by_step) = target_logprobs_by_step.topk(
- k=self.scorer_worker.model_config.max_logprobs,
- dim=-1,
- )
- # Get the sequence ids and num_logprobs (sampling parameter) in the
- # batch.
- seq_ids = get_all_seq_ids(seq_group_metadata_list)
- num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
- # Serialize all tensors to CPU Python lists.
- accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
- accepted_token_id_ranks_by_step = (
- accepted_token_id_ranks_by_step.tolist())
- accepted_token_id_logprobs_by_step = (
- accepted_token_id_logprobs_by_step.tolist())
- topk_logprobs_by_step = topk_logprobs_by_step.tolist()
- topk_indices_by_step = topk_indices_by_step.tolist()
- # Construct the output on a per-step, per-sequence basis.
- sampler_output_list = []
- for step_index in range(num_steps):
- if all(token_id == -1
- for token_id in accepted_token_ids_by_step[step_index]):
- break
- step_output_token_ids = []
- for sequence_index in range(batch_size):
- # Each sequence may have a different num_logprobs; retrieve it.
- num_logprobs = num_logprobs_per_seq[sequence_index]
- step_output_token_ids.append(
- create_sequence_group_output(
- token_id=accepted_token_ids_by_step[step_index]
- [sequence_index],
- token_id_logprob_rank=accepted_token_id_ranks_by_step[
- step_index][sequence_index],
- token_id_logprob=accepted_token_id_logprobs_by_step[
- step_index][sequence_index],
- seq_id=seq_ids[sequence_index],
- topk_token_ids=topk_indices_by_step[step_index]
- [sequence_index][:num_logprobs],
- topk_logprobs=topk_logprobs_by_step[step_index]
- [sequence_index][:num_logprobs],
- ))
- sampler_output_list.append(
- SamplerOutput(outputs=step_output_token_ids))
- maybe_rejsample_metrics = (
- self._metrics.maybe_collect_rejsample_metrics(k))
- if maybe_rejsample_metrics is not None:
- sampler_output_list[
- 0].spec_decode_worker_metrics = maybe_rejsample_metrics
- return sampler_output_list
- @cached_property
- def _vocab_size(self) -> int:
- """Get the vocab size of the model and make sure it's consistent between
- draft and target workers.
- """
- vocab_sizes = [
- worker.vocab_size
- for worker in [self.proposer_worker, self.scorer_worker]
- ]
- assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
- return vocab_sizes[0]
- @property
- def rank(self):
- return self.scorer_worker.rank
- @property
- def device(self):
- return self.scorer_worker.device
- def get_cache_block_size_bytes(self):
- """Return the size of a cache block in bytes.
-
- This function is only used to compose workers within a SpecDecodeWorker.
- We leave composing a SpecDecodeWorker within a SpecDecodeWorker
- undefined for now, although it could be implemented in the future.
- See https://arxiv.org/abs/2308.04623.
- """
- raise NotImplementedError
- def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
- proposer_cache_block_size_bytes: int,
- total_num_gpu_blocks: int) -> int:
- """Given total_num_gpu_blocks, the number of GPU blocks that could be
- allocate to the target model, this function calculates how many blocks
- should be given to the draft and target model.
- Note that usually the block size, in bytes, of each model is different,
- as it's a function of number of KV/layer, number of heads, and hidden
- dimension size.
- Since the target and draft models allocate the same number of blocks, we
- simply calculate the number of blocks where if allocated by both models,
- the total memory usage from KV cache is no larger than the number of
- blocks allocatable by the target model alone.
- """
- new_num_gpu_blocks = int(
- total_num_gpu_blocks * scorer_cache_block_size_bytes /
- (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
- return new_num_gpu_blocks
|