spec_decode_worker.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983
  1. from collections import defaultdict
  2. from functools import cached_property
  3. from typing import Any, Dict, List, Optional, Set, Tuple
  4. import torch
  5. from loguru import logger
  6. from aphrodite.common.config import ParallelConfig, SpeculativeConfig
  7. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  8. ExecuteModelRequest, HiddenStates,
  9. SequenceGroupMetadata, get_all_seq_ids,
  10. get_all_seq_ids_and_request_ids)
  11. from aphrodite.distributed.communication_op import broadcast_tensor_dict
  12. from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
  13. from aphrodite.modeling.layers.sampler import SamplerOutput
  14. from aphrodite.modeling.layers.spec_decode_base_sampler import (
  15. SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
  16. from aphrodite.modeling.layers.typical_acceptance_sampler import (
  17. TypicalAcceptanceSampler)
  18. from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
  19. from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
  20. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  21. SpeculativeScorer,
  22. SpeculativeScores)
  23. from aphrodite.spec_decode.medusa_worker import MedusaWorker
  24. from aphrodite.spec_decode.metrics import AsyncMetricsCollector
  25. from aphrodite.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
  26. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  27. from aphrodite.spec_decode.ngram_worker import NGramWorker
  28. from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
  29. from aphrodite.spec_decode.smaller_tp_proposer_worker import (
  30. SmallerTpProposerWorker)
  31. from aphrodite.spec_decode.target_model_runner import TargetModelRunner
  32. from aphrodite.spec_decode.util import (Timer, create_sequence_group_output,
  33. get_all_num_logprobs,
  34. get_sampled_token_logprobs, nvtx_range,
  35. split_batch_by_proposal_len)
  36. from aphrodite.worker.worker import Worker
  37. from aphrodite.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
  38. def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
  39. """Helper method that is the entrypoint for Executors which use
  40. WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
  41. """
  42. assert "speculative_config" in kwargs
  43. speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
  44. assert speculative_config is not None
  45. draft_worker_kwargs = kwargs.copy()
  46. kwargs["model_runner_cls"] = TargetModelRunner
  47. target_worker = Worker(*args, **kwargs)
  48. # Set the disable_logprobs variable in the TargetModelRunner instance
  49. # as per its value specified in the SpeculativeConfig.
  50. target_worker.model_runner.disable_logprobs =\
  51. speculative_config.disable_logprobs
  52. # Override draft-model specific worker args.
  53. draft_worker_kwargs.update(
  54. model_config=speculative_config.draft_model_config,
  55. parallel_config=speculative_config.draft_parallel_config,
  56. ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
  57. ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
  58. # TODO allow draft-model specific load config.
  59. #load_config=load_config,
  60. )
  61. spec_decode_worker = SpecDecodeWorker.create_worker(
  62. scorer_worker=target_worker,
  63. draft_worker_kwargs=draft_worker_kwargs,
  64. disable_by_batch_size=speculative_config.
  65. speculative_disable_by_batch_size,
  66. draft_token_acceptance_method=speculative_config.
  67. draft_token_acceptance_method,
  68. typical_acceptance_sampler_posterior_threshold=speculative_config.
  69. typical_acceptance_sampler_posterior_threshold,
  70. typical_acceptance_sampler_posterior_alpha=speculative_config.
  71. typical_acceptance_sampler_posterior_alpha,
  72. disable_logprobs=speculative_config.disable_logprobs,
  73. disable_log_stats=speculative_config.disable_log_stats,
  74. )
  75. return spec_decode_worker
  76. class SpecDecodeWorker(LoraNotSupportedWorkerBase):
  77. """Worker which implements speculative decoding.
  78. Speculative decoding reduces decoding per-token latency by using a proposal
  79. method, such as a small draft model, to speculate ahead of a larger LLM. The
  80. probabilities of the speculative tokens are then determined by the larger
  81. LLM, after which some verification routine determines which (if any) of the
  82. speculative tokens are accepted by the larger LLM.
  83. The current implementation has the following limitations:
  84. * Only draft-model proposal is implemented (contributions for more forms are
  85. welcome!).
  86. * Only top-1 proposal and scoring are implemented. Tree-attention is left as
  87. future work.
  88. * All sequences in a batch must have the same proposal length, or zero. This
  89. can be improved by having per-sequence speculation in the future.
  90. * The scoring forward pass is done without an MQA kernel, which is
  91. suboptimal especially as the batch size, proposal length, and sequence
  92. lengths grow. Contributions to add a MQA scoring are welcome once
  93. correctness tests pass.
  94. """
  95. @classmethod
  96. def create_worker(
  97. cls,
  98. scorer_worker: Worker,
  99. draft_worker_kwargs: Dict[str, Any],
  100. disable_by_batch_size: Optional[int],
  101. draft_token_acceptance_method: str,
  102. typical_acceptance_sampler_posterior_threshold: float,
  103. typical_acceptance_sampler_posterior_alpha: float,
  104. disable_logprobs: bool,
  105. disable_log_stats: bool,
  106. ) -> "SpecDecodeWorker":
  107. allow_zero_draft_token_step = True
  108. ngram_prompt_lookup_max = (
  109. draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
  110. ngram_prompt_lookup_min = (
  111. draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
  112. if ngram_prompt_lookup_max > 0:
  113. proposer_worker = NGramWorker(**draft_worker_kwargs)
  114. proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
  115. ngram_prompt_lookup_max)
  116. else:
  117. draft_parallel_config: ParallelConfig = draft_worker_kwargs[
  118. 'parallel_config']
  119. draft_tp = draft_parallel_config.tensor_parallel_size
  120. target_tp = scorer_worker.parallel_config.tensor_parallel_size
  121. if draft_worker_kwargs[
  122. "model_config"].hf_config.model_type == "mlp_speculator":
  123. proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
  124. elif draft_worker_kwargs[
  125. "model_config"].hf_config.model_type == "medusa":
  126. proposer_worker = MedusaWorker(**draft_worker_kwargs)
  127. else:
  128. if draft_tp == 1:
  129. draft_worker_kwargs[
  130. "model_runner_cls"] = TP1DraftModelRunner
  131. else:
  132. if draft_worker_kwargs[
  133. "model_config"].hf_config.model_type == "eagle":
  134. raise NotImplementedError(
  135. "EAGLE does not support TP > 1 yet")
  136. allow_zero_draft_token_step = False
  137. proposer_worker = MultiStepWorker(**draft_worker_kwargs)
  138. proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
  139. proposer_worker, draft_tp, target_tp)
  140. logger.info("Configuring SpecDecodeWorker with "
  141. f"proposer={type(proposer_worker)}")
  142. spec_decode_sampler: SpecDecodeBaseSampler = None
  143. if draft_token_acceptance_method == "rejection_sampler":
  144. spec_decode_sampler = RejectionSampler(
  145. disable_bonus_tokens=False, )
  146. elif draft_token_acceptance_method == "typical_acceptance_sampler":
  147. spec_decode_sampler = TypicalAcceptanceSampler(
  148. disable_bonus_tokens=False,
  149. posterior_threshold=\
  150. typical_acceptance_sampler_posterior_threshold,
  151. posterior_alpha=typical_acceptance_sampler_posterior_alpha,
  152. )
  153. logger.info("Configuring SpecDecodeWorker with "
  154. f"sampler={type(spec_decode_sampler)}")
  155. return SpecDecodeWorker(
  156. proposer_worker,
  157. scorer_worker,
  158. disable_logprobs=disable_logprobs,
  159. disable_log_stats=disable_log_stats,
  160. disable_by_batch_size=disable_by_batch_size,
  161. spec_decode_sampler=spec_decode_sampler,
  162. allow_zero_draft_token_step=allow_zero_draft_token_step)
  163. def __init__(
  164. self,
  165. proposer_worker: ProposerWorkerBase,
  166. scorer_worker: WorkerBase,
  167. spec_decode_sampler: SpecDecodeBaseSampler,
  168. disable_logprobs: bool = False,
  169. disable_log_stats: bool = False,
  170. metrics_collector: Optional[AsyncMetricsCollector] = None,
  171. disable_by_batch_size: Optional[int] = None,
  172. allow_zero_draft_token_step: Optional[bool] = True,
  173. ):
  174. """
  175. Create a SpecDecodeWorker.
  176. Args:
  177. proposer_worker: A worker that can produce speculative tokens for
  178. sequences.
  179. scorer_worker: A worker that produces probabilities of speculative
  180. tokens according to some base model. Typically a vanilla
  181. Aphrodite Worker.
  182. spec_decode_sampler: A Torch module used to perform acceptance
  183. sampling of the draft tokens in the verification step of
  184. speculative decoding. Currently we support two different
  185. types of sampler namely RejectionSampler and
  186. TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
  187. instance of RejectionSampler or TypicalAcceptanceSampler.
  188. disable_logprobs: If set to True, token log probabilities will
  189. not be output in both the draft worker and the target worker.
  190. If set to False, log probabilities will be output by both.
  191. disable_log_stats: If set to True, disable periodic printing of
  192. speculative stage times.
  193. disable_by_batch_size: If the batch size is larger than this,
  194. disable speculative decoding for new incoming requests.
  195. metrics_collector: Helper class for collecting metrics; can be set
  196. for testing purposes.
  197. allow_zero_draft_token_step: whether to allow a step where the draft
  198. model generates no draft token; should disallow when the tp of
  199. draft model is larger than 1
  200. """
  201. self.proposer_worker = proposer_worker
  202. self.scorer_worker = scorer_worker
  203. scorer_runner = getattr(self.scorer_worker, "model_runner", None)
  204. self.generators = scorer_runner.get_generators(
  205. ) if scorer_runner else None
  206. self.disable_by_batch_size = disable_by_batch_size or float("inf")
  207. self.spec_decode_sampler = spec_decode_sampler
  208. self._allow_zero_draft_token_step = allow_zero_draft_token_step
  209. self._metrics = AsyncMetricsCollector(
  210. self.spec_decode_sampler
  211. ) if metrics_collector is None else metrics_collector
  212. # Tracks the sequence IDs that received a bonus token ID in
  213. # their last forward pass. Needed only if KV cache is being
  214. # used for token generation such as in the case of MultiStepWorker.
  215. self._seq_with_bonus_token_in_last_step: Set[int] = set()
  216. # Tracks the currently active request ids and the sequence IDs
  217. # corresponding to them
  218. self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
  219. # Tracks if the proposer worker uses the KV cache or not.
  220. self.probs_dtype = self.spec_decode_sampler.probs_dtype
  221. self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
  222. # Lazy initialization.
  223. self.scorer: SpeculativeScorer
  224. # Hidden states from target model to pass to proposer
  225. # in the subsequent step.
  226. self.previous_hidden_states: Optional[HiddenStates] = None
  227. self._disable_logprobs = disable_logprobs
  228. self._disable_log_stats = disable_log_stats
  229. def init_device(self) -> None:
  230. """Initialize both scorer and proposer models.
  231. """
  232. # The scorer worker model is initialized first in case the proposer
  233. # model has a smaller TP degree than the target worker.
  234. self.scorer_worker.init_device()
  235. self.proposer_worker.init_device()
  236. # NOTE: load_model is not part of the WorkerBase interface.
  237. self.scorer_worker.load_model()
  238. self.proposer_worker.load_model()
  239. self._metrics.init_gpu_tensors(self.rank)
  240. self.spec_decode_sampler.init_gpu_tensors(self.rank)
  241. self.scorer = BatchExpansionTop1Scorer(
  242. scorer_worker=self.scorer_worker,
  243. device=self.device,
  244. vocab_size=self._vocab_size)
  245. self._configure_model_sampler_for_spec_decode()
  246. def load_model(self, *args, **kwargs):
  247. pass
  248. def _configure_model_sampler_for_spec_decode(self):
  249. """Configure model sampler to emit GPU tensors. This allows spec decode
  250. to keep data on device without transferring to CPU and serializing,
  251. which significantly reduces overhead of sampling during verification.
  252. NOTE: This breaks abstraction boundaries pretty badly. The better
  253. design is to have the "move to CPU and serialize" sampling decision be
  254. done outside of the model/sampler; this way the "last-mile" worker
  255. object which interfaces with the scheduler can serialize and incur the
  256. performance hit as necessary. This allows us to run the worker several
  257. iterations in a row without incurring the "move to CPU and serialize"
  258. performance penalty.
  259. Since this requires a large change to Aphrodite, we defer it to later
  260. and temporarily accept this broken abstraction boundary.
  261. NOTE: This will require a special check if the proposer worker
  262. does not have a sampler (e.g. ngram speculation).
  263. """
  264. (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
  265. ) = True
  266. (self.scorer_worker.model_runner.model.sampler.
  267. should_modify_greedy_probs_inplace) = True
  268. self.proposer_worker.set_include_gpu_probs_tensor()
  269. self.proposer_worker.set_should_modify_greedy_probs_inplace()
  270. def determine_num_available_blocks(self) -> Tuple[int, int]:
  271. """Determine the number of cache blocks to use.
  272. This is done by profiling the scorer model (which is typically the
  273. larger of the two). Then the total memory which would be used by the
  274. scorer cache is divided evenly between the proposer and scorer model KV,
  275. such that the number of blocks is equal in both KV caches.
  276. """
  277. num_gpu_blocks, num_cpu_blocks = (
  278. self.scorer_worker.determine_num_available_blocks())
  279. scorer_cache_block_size_bytes = (
  280. self.scorer_worker.get_cache_block_size_bytes())
  281. proposer_cache_block_size_bytes = (
  282. self.proposer_worker.get_cache_block_size_bytes())
  283. new_num_gpu_blocks = split_num_cache_blocks_evenly(
  284. scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
  285. num_gpu_blocks)
  286. return new_num_gpu_blocks, num_cpu_blocks
  287. def initialize_cache(self, num_gpu_blocks: int,
  288. num_cpu_blocks: int) -> None:
  289. """Initialize the cache engine of the scorer and proposer workers.
  290. """
  291. self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  292. num_cpu_blocks=num_cpu_blocks)
  293. self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  294. num_cpu_blocks=num_cpu_blocks)
  295. @torch.inference_mode()
  296. def execute_model(
  297. self,
  298. execute_model_req: Optional[ExecuteModelRequest] = None
  299. ) -> List[SamplerOutput]:
  300. """Perform speculative decoding on the input batch.
  301. """
  302. if self.rank != self._driver_rank:
  303. self._run_non_driver_rank()
  304. return []
  305. if execute_model_req is None:
  306. # This signals that there's no more requests to process for now.
  307. # All workers are running infinite loop with broadcast_tensor_dict,
  308. # and it stops the loop when the driver broadcasts an empty input.
  309. # Send an empty input to notify all other workers to stop their
  310. # execution loop.
  311. broadcast_tensor_dict({}, src=0)
  312. return []
  313. self._track_finished_requests(execute_model_req)
  314. disable_all_speculation = self._should_disable_all_speculation(
  315. execute_model_req)
  316. num_lookahead_slots = execute_model_req.num_lookahead_slots
  317. # Speculative decoding is disabled in the following cases:
  318. # 1. Prefill phase: Speculative decoding is not
  319. # used during the prefill phase.
  320. # 2. Auto-disable enabled: The running queue size exceeds
  321. # the specified threshold.
  322. # 3. No request: There are no requests in the batch, or
  323. # none of the requests in the batch have spec decoding enabled.
  324. # In any of these cases, the proposer and scorer workers
  325. # are called normally.
  326. no_spec = num_lookahead_slots == 0 or disable_all_speculation or all(
  327. sgm.num_speculative_tokens == 0
  328. for sgm in execute_model_req.seq_group_metadata_list)
  329. # Broadcast how many lookahead slots are scheduled for this step, and
  330. # whether all speculation is disabled, to all non-driver workers.
  331. # This is required as if the number of draft model runs changes
  332. # dynamically, the non-driver workers won't know unless we perform a
  333. # communication to inform them.
  334. # no_spec is used to signal non-driver worker about prefill vs decode
  335. # stage. This is needed to ensure that order of execution of proposer
  336. # and scorer is same in both driver and non-driver workers (i.e.,
  337. # scorer -> proposer for prefill and proposer -> scorer in decode). This
  338. # order is needed to support models like EAGLE that take scorer states
  339. # as inputs.
  340. broadcast_dict = dict(
  341. num_lookahead_slots=num_lookahead_slots,
  342. no_spec=no_spec,
  343. disable_all_speculation=disable_all_speculation,
  344. )
  345. broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
  346. assert execute_model_req.seq_group_metadata_list is not None, (
  347. "speculative decoding requires non-None seq_group_metadata_list")
  348. self._maybe_disable_speculative_tokens(
  349. disable_all_speculation, execute_model_req.seq_group_metadata_list)
  350. if no_spec:
  351. return self._run_no_spec(execute_model_req,
  352. skip_proposer=disable_all_speculation)
  353. return self._run_speculative_decoding_step(execute_model_req,
  354. num_lookahead_slots)
  355. @torch.inference_mode()
  356. def start_worker_execution_loop(self) -> None:
  357. """Execute model loop to perform speculative decoding
  358. in parallel worker."""
  359. while self._run_non_driver_rank():
  360. pass
  361. def _should_disable_all_speculation(
  362. self, execute_model_req: ExecuteModelRequest) -> bool:
  363. # When the batch size is too large, disable speculative decoding
  364. # to stop trading off throughput for latency.
  365. return (execute_model_req.running_queue_size >=
  366. self.disable_by_batch_size)
  367. def _maybe_disable_speculative_tokens(
  368. self, disable_all_speculation: bool,
  369. seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
  370. if not disable_all_speculation:
  371. return
  372. for seq_group_metadata in seq_group_metadata_list:
  373. # Once num_speculative_tokens is set to 0, the spec decode
  374. # of this request will be disabled forever.
  375. # TODO: We currently store spec decoding specific
  376. # state in the global data structure, but we should maintain
  377. # this state within spec decode worker.
  378. seq_group_metadata.num_speculative_tokens = 0
  379. def _serialize_sampler_output_no_logprobs(
  380. self, execute_model_req: ExecuteModelRequest,
  381. sampler_output: SamplerOutput) -> SamplerOutput:
  382. """
  383. Creates and returns a `SamplerOutput` with only the sampled token IDs
  384. being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
  385. All other parameters in `CompletionSequenceGroupOutput` related to log
  386. probabilities are skipped.
  387. Args:
  388. execute_model_req (ExecuteModelRequest): The model request that
  389. was executed.
  390. sampler_output (SamplerOutput): The output from the sampler with
  391. only GPU tensors populated.
  392. Returns:
  393. SamplerOutput: A new `SamplerOutput` instance containing a list of
  394. `CompletionSequenceGroupOutput` objects with only sampled token
  395. IDs populated.
  396. """
  397. seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
  398. sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
  399. completion_seq_group_output_list: List[
  400. CompletionSequenceGroupOutput] = []
  401. for index, seq_id in enumerate(seq_ids):
  402. completion_seq_group_output_list.append(
  403. create_sequence_group_output(
  404. token_id=sampled_token_ids_list[index][0],
  405. token_id_logprob_rank=-1,
  406. token_id_logprob=0.0,
  407. seq_id=seq_id,
  408. topk_token_ids=[],
  409. topk_logprobs=[],
  410. ))
  411. return SamplerOutput(outputs=completion_seq_group_output_list)
  412. @nvtx_range("spec_decode_worker._run_no_spec")
  413. def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
  414. skip_proposer: bool) -> List[SamplerOutput]:
  415. """Run a single generation step without any speculation. The input is
  416. sent to the proposer and scorer model so that the KV cache is consistent
  417. between the two. When skip_proposer is True, the proposer model is
  418. not called, meaning that the kv-cache in proposer for requests is not
  419. updated, so they cannot enable spec decode in the rest decoding.
  420. """
  421. sampler_output = self.scorer_worker.execute_model(execute_model_req)
  422. assert len(sampler_output) == 1
  423. sampler_output = sampler_output[0]
  424. # Store hidden states from target model execution.
  425. hidden_states = sampler_output.hidden_states
  426. if hidden_states is not None:
  427. if self.previous_hidden_states is None:
  428. self.previous_hidden_states = HiddenStates(
  429. hidden_states, execute_model_req.seq_group_metadata_list)
  430. else:
  431. self.previous_hidden_states.update(
  432. hidden_states, execute_model_req.seq_group_metadata_list)
  433. if not skip_proposer:
  434. # We prepare the prefill hidden states here so that there no
  435. # additional complexity in worker for spec_decode vs non_spec_decode
  436. # flow and execute_model doesn't need additional modifications.
  437. execute_model_req.previous_hidden_states = \
  438. prepare_prefill_hidden_states(
  439. sampler_output.prefill_hidden_states)
  440. self.proposer_worker.execute_model(execute_model_req)
  441. sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
  442. execute_model_req=execute_model_req, sampler_output=sampler_output)
  443. if self._disable_logprobs else
  444. sampler_output)
  445. # Clear device tensors from sampler output. This reduces communication
  446. # overhead when the engine runs in a different process than the workers.
  447. sampler_output.sampled_token_probs = None
  448. sampler_output.sampled_token_ids = None
  449. sampler_output.logprobs = None
  450. return [sampler_output_to_return]
  451. def _run_non_driver_rank(self) -> bool:
  452. """Run proposer and verifier model in non-driver workers. This is used
  453. for both speculation cases (num_lookahead_slots>0) and non-speculation
  454. cases (e.g. prefill).
  455. Returns True if there are remaining sequences to process.
  456. """
  457. assert self.rank != self._driver_rank
  458. data = broadcast_tensor_dict(src=self._driver_rank)
  459. if not data:
  460. return False
  461. num_lookahead_slots = data["num_lookahead_slots"]
  462. # In case of prefill, scorer_worker has to be run before proposer so
  463. # that the hidden states can be propagated to proposer when needed.
  464. if data["no_spec"]:
  465. self.scorer_worker.execute_model()
  466. if not data["disable_all_speculation"]:
  467. # Even if num_lookahead_slots is zero, we want to run the
  468. # proposer model as it may have KV.
  469. #
  470. # We run the proposer once per lookahead slot. In the future we
  471. # should delegate how many times it runs to the proposer.
  472. for _ in range(max(num_lookahead_slots, 1)):
  473. self.proposer_worker.execute_model()
  474. if not data["no_spec"]:
  475. self.scorer_worker.execute_model()
  476. return True
  477. @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
  478. def _run_speculative_decoding_step(
  479. self, execute_model_req: ExecuteModelRequest,
  480. num_lookahead_slots: int) -> List[SamplerOutput]:
  481. """Execute a single step of speculative decoding.
  482. This invokes the proposer worker to get k speculative tokens for each
  483. sequence, then scores each speculative token using the scoring worker.
  484. Returns a list of SamplerOutput, each containing a single token per
  485. sequence.
  486. """
  487. assert num_lookahead_slots == execute_model_req.num_lookahead_slots
  488. # Pass last hidden states from target model to proposer
  489. execute_model_req.previous_hidden_states = self.previous_hidden_states
  490. self.previous_hidden_states = None
  491. with Timer() as proposal_timer:
  492. # Generate proposals using draft worker.
  493. proposals = self.proposer_worker.get_spec_proposals(
  494. execute_model_req, self._seq_with_bonus_token_in_last_step)
  495. if not self._allow_zero_draft_token_step and proposals.no_proposals:
  496. #TODO: Fix it #5814
  497. raise RuntimeError("Cannot handle cases where distributed draft "
  498. "workers generate no tokens")
  499. execute_model_req.previous_hidden_states = None
  500. with Timer() as scoring_timer:
  501. proposal_scores = self.scorer.score_proposals(
  502. execute_model_req,
  503. proposals,
  504. )
  505. with Timer() as verification_timer:
  506. accepted_token_ids, target_logprobs = self._verify_tokens(
  507. execute_model_req.seq_group_metadata_list, proposal_scores,
  508. proposals, execute_model_req.num_lookahead_slots)
  509. stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
  510. scoring_timer.elapsed_time_ms,
  511. verification_timer.elapsed_time_ms)
  512. return self._create_output_sampler_list(
  513. execute_model_req.seq_group_metadata_list,
  514. accepted_token_ids,
  515. target_logprobs=target_logprobs,
  516. k=execute_model_req.num_lookahead_slots,
  517. stage_times=stage_times)
  518. @nvtx_range("spec_decode_worker._verify_tokens")
  519. def _verify_tokens(
  520. self,
  521. seq_group_metadata_list: List[SequenceGroupMetadata],
  522. proposal_scores: SpeculativeScores,
  523. proposals: SpeculativeProposals,
  524. max_proposal_len: int,
  525. ) -> Tuple[torch.Tensor, torch.Tensor]:
  526. """Determine which speculative tokens are accepted using the
  527. probabilities of each token according to the proposer and scorer models.
  528. Returns a tuple of Tensors, one for the accepted token ids and one for
  529. the logprobs according to the scoring model.
  530. """
  531. proposal_lens_list = proposals.proposal_lens.tolist()
  532. # Aphrodite currently only supports proposal lens equal to zero or the
  533. # batch proposal len. This adds some complexity (splitting the batch
  534. # into spec and non spec sequences) and should be removed in the
  535. # future. It can be done by supporting per-sequence proposal lens.
  536. (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
  537. seq_group_metadata_list, proposal_lens_list)
  538. original_indices = spec_indices + non_spec_indices
  539. # Get probabilities of target model, excluding bonus token.
  540. proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
  541. # Get non-speculative sampled tokens from target model.
  542. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
  543. # Get bonus tokens from target model.
  544. bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
  545. # Get probabilities according to proposal method.
  546. proposal_probs = proposals.proposal_probs[spec_indices]
  547. # Get proposed tokens.
  548. proposal_token_ids = proposals.proposal_token_ids[spec_indices]
  549. # Sampler arguments
  550. sampler_extra_kwargs: Dict[str, Any] = {}
  551. if self.generators and isinstance(self.spec_decode_sampler,
  552. SpecDecodeStochasticBaseSampler):
  553. sampler_extra_kwargs["seeded_seqs"] = {
  554. idx: self.generators[sgm.request_id]
  555. for idx, sgm in enumerate(seq_group_metadata_list)
  556. if sgm.sampling_params.seed is not None
  557. }
  558. accepted_token_ids = self.spec_decode_sampler(
  559. target_probs=proposal_verifier_probs,
  560. bonus_token_ids=bonus_token_ids,
  561. draft_probs=proposal_probs,
  562. draft_token_ids=proposal_token_ids,
  563. **sampler_extra_kwargs,
  564. )
  565. # Append output tokens from non-speculative sequences to
  566. # the accepted token ids tensor.
  567. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
  568. 1).clone()
  569. non_spec_token_ids[:, 1:] = -1
  570. accepted_token_ids = torch.cat(
  571. [accepted_token_ids, non_spec_token_ids])
  572. logprobs = proposal_scores.logprobs
  573. # Rearrange so that results are in the order of the original seq group
  574. # metadata.
  575. accepted_token_ids[original_indices] = accepted_token_ids.clone()
  576. hidden_states = proposal_scores.hidden_states
  577. if hidden_states is not None:
  578. # Contract hidden states based on accepted tokens
  579. hs_size = hidden_states.shape[-1]
  580. accepted_index = accepted_token_ids + 1 # Convert -1 to 0
  581. accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
  582. index = accepted_index[:, None, None].expand(-1, 1, hs_size)
  583. second_last_token_hidden_states = hidden_states[:, -2] # b x d
  584. hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
  585. # Store hidden states from target model for subsequent decode step
  586. self.previous_hidden_states = HiddenStates(
  587. hidden_states, seq_group_metadata_list,
  588. second_last_token_hidden_states)
  589. return accepted_token_ids, logprobs
  590. def _create_output_sampler_list(
  591. self,
  592. seq_group_metadata_list: List[SequenceGroupMetadata],
  593. accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
  594. target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
  595. k: int,
  596. stage_times: Tuple[float, float, float],
  597. ) -> List[SamplerOutput]:
  598. """Given the accepted token ids, create a list of SamplerOutput.
  599. The output is padded with -1 tokens such that each sequence has
  600. the same number of outputs.
  601. """
  602. batch_size, num_steps = accepted_token_ids.shape
  603. accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
  604. if self._disable_logprobs:
  605. # We are skipping the logprobs. Hence don't serialize the
  606. # logprobs related tensors from the GPU. Instead create
  607. # empty/dummy lists.
  608. (accepted_token_id_ranks_by_step,
  609. accepted_token_id_logprobs_by_step,
  610. topk_logprobs_by_step, topk_indices_by_step) =\
  611. self._create_dummy_logprob_lists(
  612. batch_size, num_steps,
  613. self.scorer_worker.model_config.max_logprobs)
  614. else:
  615. # Organize input tensors by step instead of by sequence.
  616. target_logprobs_by_step = target_logprobs.transpose(0, 1)
  617. # Serialize all tensors into Python lists.
  618. (accepted_token_id_ranks_by_step,
  619. accepted_token_id_logprobs_by_step,
  620. topk_logprobs_by_step, topk_indices_by_step) =\
  621. self._create_logprob_lists_from_tensors(
  622. target_logprobs_by_step, accepted_token_ids_by_step,
  623. self.scorer_worker.model_config.max_logprobs)
  624. # Get the sequence ids and num_logprobs (sampling parameter) in the
  625. # batch.
  626. seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
  627. seq_group_metadata_list)
  628. num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
  629. # Serialize tensor to CPU Python list.
  630. accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
  631. # Construct the output on a per-step, per-sequence basis.
  632. sampler_output_list: List[SamplerOutput] = []
  633. for step_index in range(num_steps):
  634. if all(token_id == -1
  635. for token_id in accepted_token_ids_by_step[step_index]):
  636. break
  637. step_output_token_ids: List[CompletionSequenceGroupOutput] = []
  638. for sequence_index in range(batch_size):
  639. # Each sequence may have a different num_logprobs; retrieve it.
  640. num_logprobs = num_logprobs_per_seq[sequence_index]
  641. step_output_token_ids.append(
  642. create_sequence_group_output(
  643. token_id=accepted_token_ids_by_step[step_index]
  644. [sequence_index],
  645. token_id_logprob_rank=accepted_token_id_ranks_by_step[
  646. step_index][sequence_index],
  647. token_id_logprob=accepted_token_id_logprobs_by_step[
  648. step_index][sequence_index],
  649. seq_id=seq_ids[sequence_index],
  650. topk_token_ids=topk_indices_by_step[step_index]
  651. [sequence_index][:num_logprobs],
  652. topk_logprobs=topk_logprobs_by_step[step_index]
  653. [sequence_index][:num_logprobs],
  654. ))
  655. sampler_output_list.append(
  656. SamplerOutput(outputs=step_output_token_ids))
  657. # Populate the data structures needed to keep track of sequences with
  658. # bonus tokens.
  659. self._track_sequences_with_bonus_tokens(seq_ids,
  660. request_ids_seq_ids_mapping,
  661. accepted_token_ids_by_step)
  662. maybe_rejsample_metrics = (
  663. self._metrics.maybe_collect_rejsample_metrics(k))
  664. if maybe_rejsample_metrics is not None:
  665. sampler_output_list[
  666. 0].spec_decode_worker_metrics = maybe_rejsample_metrics
  667. # Log time spent in each stage periodically.
  668. # This is periodic because the rejection sampler emits metrics
  669. # periodically.
  670. self._maybe_log_stage_times(*stage_times)
  671. return sampler_output_list
  672. def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
  673. scoring_time_ms: float,
  674. verification_time_ms: float) -> None:
  675. """Log the speculative stage times. If stat logging is disabled, do
  676. nothing.
  677. """
  678. if self._disable_log_stats:
  679. return
  680. logger.info(f"SpecDecodeWorker stage times: "
  681. f"average_time_per_proposal_tok_ms="
  682. f"{average_time_per_proposal_tok_ms:.02f} "
  683. f"scoring_time_ms={scoring_time_ms:.02f} "
  684. f"verification_time_ms={verification_time_ms:.02f}")
  685. def _create_dummy_logprob_lists(
  686. self,
  687. batch_size: int,
  688. num_steps: int,
  689. num_top_k: int,
  690. ) -> Tuple[List[List[int]], List[List[float]],
  691. List[List[List[Optional[float]]]],
  692. List[List[List[Optional[int]]]]]:
  693. """
  694. Creates and returns four dummy lists representing token probabilities
  695. and their ranks.
  696. This method initializes and returns:
  697. - The ranks of the accepted tokens, shaped (num_steps, batch_size)
  698. - The log probabilities of the accepted tokens,
  699. shaped (num_steps, batch_size)
  700. - The log probabilities of the top k tokens,
  701. shaped (num_steps, batch_size, num_top_k)
  702. - The token IDs of the top k tokens,
  703. shaped (num_steps, batch_size, num_top_k)
  704. Args:
  705. batch_size (int): The size of the batch.
  706. num_steps (int): The number of steps in the sequence.
  707. num_top_k (int): The number of top-k token log probabilities to
  708. return.
  709. Returns:
  710. A tuple containing four dummy lists as described above.
  711. """
  712. accepted_token_id_ranks_by_step = [[-1] * batch_size
  713. for _ in range(num_steps)]
  714. accepted_token_id_logprobs_by_step = [[0.0] * batch_size
  715. for _ in range(num_steps)]
  716. topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
  717. [None] * num_top_k for _ in range(batch_size)
  718. ] for _ in range(num_steps)]
  719. topk_indices_by_step: List[List[List[Optional[int]]]] = [[
  720. [None] * num_top_k for _ in range(batch_size)
  721. ] for _ in range(num_steps)]
  722. return (accepted_token_id_ranks_by_step,
  723. accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
  724. topk_indices_by_step)
  725. def _create_logprob_lists_from_tensors(
  726. self,
  727. target_logprobs_by_step: torch.Tensor,
  728. accepted_token_ids_by_step: torch.Tensor,
  729. num_top_k: int,
  730. ) -> Tuple[List[List[int]], List[List[float]],
  731. List[List[List[Optional[float]]]],
  732. List[List[List[Optional[int]]]]]:
  733. """
  734. Creates and returns four lists representing token probabilities and
  735. their ranks.
  736. This method initializes and returns four lists containing:
  737. - The ranks of the accepted tokens, shaped (num_steps, batch_size)
  738. - The log probabilities of the accepted tokens,
  739. shaped (num_steps, batch_size)
  740. - The log probabilities of the top k tokens,
  741. shaped (num_steps, batch_size, num_top_k)
  742. - The token IDs of the top k tokens,
  743. shaped (num_steps, batch_size, num_top_k)
  744. Args:
  745. target_logprobs_by_step (torch.Tensor): Tensor representing the
  746. log probabilities of the target model,
  747. shaped (num_steps, batch_size, vocab_size)
  748. accepted_token_ids_by_step (torch.Tensor): Tensor representing
  749. the accepted token_ids, shaped (num_steps, batch_size)
  750. num_top_k (int): The number of top-k token log probabilities to
  751. return.
  752. Returns:
  753. A tuple containing the lists as described above.
  754. """
  755. # Serialize all tensors to CPU Python lists.
  756. # Get the logprobs/rank of the accepted tokens.
  757. (accepted_token_id_ranks_by_step_tensor,
  758. accepted_token_id_logprobs_by_step_tensor
  759. ) = get_sampled_token_logprobs(
  760. logprob_tensor=target_logprobs_by_step,
  761. sampled_token_ids=accepted_token_ids_by_step,
  762. )
  763. # Get the top-k logprobs (which may or may not include the
  764. # logprob of the accepted token).
  765. (topk_logprobs_by_step_tensor,
  766. topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
  767. k=num_top_k,
  768. dim=-1,
  769. )
  770. accepted_token_id_ranks_by_step = (
  771. accepted_token_id_ranks_by_step_tensor.tolist())
  772. accepted_token_id_logprobs_by_step = (
  773. accepted_token_id_logprobs_by_step_tensor.tolist())
  774. topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
  775. topk_indices_by_step = topk_indices_by_step_tensor.tolist()
  776. return (accepted_token_id_ranks_by_step,
  777. accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
  778. topk_indices_by_step)
  779. def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
  780. """
  781. Removes the finished requests and their associated sequence ids from
  782. internal book keeping data structures.
  783. """
  784. for finished_request in execute_model_req.finished_requests_ids:
  785. for seq_id in self._request_id_seq_id_mapping[finished_request]:
  786. self._seq_with_bonus_token_in_last_step.discard(seq_id)
  787. del self._request_id_seq_id_mapping[finished_request]
  788. def _track_sequences_with_bonus_tokens(
  789. self, seq_ids: List[int],
  790. request_ids_seq_ids_mapping: Dict[str, Set[int]],
  791. accepted_token_ids_by_step: List[List[int]]):
  792. """
  793. Updates the internal data structures which keep track of sequences
  794. which have been assigned bonus tokens in their last forward pass.
  795. """
  796. for seq_index, seq_id in enumerate(seq_ids):
  797. last_token_id = accepted_token_ids_by_step[-1][seq_index]
  798. if last_token_id == -1:
  799. self._seq_with_bonus_token_in_last_step.discard(seq_id)
  800. else:
  801. self._seq_with_bonus_token_in_last_step.add(seq_id)
  802. for request_id, sequences in request_ids_seq_ids_mapping.items():
  803. self._request_id_seq_id_mapping[request_id].update(sequences)
  804. @cached_property
  805. def _vocab_size(self) -> int:
  806. """Get the vocab size of the model and make sure it's consistent between
  807. draft and target workers.
  808. """
  809. vocab_sizes = [
  810. worker.vocab_size
  811. for worker in [self.proposer_worker, self.scorer_worker]
  812. ]
  813. assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
  814. return vocab_sizes[0]
  815. @property
  816. def rank(self):
  817. return self.scorer_worker.rank
  818. @property
  819. def device(self):
  820. return self.scorer_worker.device
  821. @property
  822. def _driver_rank(self) -> int:
  823. return 0
  824. def get_cache_block_size_bytes(self):
  825. """Return the size of a cache block in bytes.
  826. This function is only used to compose workers within a SpecDecodeWorker.
  827. We leave composing a SpecDecodeWorker within a SpecDecodeWorker
  828. undefined for now, although it could be implemented in the future.
  829. See https://arxiv.org/abs/2308.04623.
  830. """
  831. raise NotImplementedError
  832. def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
  833. proposer_cache_block_size_bytes: int,
  834. total_num_gpu_blocks: int) -> int:
  835. """Given total_num_gpu_blocks, the number of GPU blocks that could be
  836. allocate to the target model, this function calculates how many blocks
  837. should be given to the draft and target model.
  838. Note that usually the block size, in bytes, of each model is different,
  839. as it's a function of number of KV/layer, number of heads, and hidden
  840. dimension size.
  841. Since the target and draft models allocate the same number of blocks, we
  842. simply calculate the number of blocks where if allocated by both models,
  843. the total memory usage from KV cache is no larger than the number of
  844. blocks allocatable by the target model alone.
  845. """
  846. new_num_gpu_blocks = int(
  847. total_num_gpu_blocks * scorer_cache_block_size_bytes /
  848. (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
  849. return new_num_gpu_blocks
  850. def prepare_prefill_hidden_states(
  851. prefill_hidden_states: torch.Tensor) -> HiddenStates:
  852. # For prefill step in proposer, we run the model for N-1 tokens
  853. # because Nth token will be processed in the first decode step. For
  854. # N-1 tokens, the input should be 0:N-1 hidden states which should
  855. # be concatanated with 1:N token (since output of scorer has to be
  856. # the input for proposer). Therefore, we shift the hidden states to
  857. # align n-1th hidden state with nth token.
  858. return HiddenStates(prefill_hidden_states.roll(
  859. shifts=1, dims=0)) if prefill_hidden_states is not None else None