spec_decode_worker.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981
  1. from collections import defaultdict
  2. from functools import cached_property
  3. from typing import Any, Dict, List, Optional, Set, Tuple
  4. import torch
  5. from loguru import logger
  6. from aphrodite.common.config import ParallelConfig, SpeculativeConfig
  7. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  8. ExecuteModelRequest, HiddenStates,
  9. SequenceGroupMetadata, get_all_seq_ids,
  10. get_all_seq_ids_and_request_ids)
  11. from aphrodite.distributed.communication_op import broadcast_tensor_dict
  12. from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
  13. from aphrodite.modeling.layers.sampler import SamplerOutput
  14. from aphrodite.modeling.layers.spec_decode_base_sampler import (
  15. SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
  16. from aphrodite.modeling.layers.typical_acceptance_sampler import (
  17. TypicalAcceptanceSampler)
  18. from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
  19. from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
  20. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  21. SpeculativeScorer,
  22. SpeculativeScores)
  23. from aphrodite.spec_decode.medusa_worker import MedusaWorker
  24. from aphrodite.spec_decode.metrics import AsyncMetricsCollector
  25. from aphrodite.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
  26. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  27. from aphrodite.spec_decode.ngram_worker import NGramWorker
  28. from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
  29. from aphrodite.spec_decode.smaller_tp_proposer_worker import (
  30. SmallerTpProposerWorker)
  31. from aphrodite.spec_decode.target_model_runner import TargetModelRunner
  32. from aphrodite.spec_decode.util import (Timer, create_sequence_group_output,
  33. get_all_num_logprobs,
  34. get_sampled_token_logprobs, nvtx_range,
  35. split_batch_by_proposal_len)
  36. from aphrodite.worker.worker import Worker
  37. from aphrodite.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
  38. def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
  39. """Helper method that is the entrypoint for Executors which use
  40. WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
  41. """
  42. assert "speculative_config" in kwargs
  43. speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
  44. assert speculative_config is not None
  45. draft_worker_kwargs = kwargs.copy()
  46. kwargs["model_runner_cls"] = TargetModelRunner
  47. target_worker = Worker(*args, **kwargs)
  48. # Set the disable_logprobs variable in the TargetModelRunner instance
  49. # as per its value specified in the SpeculativeConfig.
  50. target_worker.model_runner.disable_logprobs =\
  51. speculative_config.disable_logprobs
  52. # Override draft-model specific worker args.
  53. draft_worker_kwargs.update(
  54. model_config=speculative_config.draft_model_config,
  55. parallel_config=speculative_config.draft_parallel_config,
  56. ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
  57. ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
  58. # TODO allow draft-model specific load config.
  59. #load_config=load_config,
  60. )
  61. spec_decode_worker = SpecDecodeWorker.create_worker(
  62. scorer_worker=target_worker,
  63. draft_worker_kwargs=draft_worker_kwargs,
  64. disable_by_batch_size=speculative_config.
  65. speculative_disable_by_batch_size,
  66. draft_token_acceptance_method=speculative_config.
  67. draft_token_acceptance_method,
  68. typical_acceptance_sampler_posterior_threshold=speculative_config.
  69. typical_acceptance_sampler_posterior_threshold,
  70. typical_acceptance_sampler_posterior_alpha=speculative_config.
  71. typical_acceptance_sampler_posterior_alpha,
  72. disable_logprobs=speculative_config.disable_logprobs,
  73. disable_log_stats=speculative_config.disable_log_stats,
  74. )
  75. return spec_decode_worker
  76. class SpecDecodeWorker(LoraNotSupportedWorkerBase):
  77. """Worker which implements speculative decoding.
  78. Speculative decoding reduces decoding per-token latency by using a proposal
  79. method, such as a small draft model, to speculate ahead of a larger LLM. The
  80. probabilities of the speculative tokens are then determined by the larger
  81. LLM, after which some verification routine determines which (if any) of the
  82. speculative tokens are accepted by the larger LLM.
  83. The current implementation has the following limitations:
  84. * Only draft-model proposal is implemented (contributions for more forms are
  85. welcome!).
  86. * Only top-1 proposal and scoring are implemented. Tree-attention is left as
  87. future work.
  88. * All sequences in a batch must have the same proposal length, or zero. This
  89. can be improved by having per-sequence speculation in the future.
  90. * The scoring forward pass is done without an MQA kernel, which is
  91. suboptimal especially as the batch size, proposal length, and sequence
  92. lengths grow. Contributions to add a MQA scoring are welcome once
  93. correctness tests pass.
  94. """
  95. @classmethod
  96. def create_worker(
  97. cls,
  98. scorer_worker: Worker,
  99. draft_worker_kwargs: Dict[str, Any],
  100. disable_by_batch_size: Optional[int],
  101. draft_token_acceptance_method: str,
  102. typical_acceptance_sampler_posterior_threshold: float,
  103. typical_acceptance_sampler_posterior_alpha: float,
  104. disable_logprobs: bool,
  105. disable_log_stats: bool,
  106. ) -> "SpecDecodeWorker":
  107. allow_zero_draft_token_step = True
  108. ngram_prompt_lookup_max = (
  109. draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
  110. ngram_prompt_lookup_min = (
  111. draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
  112. if ngram_prompt_lookup_max > 0:
  113. proposer_worker = NGramWorker(**draft_worker_kwargs)
  114. proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
  115. ngram_prompt_lookup_max)
  116. else:
  117. draft_parallel_config: ParallelConfig = draft_worker_kwargs[
  118. 'parallel_config']
  119. draft_tp = draft_parallel_config.tensor_parallel_size
  120. target_tp = scorer_worker.parallel_config.tensor_parallel_size
  121. if draft_worker_kwargs[
  122. "model_config"].hf_config.model_type == "mlp_speculator":
  123. proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
  124. elif draft_worker_kwargs[
  125. "model_config"].hf_config.model_type == "medusa":
  126. proposer_worker = MedusaWorker(**draft_worker_kwargs)
  127. else:
  128. if draft_tp == 1:
  129. draft_worker_kwargs[
  130. "model_runner_cls"] = TP1DraftModelRunner
  131. else:
  132. if draft_worker_kwargs[
  133. "model_config"].hf_config.model_type == "eagle":
  134. raise NotImplementedError(
  135. "EAGLE does not support TP > 1 yet")
  136. allow_zero_draft_token_step = False
  137. proposer_worker = MultiStepWorker(**draft_worker_kwargs)
  138. proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
  139. proposer_worker, draft_tp, target_tp)
  140. logger.info("Configuring SpecDecodeWorker with "
  141. f"proposer={type(proposer_worker)}")
  142. spec_decode_sampler: SpecDecodeBaseSampler = None
  143. if draft_token_acceptance_method == "rejection_sampler":
  144. spec_decode_sampler = RejectionSampler()
  145. elif draft_token_acceptance_method == "typical_acceptance_sampler":
  146. spec_decode_sampler = TypicalAcceptanceSampler(
  147. posterior_threshold=\
  148. typical_acceptance_sampler_posterior_threshold,
  149. posterior_alpha=typical_acceptance_sampler_posterior_alpha,
  150. )
  151. logger.info("Configuring SpecDecodeWorker with "
  152. f"sampler={type(spec_decode_sampler)}")
  153. return SpecDecodeWorker(
  154. proposer_worker,
  155. scorer_worker,
  156. disable_logprobs=disable_logprobs,
  157. disable_log_stats=disable_log_stats,
  158. disable_by_batch_size=disable_by_batch_size,
  159. spec_decode_sampler=spec_decode_sampler,
  160. allow_zero_draft_token_step=allow_zero_draft_token_step)
  161. def __init__(
  162. self,
  163. proposer_worker: ProposerWorkerBase,
  164. scorer_worker: WorkerBase,
  165. spec_decode_sampler: SpecDecodeBaseSampler,
  166. disable_logprobs: bool = False,
  167. disable_log_stats: bool = False,
  168. metrics_collector: Optional[AsyncMetricsCollector] = None,
  169. disable_by_batch_size: Optional[int] = None,
  170. allow_zero_draft_token_step: Optional[bool] = True,
  171. ):
  172. """
  173. Create a SpecDecodeWorker.
  174. Args:
  175. proposer_worker: A worker that can produce speculative tokens for
  176. sequences.
  177. scorer_worker: A worker that produces probabilities of speculative
  178. tokens according to some base model. Typically a vanilla
  179. Aphrodite Worker.
  180. spec_decode_sampler: A Torch module used to perform acceptance
  181. sampling of the draft tokens in the verification step of
  182. speculative decoding. Currently we support two different
  183. types of sampler namely RejectionSampler and
  184. TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
  185. instance of RejectionSampler or TypicalAcceptanceSampler.
  186. disable_logprobs: If set to True, token log probabilities will
  187. not be output in both the draft worker and the target worker.
  188. If set to False, log probabilities will be output by both.
  189. disable_log_stats: If set to True, disable periodic printing of
  190. speculative stage times.
  191. disable_by_batch_size: If the batch size is larger than this,
  192. disable speculative decoding for new incoming requests.
  193. metrics_collector: Helper class for collecting metrics; can be set
  194. for testing purposes.
  195. allow_zero_draft_token_step: whether to allow a step where the draft
  196. model generates no draft token; should disallow when the tp of
  197. draft model is larger than 1
  198. """
  199. self.proposer_worker = proposer_worker
  200. self.scorer_worker = scorer_worker
  201. scorer_runner = getattr(self.scorer_worker, "model_runner", None)
  202. self.generators = scorer_runner.get_generators(
  203. ) if scorer_runner else None
  204. self.disable_by_batch_size = disable_by_batch_size or float("inf")
  205. self.spec_decode_sampler = spec_decode_sampler
  206. self._allow_zero_draft_token_step = allow_zero_draft_token_step
  207. self._metrics = AsyncMetricsCollector(
  208. self.spec_decode_sampler
  209. ) if metrics_collector is None else metrics_collector
  210. # Tracks the sequence IDs that received a bonus token ID in
  211. # their last forward pass. Needed only if KV cache is being
  212. # used for token generation such as in the case of MultiStepWorker.
  213. self._seq_with_bonus_token_in_last_step: Set[int] = set()
  214. # Tracks the currently active request ids and the sequence IDs
  215. # corresponding to them
  216. self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
  217. # Tracks if the proposer worker uses the KV cache or not.
  218. self.probs_dtype = self.spec_decode_sampler.probs_dtype
  219. self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
  220. # Lazy initialization.
  221. self.scorer: SpeculativeScorer
  222. # Hidden states from target model to pass to proposer
  223. # in the subsequent step.
  224. self.previous_hidden_states: Optional[HiddenStates] = None
  225. self._disable_logprobs = disable_logprobs
  226. self._disable_log_stats = disable_log_stats
  227. def init_device(self) -> None:
  228. """Initialize both scorer and proposer models.
  229. """
  230. # The scorer worker model is initialized first in case the proposer
  231. # model has a smaller TP degree than the target worker.
  232. self.scorer_worker.init_device()
  233. self.proposer_worker.init_device()
  234. # NOTE: load_model is not part of the WorkerBase interface.
  235. self.scorer_worker.load_model()
  236. self.proposer_worker.load_model()
  237. self._metrics.init_gpu_tensors(self.rank)
  238. self.spec_decode_sampler.init_gpu_tensors(self.rank)
  239. self.scorer = BatchExpansionTop1Scorer(
  240. scorer_worker=self.scorer_worker,
  241. device=self.device,
  242. vocab_size=self._vocab_size)
  243. self._configure_model_sampler_for_spec_decode()
  244. def load_model(self, *args, **kwargs):
  245. pass
  246. def _configure_model_sampler_for_spec_decode(self):
  247. """Configure model sampler to emit GPU tensors. This allows spec decode
  248. to keep data on device without transferring to CPU and serializing,
  249. which significantly reduces overhead of sampling during verification.
  250. NOTE: This breaks abstraction boundaries pretty badly. The better
  251. design is to have the "move to CPU and serialize" sampling decision be
  252. done outside of the model/sampler; this way the "last-mile" worker
  253. object which interfaces with the scheduler can serialize and incur the
  254. performance hit as necessary. This allows us to run the worker several
  255. iterations in a row without incurring the "move to CPU and serialize"
  256. performance penalty.
  257. Since this requires a large change to Aphrodite, we defer it to later
  258. and temporarily accept this broken abstraction boundary.
  259. NOTE: This will require a special check if the proposer worker
  260. does not have a sampler (e.g. ngram speculation).
  261. """
  262. (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
  263. ) = True
  264. (self.scorer_worker.model_runner.model.sampler.
  265. should_modify_greedy_probs_inplace) = True
  266. self.proposer_worker.set_include_gpu_probs_tensor()
  267. self.proposer_worker.set_should_modify_greedy_probs_inplace()
  268. def determine_num_available_blocks(self) -> Tuple[int, int]:
  269. """Determine the number of cache blocks to use.
  270. This is done by profiling the scorer model (which is typically the
  271. larger of the two). Then the total memory which would be used by the
  272. scorer cache is divided evenly between the proposer and scorer model KV,
  273. such that the number of blocks is equal in both KV caches.
  274. """
  275. num_gpu_blocks, num_cpu_blocks = (
  276. self.scorer_worker.determine_num_available_blocks())
  277. scorer_cache_block_size_bytes = (
  278. self.scorer_worker.get_cache_block_size_bytes())
  279. proposer_cache_block_size_bytes = (
  280. self.proposer_worker.get_cache_block_size_bytes())
  281. new_num_gpu_blocks = split_num_cache_blocks_evenly(
  282. scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
  283. num_gpu_blocks)
  284. return new_num_gpu_blocks, num_cpu_blocks
  285. def initialize_cache(self, num_gpu_blocks: int,
  286. num_cpu_blocks: int) -> None:
  287. """Initialize the cache engine of the scorer and proposer workers.
  288. """
  289. self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  290. num_cpu_blocks=num_cpu_blocks)
  291. self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  292. num_cpu_blocks=num_cpu_blocks)
  293. @torch.inference_mode()
  294. def execute_model(
  295. self,
  296. execute_model_req: Optional[ExecuteModelRequest] = None
  297. ) -> List[SamplerOutput]:
  298. """Perform speculative decoding on the input batch.
  299. """
  300. if self.rank != self._driver_rank:
  301. self._run_non_driver_rank()
  302. return []
  303. if execute_model_req is None:
  304. # This signals that there's no more requests to process for now.
  305. # All workers are running infinite loop with broadcast_tensor_dict,
  306. # and it stops the loop when the driver broadcasts an empty input.
  307. # Send an empty input to notify all other workers to stop their
  308. # execution loop.
  309. broadcast_tensor_dict({}, src=0)
  310. return []
  311. self._track_finished_requests(execute_model_req)
  312. disable_all_speculation = self._should_disable_all_speculation(
  313. execute_model_req)
  314. num_lookahead_slots = execute_model_req.num_lookahead_slots
  315. # Speculative decoding is disabled in the following cases:
  316. # 1. Prefill phase: Speculative decoding is not
  317. # used during the prefill phase.
  318. # 2. Auto-disable enabled: The running queue size exceeds
  319. # the specified threshold.
  320. # 3. No request: There are no requests in the batch, or
  321. # none of the requests in the batch have spec decoding enabled.
  322. # In any of these cases, the proposer and scorer workers
  323. # are called normally.
  324. no_spec = num_lookahead_slots == 0 or disable_all_speculation or all(
  325. sgm.num_speculative_tokens == 0
  326. for sgm in execute_model_req.seq_group_metadata_list)
  327. # Broadcast how many lookahead slots are scheduled for this step, and
  328. # whether all speculation is disabled, to all non-driver workers.
  329. # This is required as if the number of draft model runs changes
  330. # dynamically, the non-driver workers won't know unless we perform a
  331. # communication to inform them.
  332. # no_spec is used to signal non-driver worker about prefill vs decode
  333. # stage. This is needed to ensure that order of execution of proposer
  334. # and scorer is same in both driver and non-driver workers (i.e.,
  335. # scorer -> proposer for prefill and proposer -> scorer in decode). This
  336. # order is needed to support models like EAGLE that take scorer states
  337. # as inputs.
  338. broadcast_dict = dict(
  339. num_lookahead_slots=num_lookahead_slots,
  340. no_spec=no_spec,
  341. disable_all_speculation=disable_all_speculation,
  342. )
  343. broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
  344. assert execute_model_req.seq_group_metadata_list is not None, (
  345. "speculative decoding requires non-None seq_group_metadata_list")
  346. self._maybe_disable_speculative_tokens(
  347. disable_all_speculation, execute_model_req.seq_group_metadata_list)
  348. if no_spec:
  349. return self._run_no_spec(execute_model_req,
  350. skip_proposer=disable_all_speculation)
  351. return self._run_speculative_decoding_step(execute_model_req,
  352. num_lookahead_slots)
  353. @torch.inference_mode()
  354. def start_worker_execution_loop(self) -> None:
  355. """Execute model loop to perform speculative decoding
  356. in parallel worker."""
  357. while self._run_non_driver_rank():
  358. pass
  359. def _should_disable_all_speculation(
  360. self, execute_model_req: ExecuteModelRequest) -> bool:
  361. # When the batch size is too large, disable speculative decoding
  362. # to stop trading off throughput for latency.
  363. return (execute_model_req.running_queue_size >=
  364. self.disable_by_batch_size)
  365. def _maybe_disable_speculative_tokens(
  366. self, disable_all_speculation: bool,
  367. seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
  368. if not disable_all_speculation:
  369. return
  370. for seq_group_metadata in seq_group_metadata_list:
  371. # Once num_speculative_tokens is set to 0, the spec decode
  372. # of this request will be disabled forever.
  373. # TODO: We currently store spec decoding specific
  374. # state in the global data structure, but we should maintain
  375. # this state within spec decode worker.
  376. seq_group_metadata.num_speculative_tokens = 0
  377. def _serialize_sampler_output_no_logprobs(
  378. self, execute_model_req: ExecuteModelRequest,
  379. sampler_output: SamplerOutput) -> SamplerOutput:
  380. """
  381. Creates and returns a `SamplerOutput` with only the sampled token IDs
  382. being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
  383. All other parameters in `CompletionSequenceGroupOutput` related to log
  384. probabilities are skipped.
  385. Args:
  386. execute_model_req (ExecuteModelRequest): The model request that
  387. was executed.
  388. sampler_output (SamplerOutput): The output from the sampler with
  389. only GPU tensors populated.
  390. Returns:
  391. SamplerOutput: A new `SamplerOutput` instance containing a list of
  392. `CompletionSequenceGroupOutput` objects with only sampled token
  393. IDs populated.
  394. """
  395. seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
  396. sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
  397. completion_seq_group_output_list: List[
  398. CompletionSequenceGroupOutput] = []
  399. for index, seq_id in enumerate(seq_ids):
  400. completion_seq_group_output_list.append(
  401. create_sequence_group_output(
  402. token_id=sampled_token_ids_list[index][0],
  403. token_id_logprob_rank=-1,
  404. token_id_logprob=0.0,
  405. seq_id=seq_id,
  406. topk_token_ids=[],
  407. topk_logprobs=[],
  408. ))
  409. return SamplerOutput(outputs=completion_seq_group_output_list)
  410. @nvtx_range("spec_decode_worker._run_no_spec")
  411. def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
  412. skip_proposer: bool) -> List[SamplerOutput]:
  413. """Run a single generation step without any speculation. The input is
  414. sent to the proposer and scorer model so that the KV cache is consistent
  415. between the two. When skip_proposer is True, the proposer model is
  416. not called, meaning that the kv-cache in proposer for requests is not
  417. updated, so they cannot enable spec decode in the rest decoding.
  418. """
  419. sampler_output = self.scorer_worker.execute_model(execute_model_req)
  420. assert len(sampler_output) == 1
  421. sampler_output = sampler_output[0]
  422. # Store hidden states from target model execution.
  423. hidden_states = sampler_output.hidden_states
  424. if hidden_states is not None:
  425. if self.previous_hidden_states is None:
  426. self.previous_hidden_states = HiddenStates(
  427. hidden_states, execute_model_req.seq_group_metadata_list)
  428. else:
  429. self.previous_hidden_states.update(
  430. hidden_states, execute_model_req.seq_group_metadata_list)
  431. if not skip_proposer:
  432. # We prepare the prefill hidden states here so that there no
  433. # additional complexity in worker for spec_decode vs non_spec_decode
  434. # flow and execute_model doesn't need additional modifications.
  435. execute_model_req.previous_hidden_states = \
  436. prepare_prefill_hidden_states(
  437. sampler_output.prefill_hidden_states)
  438. self.proposer_worker.execute_model(execute_model_req)
  439. sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
  440. execute_model_req=execute_model_req, sampler_output=sampler_output)
  441. if self._disable_logprobs else
  442. sampler_output)
  443. # Clear device tensors from sampler output. This reduces communication
  444. # overhead when the engine runs in a different process than the workers.
  445. sampler_output.sampled_token_probs = None
  446. sampler_output.sampled_token_ids = None
  447. sampler_output.logprobs = None
  448. return [sampler_output_to_return]
  449. def _run_non_driver_rank(self) -> bool:
  450. """Run proposer and verifier model in non-driver workers. This is used
  451. for both speculation cases (num_lookahead_slots>0) and non-speculation
  452. cases (e.g. prefill).
  453. Returns True if there are remaining sequences to process.
  454. """
  455. assert self.rank != self._driver_rank
  456. data = broadcast_tensor_dict(src=self._driver_rank)
  457. if not data:
  458. return False
  459. num_lookahead_slots = data["num_lookahead_slots"]
  460. # In case of prefill, scorer_worker has to be run before proposer so
  461. # that the hidden states can be propagated to proposer when needed.
  462. if data["no_spec"]:
  463. self.scorer_worker.execute_model()
  464. if not data["disable_all_speculation"]:
  465. # Even if num_lookahead_slots is zero, we want to run the
  466. # proposer model as it may have KV.
  467. #
  468. # We run the proposer once per lookahead slot. In the future we
  469. # should delegate how many times it runs to the proposer.
  470. for _ in range(max(num_lookahead_slots, 1)):
  471. self.proposer_worker.execute_model()
  472. if not data["no_spec"]:
  473. self.scorer_worker.execute_model()
  474. return True
  475. @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
  476. def _run_speculative_decoding_step(
  477. self, execute_model_req: ExecuteModelRequest,
  478. num_lookahead_slots: int) -> List[SamplerOutput]:
  479. """Execute a single step of speculative decoding.
  480. This invokes the proposer worker to get k speculative tokens for each
  481. sequence, then scores each speculative token using the scoring worker.
  482. Returns a list of SamplerOutput, each containing a single token per
  483. sequence.
  484. """
  485. assert num_lookahead_slots == execute_model_req.num_lookahead_slots
  486. # Pass last hidden states from target model to proposer
  487. execute_model_req.previous_hidden_states = self.previous_hidden_states
  488. self.previous_hidden_states = None
  489. with Timer() as proposal_timer:
  490. # Generate proposals using draft worker.
  491. proposals = self.proposer_worker.get_spec_proposals(
  492. execute_model_req, self._seq_with_bonus_token_in_last_step)
  493. if not self._allow_zero_draft_token_step and proposals.no_proposals:
  494. #TODO: Fix it #5814
  495. raise RuntimeError("Cannot handle cases where distributed draft "
  496. "workers generate no tokens")
  497. execute_model_req.previous_hidden_states = None
  498. with Timer() as scoring_timer:
  499. proposal_scores = self.scorer.score_proposals(
  500. execute_model_req,
  501. proposals,
  502. )
  503. with Timer() as verification_timer:
  504. accepted_token_ids, target_logprobs = self._verify_tokens(
  505. execute_model_req.seq_group_metadata_list, proposal_scores,
  506. proposals, execute_model_req.num_lookahead_slots)
  507. stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
  508. scoring_timer.elapsed_time_ms,
  509. verification_timer.elapsed_time_ms)
  510. return self._create_output_sampler_list(
  511. execute_model_req.seq_group_metadata_list,
  512. accepted_token_ids,
  513. target_logprobs=target_logprobs,
  514. k=execute_model_req.num_lookahead_slots,
  515. stage_times=stage_times)
  516. @nvtx_range("spec_decode_worker._verify_tokens")
  517. def _verify_tokens(
  518. self,
  519. seq_group_metadata_list: List[SequenceGroupMetadata],
  520. proposal_scores: SpeculativeScores,
  521. proposals: SpeculativeProposals,
  522. max_proposal_len: int,
  523. ) -> Tuple[torch.Tensor, torch.Tensor]:
  524. """Determine which speculative tokens are accepted using the
  525. probabilities of each token according to the proposer and scorer models.
  526. Returns a tuple of Tensors, one for the accepted token ids and one for
  527. the logprobs according to the scoring model.
  528. """
  529. proposal_lens_list = proposals.proposal_lens.tolist()
  530. # Aphrodite currently only supports proposal lens equal to zero or the
  531. # batch proposal len. This adds some complexity (splitting the batch
  532. # into spec and non spec sequences) and should be removed in the
  533. # future. It can be done by supporting per-sequence proposal lens.
  534. (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
  535. seq_group_metadata_list, proposal_lens_list)
  536. original_indices = spec_indices + non_spec_indices
  537. # Get probabilities of target model, excluding bonus token.
  538. proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
  539. # Get non-speculative sampled tokens from target model.
  540. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
  541. # Get bonus tokens from target model.
  542. bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
  543. # Get probabilities according to proposal method.
  544. proposal_probs = proposals.proposal_probs[spec_indices]
  545. # Get proposed tokens.
  546. proposal_token_ids = proposals.proposal_token_ids[spec_indices]
  547. # Sampler arguments
  548. sampler_extra_kwargs: Dict[str, Any] = {}
  549. if self.generators and isinstance(self.spec_decode_sampler,
  550. SpecDecodeStochasticBaseSampler):
  551. sampler_extra_kwargs["seeded_seqs"] = {
  552. idx: self.generators[sgm.request_id]
  553. for idx, sgm in enumerate(seq_group_metadata_list)
  554. if sgm.sampling_params.seed is not None
  555. }
  556. accepted_token_ids = self.spec_decode_sampler(
  557. target_probs=proposal_verifier_probs,
  558. bonus_token_ids=bonus_token_ids,
  559. draft_probs=proposal_probs,
  560. draft_token_ids=proposal_token_ids,
  561. **sampler_extra_kwargs,
  562. )
  563. # Append output tokens from non-speculative sequences to
  564. # the accepted token ids tensor.
  565. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
  566. 1).clone()
  567. non_spec_token_ids[:, 1:] = -1
  568. accepted_token_ids = torch.cat(
  569. [accepted_token_ids, non_spec_token_ids])
  570. logprobs = proposal_scores.logprobs
  571. # Rearrange so that results are in the order of the original seq group
  572. # metadata.
  573. accepted_token_ids[original_indices] = accepted_token_ids.clone()
  574. hidden_states = proposal_scores.hidden_states
  575. if hidden_states is not None:
  576. # Contract hidden states based on accepted tokens
  577. hs_size = hidden_states.shape[-1]
  578. accepted_index = accepted_token_ids + 1 # Convert -1 to 0
  579. accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
  580. index = accepted_index[:, None, None].expand(-1, 1, hs_size)
  581. second_last_token_hidden_states = hidden_states[:, -2] # b x d
  582. hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
  583. # Store hidden states from target model for subsequent decode step
  584. self.previous_hidden_states = HiddenStates(
  585. hidden_states, seq_group_metadata_list,
  586. second_last_token_hidden_states)
  587. return accepted_token_ids, logprobs
  588. def _create_output_sampler_list(
  589. self,
  590. seq_group_metadata_list: List[SequenceGroupMetadata],
  591. accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
  592. target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
  593. k: int,
  594. stage_times: Tuple[float, float, float],
  595. ) -> List[SamplerOutput]:
  596. """Given the accepted token ids, create a list of SamplerOutput.
  597. The output is padded with -1 tokens such that each sequence has
  598. the same number of outputs.
  599. """
  600. batch_size, num_steps = accepted_token_ids.shape
  601. accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
  602. if self._disable_logprobs:
  603. # We are skipping the logprobs. Hence don't serialize the
  604. # logprobs related tensors from the GPU. Instead create
  605. # empty/dummy lists.
  606. (accepted_token_id_ranks_by_step,
  607. accepted_token_id_logprobs_by_step,
  608. topk_logprobs_by_step, topk_indices_by_step) =\
  609. self._create_dummy_logprob_lists(
  610. batch_size, num_steps,
  611. self.scorer_worker.model_config.max_logprobs)
  612. else:
  613. # Organize input tensors by step instead of by sequence.
  614. target_logprobs_by_step = target_logprobs.transpose(0, 1)
  615. # Serialize all tensors into Python lists.
  616. (accepted_token_id_ranks_by_step,
  617. accepted_token_id_logprobs_by_step,
  618. topk_logprobs_by_step, topk_indices_by_step) =\
  619. self._create_logprob_lists_from_tensors(
  620. target_logprobs_by_step, accepted_token_ids_by_step,
  621. self.scorer_worker.model_config.max_logprobs)
  622. # Get the sequence ids and num_logprobs (sampling parameter) in the
  623. # batch.
  624. seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
  625. seq_group_metadata_list)
  626. num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
  627. # Serialize tensor to CPU Python list.
  628. accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
  629. # Construct the output on a per-step, per-sequence basis.
  630. sampler_output_list: List[SamplerOutput] = []
  631. for step_index in range(num_steps):
  632. if all(token_id == -1
  633. for token_id in accepted_token_ids_by_step[step_index]):
  634. break
  635. step_output_token_ids: List[CompletionSequenceGroupOutput] = []
  636. for sequence_index in range(batch_size):
  637. # Each sequence may have a different num_logprobs; retrieve it.
  638. num_logprobs = num_logprobs_per_seq[sequence_index]
  639. step_output_token_ids.append(
  640. create_sequence_group_output(
  641. token_id=accepted_token_ids_by_step[step_index]
  642. [sequence_index],
  643. token_id_logprob_rank=accepted_token_id_ranks_by_step[
  644. step_index][sequence_index],
  645. token_id_logprob=accepted_token_id_logprobs_by_step[
  646. step_index][sequence_index],
  647. seq_id=seq_ids[sequence_index],
  648. topk_token_ids=topk_indices_by_step[step_index]
  649. [sequence_index][:num_logprobs],
  650. topk_logprobs=topk_logprobs_by_step[step_index]
  651. [sequence_index][:num_logprobs],
  652. ))
  653. sampler_output_list.append(
  654. SamplerOutput(outputs=step_output_token_ids))
  655. # Populate the data structures needed to keep track of sequences with
  656. # bonus tokens.
  657. self._track_sequences_with_bonus_tokens(seq_ids,
  658. request_ids_seq_ids_mapping,
  659. accepted_token_ids_by_step)
  660. maybe_rejsample_metrics = (
  661. self._metrics.maybe_collect_rejsample_metrics(k))
  662. if maybe_rejsample_metrics is not None:
  663. sampler_output_list[
  664. 0].spec_decode_worker_metrics = maybe_rejsample_metrics
  665. # Log time spent in each stage periodically.
  666. # This is periodic because the rejection sampler emits metrics
  667. # periodically.
  668. self._maybe_log_stage_times(*stage_times)
  669. return sampler_output_list
  670. def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
  671. scoring_time_ms: float,
  672. verification_time_ms: float) -> None:
  673. """Log the speculative stage times. If stat logging is disabled, do
  674. nothing.
  675. """
  676. if self._disable_log_stats:
  677. return
  678. logger.info(f"SpecDecodeWorker stage times: "
  679. f"average_time_per_proposal_tok_ms="
  680. f"{average_time_per_proposal_tok_ms:.02f} "
  681. f"scoring_time_ms={scoring_time_ms:.02f} "
  682. f"verification_time_ms={verification_time_ms:.02f}")
  683. def _create_dummy_logprob_lists(
  684. self,
  685. batch_size: int,
  686. num_steps: int,
  687. num_top_k: int,
  688. ) -> Tuple[List[List[int]], List[List[float]],
  689. List[List[List[Optional[float]]]],
  690. List[List[List[Optional[int]]]]]:
  691. """
  692. Creates and returns four dummy lists representing token probabilities
  693. and their ranks.
  694. This method initializes and returns:
  695. - The ranks of the accepted tokens, shaped (num_steps, batch_size)
  696. - The log probabilities of the accepted tokens,
  697. shaped (num_steps, batch_size)
  698. - The log probabilities of the top k tokens,
  699. shaped (num_steps, batch_size, num_top_k)
  700. - The token IDs of the top k tokens,
  701. shaped (num_steps, batch_size, num_top_k)
  702. Args:
  703. batch_size (int): The size of the batch.
  704. num_steps (int): The number of steps in the sequence.
  705. num_top_k (int): The number of top-k token log probabilities to
  706. return.
  707. Returns:
  708. A tuple containing four dummy lists as described above.
  709. """
  710. accepted_token_id_ranks_by_step = [[-1] * batch_size
  711. for _ in range(num_steps)]
  712. accepted_token_id_logprobs_by_step = [[0.0] * batch_size
  713. for _ in range(num_steps)]
  714. topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
  715. [None] * num_top_k for _ in range(batch_size)
  716. ] for _ in range(num_steps)]
  717. topk_indices_by_step: List[List[List[Optional[int]]]] = [[
  718. [None] * num_top_k for _ in range(batch_size)
  719. ] for _ in range(num_steps)]
  720. return (accepted_token_id_ranks_by_step,
  721. accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
  722. topk_indices_by_step)
  723. def _create_logprob_lists_from_tensors(
  724. self,
  725. target_logprobs_by_step: torch.Tensor,
  726. accepted_token_ids_by_step: torch.Tensor,
  727. num_top_k: int,
  728. ) -> Tuple[List[List[int]], List[List[float]],
  729. List[List[List[Optional[float]]]],
  730. List[List[List[Optional[int]]]]]:
  731. """
  732. Creates and returns four lists representing token probabilities and
  733. their ranks.
  734. This method initializes and returns four lists containing:
  735. - The ranks of the accepted tokens, shaped (num_steps, batch_size)
  736. - The log probabilities of the accepted tokens,
  737. shaped (num_steps, batch_size)
  738. - The log probabilities of the top k tokens,
  739. shaped (num_steps, batch_size, num_top_k)
  740. - The token IDs of the top k tokens,
  741. shaped (num_steps, batch_size, num_top_k)
  742. Args:
  743. target_logprobs_by_step (torch.Tensor): Tensor representing the
  744. log probabilities of the target model,
  745. shaped (num_steps, batch_size, vocab_size)
  746. accepted_token_ids_by_step (torch.Tensor): Tensor representing
  747. the accepted token_ids, shaped (num_steps, batch_size)
  748. num_top_k (int): The number of top-k token log probabilities to
  749. return.
  750. Returns:
  751. A tuple containing the lists as described above.
  752. """
  753. # Serialize all tensors to CPU Python lists.
  754. # Get the logprobs/rank of the accepted tokens.
  755. (accepted_token_id_ranks_by_step_tensor,
  756. accepted_token_id_logprobs_by_step_tensor
  757. ) = get_sampled_token_logprobs(
  758. logprob_tensor=target_logprobs_by_step,
  759. sampled_token_ids=accepted_token_ids_by_step,
  760. )
  761. # Get the top-k logprobs (which may or may not include the
  762. # logprob of the accepted token).
  763. (topk_logprobs_by_step_tensor,
  764. topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
  765. k=num_top_k,
  766. dim=-1,
  767. )
  768. accepted_token_id_ranks_by_step = (
  769. accepted_token_id_ranks_by_step_tensor.tolist())
  770. accepted_token_id_logprobs_by_step = (
  771. accepted_token_id_logprobs_by_step_tensor.tolist())
  772. topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
  773. topk_indices_by_step = topk_indices_by_step_tensor.tolist()
  774. return (accepted_token_id_ranks_by_step,
  775. accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
  776. topk_indices_by_step)
  777. def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
  778. """
  779. Removes the finished requests and their associated sequence ids from
  780. internal book keeping data structures.
  781. """
  782. for finished_request in execute_model_req.finished_requests_ids:
  783. for seq_id in self._request_id_seq_id_mapping[finished_request]:
  784. self._seq_with_bonus_token_in_last_step.discard(seq_id)
  785. del self._request_id_seq_id_mapping[finished_request]
  786. def _track_sequences_with_bonus_tokens(
  787. self, seq_ids: List[int],
  788. request_ids_seq_ids_mapping: Dict[str, Set[int]],
  789. accepted_token_ids_by_step: List[List[int]]):
  790. """
  791. Updates the internal data structures which keep track of sequences
  792. which have been assigned bonus tokens in their last forward pass.
  793. """
  794. for seq_index, seq_id in enumerate(seq_ids):
  795. last_token_id = accepted_token_ids_by_step[-1][seq_index]
  796. if last_token_id == -1:
  797. self._seq_with_bonus_token_in_last_step.discard(seq_id)
  798. else:
  799. self._seq_with_bonus_token_in_last_step.add(seq_id)
  800. for request_id, sequences in request_ids_seq_ids_mapping.items():
  801. self._request_id_seq_id_mapping[request_id].update(sequences)
  802. @cached_property
  803. def _vocab_size(self) -> int:
  804. """Get the vocab size of the model and make sure it's consistent between
  805. draft and target workers.
  806. """
  807. vocab_sizes = [
  808. worker.vocab_size
  809. for worker in [self.proposer_worker, self.scorer_worker]
  810. ]
  811. assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
  812. return vocab_sizes[0]
  813. @property
  814. def rank(self):
  815. return self.scorer_worker.rank
  816. @property
  817. def device(self):
  818. return self.scorer_worker.device
  819. @property
  820. def _driver_rank(self) -> int:
  821. return 0
  822. def get_cache_block_size_bytes(self):
  823. """Return the size of a cache block in bytes.
  824. This function is only used to compose workers within a SpecDecodeWorker.
  825. We leave composing a SpecDecodeWorker within a SpecDecodeWorker
  826. undefined for now, although it could be implemented in the future.
  827. See https://arxiv.org/abs/2308.04623.
  828. """
  829. raise NotImplementedError
  830. def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
  831. proposer_cache_block_size_bytes: int,
  832. total_num_gpu_blocks: int) -> int:
  833. """Given total_num_gpu_blocks, the number of GPU blocks that could be
  834. allocate to the target model, this function calculates how many blocks
  835. should be given to the draft and target model.
  836. Note that usually the block size, in bytes, of each model is different,
  837. as it's a function of number of KV/layer, number of heads, and hidden
  838. dimension size.
  839. Since the target and draft models allocate the same number of blocks, we
  840. simply calculate the number of blocks where if allocated by both models,
  841. the total memory usage from KV cache is no larger than the number of
  842. blocks allocatable by the target model alone.
  843. """
  844. new_num_gpu_blocks = int(
  845. total_num_gpu_blocks * scorer_cache_block_size_bytes /
  846. (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
  847. return new_num_gpu_blocks
  848. def prepare_prefill_hidden_states(
  849. prefill_hidden_states: torch.Tensor) -> HiddenStates:
  850. # For prefill step in proposer, we run the model for N-1 tokens
  851. # because Nth token will be processed in the first decode step. For
  852. # N-1 tokens, the input should be 0:N-1 hidden states which should
  853. # be concatanated with 1:N token (since output of scorer has to be
  854. # the input for proposer). Therefore, we shift the hidden states to
  855. # align n-1th hidden state with nth token.
  856. return HiddenStates(prefill_hidden_states.roll(
  857. shifts=1, dims=0)) if prefill_hidden_states is not None else None