1
0

spec_decode_worker.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991
  1. from collections import defaultdict
  2. from functools import cached_property
  3. from typing import Any, Dict, List, Optional, Set, Tuple
  4. import torch
  5. from loguru import logger
  6. from aphrodite.common.config import ParallelConfig, SpeculativeConfig
  7. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  8. ExecuteModelRequest, HiddenStates,
  9. SamplerOutput, SequenceGroupMetadata,
  10. get_all_seq_ids,
  11. get_all_seq_ids_and_request_ids)
  12. from aphrodite.distributed.communication_op import broadcast_tensor_dict
  13. from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
  14. from aphrodite.modeling.layers.spec_decode_base_sampler import (
  15. SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
  16. from aphrodite.modeling.layers.typical_acceptance_sampler import (
  17. TypicalAcceptanceSampler)
  18. from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
  19. from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
  20. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  21. SpeculativeScorer,
  22. SpeculativeScores)
  23. from aphrodite.spec_decode.medusa_worker import MedusaWorker
  24. from aphrodite.spec_decode.metrics import AsyncMetricsCollector
  25. from aphrodite.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
  26. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  27. from aphrodite.spec_decode.ngram_worker import NGramWorker
  28. from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
  29. from aphrodite.spec_decode.smaller_tp_proposer_worker import (
  30. SmallerTpProposerWorker)
  31. from aphrodite.spec_decode.target_model_runner import TargetModelRunner
  32. from aphrodite.spec_decode.util import (Timer, create_sequence_group_output,
  33. get_all_num_logprobs,
  34. get_sampled_token_logprobs, nvtx_range,
  35. split_batch_by_proposal_len)
  36. from aphrodite.task_handler.worker import Worker
  37. from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase,
  38. WorkerBase)
  39. def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
  40. """Helper method that is the entrypoint for Executors which use
  41. WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
  42. """
  43. assert "speculative_config" in kwargs
  44. speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
  45. assert speculative_config is not None
  46. draft_worker_kwargs = kwargs.copy()
  47. kwargs["model_runner_cls"] = TargetModelRunner
  48. target_worker = Worker(*args, **kwargs)
  49. # Set the disable_logprobs variable in the TargetModelRunner instance
  50. # as per its value specified in the SpeculativeConfig.
  51. target_worker.model_runner.disable_logprobs =\
  52. speculative_config.disable_logprobs
  53. # Override draft-model specific worker args.
  54. draft_worker_kwargs.update(
  55. model_config=speculative_config.draft_model_config,
  56. parallel_config=speculative_config.draft_parallel_config,
  57. ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
  58. ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
  59. # TODO allow draft-model specific load config.
  60. #load_config=load_config,
  61. )
  62. spec_decode_worker = SpecDecodeWorker.create_worker(
  63. scorer_worker=target_worker,
  64. draft_worker_kwargs=draft_worker_kwargs,
  65. disable_by_batch_size=speculative_config.
  66. speculative_disable_by_batch_size,
  67. draft_token_acceptance_method=speculative_config.
  68. draft_token_acceptance_method,
  69. typical_acceptance_sampler_posterior_threshold=speculative_config.
  70. typical_acceptance_sampler_posterior_threshold,
  71. typical_acceptance_sampler_posterior_alpha=speculative_config.
  72. typical_acceptance_sampler_posterior_alpha,
  73. disable_logprobs=speculative_config.disable_logprobs,
  74. disable_log_stats=speculative_config.disable_log_stats,
  75. )
  76. return spec_decode_worker
  77. class SpecDecodeWorker(LoraNotSupportedWorkerBase):
  78. """Worker which implements speculative decoding.
  79. Speculative decoding reduces decoding per-token latency by using a proposal
  80. method, such as a small draft model, to speculate ahead of a larger LLM. The
  81. probabilities of the speculative tokens are then determined by the larger
  82. LLM, after which some verification routine determines which (if any) of the
  83. speculative tokens are accepted by the larger LLM.
  84. The current implementation has the following limitations:
  85. * Only draft-model proposal is implemented (contributions for more forms are
  86. welcome!).
  87. * Only top-1 proposal and scoring are implemented. Tree-attention is left as
  88. future work.
  89. * All sequences in a batch must have the same proposal length, or zero. This
  90. can be improved by having per-sequence speculation in the future.
  91. * The scoring forward pass is done without an MQA kernel, which is
  92. suboptimal especially as the batch size, proposal length, and sequence
  93. lengths grow. Contributions to add a MQA scoring are welcome once
  94. correctness tests pass.
  95. """
  96. @classmethod
  97. def create_worker(
  98. cls,
  99. scorer_worker: Worker,
  100. draft_worker_kwargs: Dict[str, Any],
  101. disable_by_batch_size: Optional[int],
  102. draft_token_acceptance_method: str,
  103. typical_acceptance_sampler_posterior_threshold: float,
  104. typical_acceptance_sampler_posterior_alpha: float,
  105. disable_logprobs: bool,
  106. disable_log_stats: bool,
  107. ) -> "SpecDecodeWorker":
  108. allow_zero_draft_token_step = True
  109. ngram_prompt_lookup_max = (
  110. draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
  111. ngram_prompt_lookup_min = (
  112. draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
  113. if ngram_prompt_lookup_max > 0:
  114. proposer_worker = NGramWorker(**draft_worker_kwargs)
  115. proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
  116. ngram_prompt_lookup_max)
  117. else:
  118. draft_parallel_config: ParallelConfig = draft_worker_kwargs[
  119. 'parallel_config']
  120. draft_tp = draft_parallel_config.tensor_parallel_size
  121. target_tp = scorer_worker.parallel_config.tensor_parallel_size
  122. if draft_worker_kwargs[
  123. "model_config"].hf_config.model_type == "mlp_speculator":
  124. proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
  125. elif draft_worker_kwargs[
  126. "model_config"].hf_config.model_type == "medusa":
  127. proposer_worker = MedusaWorker(**draft_worker_kwargs)
  128. else:
  129. if draft_tp == 1:
  130. draft_worker_kwargs[
  131. "model_runner_cls"] = TP1DraftModelRunner
  132. else:
  133. if draft_worker_kwargs[
  134. "model_config"].hf_config.model_type == "eagle":
  135. raise NotImplementedError(
  136. "EAGLE does not support TP > 1 yet")
  137. allow_zero_draft_token_step = False
  138. proposer_worker = MultiStepWorker(**draft_worker_kwargs)
  139. proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
  140. proposer_worker, draft_tp, target_tp)
  141. logger.info("Configuring SpecDecodeWorker with "
  142. f"proposer={type(proposer_worker)}")
  143. spec_decode_sampler: SpecDecodeBaseSampler = None
  144. if draft_token_acceptance_method == "rejection_sampler":
  145. spec_decode_sampler = RejectionSampler(
  146. disable_bonus_tokens=False, )
  147. elif draft_token_acceptance_method == "typical_acceptance_sampler":
  148. spec_decode_sampler = TypicalAcceptanceSampler(
  149. disable_bonus_tokens=False,
  150. posterior_threshold=\
  151. typical_acceptance_sampler_posterior_threshold,
  152. posterior_alpha=typical_acceptance_sampler_posterior_alpha,
  153. )
  154. logger.info("Configuring SpecDecodeWorker with "
  155. f"sampler={type(spec_decode_sampler)}")
  156. return SpecDecodeWorker(
  157. proposer_worker,
  158. scorer_worker,
  159. disable_logprobs=disable_logprobs,
  160. disable_log_stats=disable_log_stats,
  161. disable_by_batch_size=disable_by_batch_size,
  162. spec_decode_sampler=spec_decode_sampler,
  163. allow_zero_draft_token_step=allow_zero_draft_token_step)
  164. def __init__(
  165. self,
  166. proposer_worker: ProposerWorkerBase,
  167. scorer_worker: WorkerBase,
  168. spec_decode_sampler: SpecDecodeBaseSampler,
  169. disable_logprobs: bool = False,
  170. disable_log_stats: bool = False,
  171. metrics_collector: Optional[AsyncMetricsCollector] = None,
  172. disable_by_batch_size: Optional[int] = None,
  173. allow_zero_draft_token_step: Optional[bool] = True,
  174. ):
  175. """
  176. Create a SpecDecodeWorker.
  177. Args:
  178. proposer_worker: A worker that can produce speculative tokens for
  179. sequences.
  180. scorer_worker: A worker that produces probabilities of speculative
  181. tokens according to some base model. Typically a vanilla
  182. Aphrodite Worker.
  183. spec_decode_sampler: A Torch module used to perform acceptance
  184. sampling of the draft tokens in the verification step of
  185. speculative decoding. Currently we support two different
  186. types of sampler namely RejectionSampler and
  187. TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
  188. instance of RejectionSampler or TypicalAcceptanceSampler.
  189. disable_logprobs: If set to True, token log probabilities will
  190. not be output in both the draft worker and the target worker.
  191. If set to False, log probabilities will be output by both.
  192. disable_log_stats: If set to True, disable periodic printing of
  193. speculative stage times.
  194. disable_by_batch_size: If the batch size is larger than this,
  195. disable speculative decoding for new incoming requests.
  196. metrics_collector: Helper class for collecting metrics; can be set
  197. for testing purposes.
  198. allow_zero_draft_token_step: whether to allow a step where the draft
  199. model generates no draft token; should disallow when the tp of
  200. draft model is larger than 1
  201. """
  202. self.proposer_worker = proposer_worker
  203. self.scorer_worker = scorer_worker
  204. scorer_runner = getattr(self.scorer_worker, "model_runner", None)
  205. self.generators = scorer_runner.get_generators(
  206. ) if scorer_runner else None
  207. self.disable_by_batch_size = disable_by_batch_size or float("inf")
  208. self.spec_decode_sampler = spec_decode_sampler
  209. self._allow_zero_draft_token_step = allow_zero_draft_token_step
  210. self._metrics = AsyncMetricsCollector(
  211. self.spec_decode_sampler
  212. ) if metrics_collector is None else metrics_collector
  213. # Tracks the sequence IDs that received a bonus token ID in
  214. # their last forward pass. Needed only if KV cache is being
  215. # used for token generation such as in the case of MultiStepWorker.
  216. self._seq_with_bonus_token_in_last_step: Set[int] = set()
  217. # Tracks the currently active request ids and the sequence IDs
  218. # corresponding to them
  219. self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
  220. # Tracks if the proposer worker uses the KV cache or not.
  221. self.probs_dtype = self.spec_decode_sampler.probs_dtype
  222. self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
  223. # Lazy initialization.
  224. self.scorer: SpeculativeScorer
  225. # Hidden states from target model to pass to proposer
  226. # in the subsequent step.
  227. self.previous_hidden_states: Optional[HiddenStates] = None
  228. self._disable_logprobs = disable_logprobs
  229. self._disable_log_stats = disable_log_stats
  230. def init_device(self) -> None:
  231. """Initialize both scorer and proposer models.
  232. """
  233. # The scorer worker model is initialized first in case the proposer
  234. # model has a smaller TP degree than the target worker.
  235. self.scorer_worker.init_device()
  236. self.proposer_worker.init_device()
  237. # NOTE: load_model is not part of the WorkerBase interface.
  238. self.scorer_worker.load_model()
  239. self.proposer_worker.load_model()
  240. self._metrics.init_gpu_tensors(self.rank)
  241. self.spec_decode_sampler.init_gpu_tensors(self.rank)
  242. self.scorer = BatchExpansionTop1Scorer(
  243. scorer_worker=self.scorer_worker,
  244. device=self.device,
  245. vocab_size=self._vocab_size)
  246. self._configure_model_sampler_for_spec_decode()
  247. def load_model(self, *args, **kwargs):
  248. pass
  249. def _configure_model_sampler_for_spec_decode(self):
  250. """Configure model sampler to emit GPU tensors. This allows spec decode
  251. to keep data on device without transferring to CPU and serializing,
  252. which significantly reduces overhead of sampling during verification.
  253. NOTE: This breaks abstraction boundaries pretty badly. The better
  254. design is to have the "move to CPU and serialize" sampling decision be
  255. done outside of the model/sampler; this way the "last-mile" worker
  256. object which interfaces with the scheduler can serialize and incur the
  257. performance hit as necessary. This allows us to run the worker several
  258. iterations in a row without incurring the "move to CPU and serialize"
  259. performance penalty.
  260. Since this requires a large change to Aphrodite, we defer it to later
  261. and temporarily accept this broken abstraction boundary.
  262. NOTE: This will require a special check if the proposer worker
  263. does not have a sampler (e.g. ngram speculation).
  264. """
  265. (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
  266. ) = True
  267. (self.scorer_worker.model_runner.model.sampler.
  268. should_modify_greedy_probs_inplace) = True
  269. self.proposer_worker.set_include_gpu_probs_tensor()
  270. self.proposer_worker.set_should_modify_greedy_probs_inplace()
  271. def determine_num_available_blocks(self) -> Tuple[int, int]:
  272. """Determine the number of cache blocks to use.
  273. This is done by profiling the scorer model (which is typically the
  274. larger of the two). Then the total memory which would be used by the
  275. scorer cache is divided evenly between the proposer and scorer model KV,
  276. such that the number of blocks is equal in both KV caches.
  277. """
  278. num_gpu_blocks, num_cpu_blocks = (
  279. self.scorer_worker.determine_num_available_blocks())
  280. scorer_cache_block_size_bytes = (
  281. self.scorer_worker.get_cache_block_size_bytes())
  282. proposer_cache_block_size_bytes = (
  283. self.proposer_worker.get_cache_block_size_bytes())
  284. new_num_gpu_blocks = split_num_cache_blocks_evenly(
  285. scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
  286. num_gpu_blocks)
  287. return new_num_gpu_blocks, num_cpu_blocks
  288. def initialize_cache(self, num_gpu_blocks: int,
  289. num_cpu_blocks: int) -> None:
  290. """Initialize the cache engine of the scorer and proposer workers.
  291. """
  292. self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  293. num_cpu_blocks=num_cpu_blocks)
  294. self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  295. num_cpu_blocks=num_cpu_blocks)
  296. @torch.inference_mode()
  297. def execute_model(
  298. self,
  299. execute_model_req: Optional[ExecuteModelRequest] = None
  300. ) -> List[SamplerOutput]:
  301. """Perform speculative decoding on the input batch.
  302. """
  303. if self.rank != self._driver_rank:
  304. self._run_non_driver_rank()
  305. return []
  306. if execute_model_req is None:
  307. # This signals that there's no more requests to process for now.
  308. # All workers are running infinite loop with broadcast_tensor_dict,
  309. # and it stops the loop when the driver broadcasts an empty input.
  310. # Send an empty input to notify all other workers to stop their
  311. # execution loop.
  312. broadcast_tensor_dict({}, src=0)
  313. return []
  314. self._track_finished_requests(execute_model_req)
  315. disable_all_speculation = self._should_disable_all_speculation(
  316. execute_model_req)
  317. num_lookahead_slots = execute_model_req.num_lookahead_slots
  318. # Speculative decoding is disabled in the following cases:
  319. # 1. Prefill phase: Speculative decoding is not
  320. # used during the prefill phase.
  321. # 2. Auto-disable enabled: The running queue size exceeds
  322. # the specified threshold.
  323. # 3. No request: There are no requests in the batch.
  324. # In any of these cases, the proposer and scorer workers
  325. # are called normally.
  326. no_spec = num_lookahead_slots == 0 or len(
  327. execute_model_req.seq_group_metadata_list
  328. ) == 0 or disable_all_speculation
  329. # Broadcast how many lookahead slots are scheduled for this step, and
  330. # whether all speculation is disabled, to all non-driver workers.
  331. # This is required as if the number of draft model runs changes
  332. # dynamically, the non-driver workers won't know unless we perform a
  333. # communication to inform them.
  334. # no_spec is used to signal non-driver worker about prefill vs decode
  335. # stage. This is needed to ensure that order of execution of proposer
  336. # and scorer is same in both driver and non-driver workers (i.e.,
  337. # scorer -> proposer for prefill and proposer -> scorer in decode). This
  338. # order is needed to support models like EAGLE that take scorer states
  339. # as inputs.
  340. broadcast_dict = dict(
  341. num_lookahead_slots=num_lookahead_slots,
  342. no_spec=no_spec,
  343. disable_all_speculation=disable_all_speculation,
  344. )
  345. broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
  346. assert execute_model_req.seq_group_metadata_list is not None, (
  347. "speculative decoding requires non-None seq_group_metadata_list")
  348. self._maybe_disable_speculative_tokens(
  349. disable_all_speculation, execute_model_req.seq_group_metadata_list)
  350. if no_spec:
  351. return self._run_no_spec(execute_model_req,
  352. skip_proposer=disable_all_speculation)
  353. return self._run_speculative_decoding_step(execute_model_req,
  354. num_lookahead_slots)
  355. @torch.inference_mode()
  356. def start_worker_execution_loop(self) -> None:
  357. """Execute model loop to perform speculative decoding
  358. in parallel worker."""
  359. while self._run_non_driver_rank():
  360. pass
  361. def _should_disable_all_speculation(
  362. self, execute_model_req: ExecuteModelRequest) -> bool:
  363. # When the batch size is too large, disable speculative decoding
  364. # to stop trading off throughput for latency.
  365. disable_all_speculation = (execute_model_req.running_queue_size >=
  366. self.disable_by_batch_size)
  367. return disable_all_speculation
  368. def _maybe_disable_speculative_tokens(
  369. self, disable_all_speculation: bool,
  370. seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
  371. if not disable_all_speculation:
  372. return
  373. for seq_group_metadata in seq_group_metadata_list:
  374. # Once num_speculative_tokens is set to 0, the spec decode
  375. # of this request will be disabled forever.
  376. # TODO: We currently store spec decoding specific
  377. # state in the global data structure, but we should maintain
  378. # this state within spec decode worker.
  379. seq_group_metadata.num_speculative_tokens = 0
  380. def _serialize_sampler_output_no_logprobs(
  381. self, execute_model_req: ExecuteModelRequest,
  382. sampler_output: SamplerOutput) -> SamplerOutput:
  383. """
  384. Creates and returns a `SamplerOutput` with only the sampled token IDs
  385. being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
  386. All other parameters in `CompletionSequenceGroupOutput` related to log
  387. probabilities are skipped.
  388. Args:
  389. execute_model_req (ExecuteModelRequest): The model request that
  390. was executed.
  391. sampler_output (SamplerOutput): The output from the sampler with
  392. only GPU tensors populated.
  393. Returns:
  394. SamplerOutput: A new `SamplerOutput` instance containing a list of
  395. `CompletionSequenceGroupOutput` objects with only sampled token
  396. IDs populated.
  397. """
  398. seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
  399. sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
  400. completion_seq_group_output_list: List[
  401. CompletionSequenceGroupOutput] = []
  402. for index, seq_id in enumerate(seq_ids):
  403. completion_seq_group_output_list.append(
  404. create_sequence_group_output(
  405. token_id=sampled_token_ids_list[index][0],
  406. token_id_logprob_rank=-1,
  407. token_id_logprob=0.0,
  408. seq_id=seq_id,
  409. topk_token_ids=[],
  410. topk_logprobs=[],
  411. ))
  412. return SamplerOutput(outputs=completion_seq_group_output_list)
  413. @nvtx_range("spec_decode_worker._run_no_spec")
  414. def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
  415. skip_proposer: bool) -> List[SamplerOutput]:
  416. """Run a single generation step without any speculation. The input is
  417. sent to the proposer and scorer model so that the KV cache is consistent
  418. between the two. When skip_proposer is True, the proposer model is
  419. not called, meaning that the kv-cache in proposer for requests is not
  420. updated, so they cannot enable spec decode in the rest decoding.
  421. """
  422. sampler_output = self.scorer_worker.execute_model(execute_model_req)
  423. assert len(sampler_output) == 1
  424. sampler_output = sampler_output[0]
  425. # Store hidden states from target model execution.
  426. hidden_states = sampler_output.hidden_states
  427. if hidden_states is not None:
  428. if self.previous_hidden_states is None:
  429. self.previous_hidden_states = HiddenStates(
  430. hidden_states, execute_model_req.seq_group_metadata_list)
  431. else:
  432. self.previous_hidden_states.update(
  433. hidden_states, execute_model_req.seq_group_metadata_list)
  434. if not skip_proposer:
  435. # We prepare the prefill hidden states here so that there no
  436. # additional complexity in worker for spec_decode vs non_spec_decode
  437. # flow and execute_model doesn't need additional modifications.
  438. execute_model_req.previous_hidden_states = \
  439. prepare_prefill_hidden_states(
  440. sampler_output.prefill_hidden_states)
  441. self.proposer_worker.execute_model(execute_model_req)
  442. sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
  443. execute_model_req=execute_model_req, sampler_output=sampler_output)
  444. if self._disable_logprobs else
  445. sampler_output)
  446. # Clear device tensors from sampler output. This reduces communication
  447. # overhead when the engine runs in a different process than the workers.
  448. sampler_output.sampled_token_probs = None
  449. sampler_output.sampled_token_ids = None
  450. sampler_output.logprobs = None
  451. return [sampler_output_to_return]
  452. def _run_non_driver_rank(self) -> bool:
  453. """Run proposer and verifier model in non-driver workers. This is used
  454. for both speculation cases (num_lookahead_slots>0) and non-speculation
  455. cases (e.g. prefill).
  456. Returns True if there are remaining sequences to process.
  457. """
  458. assert self.rank != self._driver_rank
  459. data = broadcast_tensor_dict(src=self._driver_rank)
  460. if not data:
  461. return False
  462. num_lookahead_slots = data["num_lookahead_slots"]
  463. # In case of prefill, scorer_worker has to be run before proposer so
  464. # that the hidden states can be propagated to proposer when needed.
  465. if data["no_spec"]:
  466. self.scorer_worker.execute_model()
  467. if not data["disable_all_speculation"]:
  468. # Even if num_lookahead_slots is zero, we want to run the
  469. # proposer model as it may have KV.
  470. #
  471. # We run the proposer once per lookahead slot. In the future we
  472. # should delegate how many times it runs to the proposer.
  473. for _ in range(max(num_lookahead_slots, 1)):
  474. self.proposer_worker.execute_model()
  475. if not data["no_spec"]:
  476. self.scorer_worker.execute_model()
  477. return True
  478. @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
  479. def _run_speculative_decoding_step(
  480. self, execute_model_req: ExecuteModelRequest,
  481. num_lookahead_slots: int) -> List[SamplerOutput]:
  482. """Execute a single step of speculative decoding.
  483. This invokes the proposer worker to get k speculative tokens for each
  484. sequence, then scores each speculative token using the scoring worker.
  485. Returns a list of SamplerOutput, each containing a single token per
  486. sequence.
  487. """
  488. assert num_lookahead_slots == execute_model_req.num_lookahead_slots
  489. # Pass last hidden states from target model to proposer
  490. execute_model_req.previous_hidden_states = self.previous_hidden_states
  491. self.previous_hidden_states = None
  492. with Timer() as proposal_timer:
  493. # Generate proposals using draft worker.
  494. proposals = self.proposer_worker.get_spec_proposals(
  495. execute_model_req, self._seq_with_bonus_token_in_last_step)
  496. if not self._allow_zero_draft_token_step and proposals.no_proposals:
  497. #TODO: Fix it #5814
  498. raise RuntimeError("Cannot handle cases where distributed draft "
  499. "workers generate no tokens")
  500. execute_model_req.previous_hidden_states = None
  501. with Timer() as scoring_timer:
  502. proposal_scores = self.scorer.score_proposals(
  503. execute_model_req,
  504. proposals,
  505. )
  506. with Timer() as verification_timer:
  507. accepted_token_ids, target_logprobs = self._verify_tokens(
  508. execute_model_req.seq_group_metadata_list, proposal_scores,
  509. proposals, execute_model_req.num_lookahead_slots)
  510. stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
  511. scoring_timer.elapsed_time_ms,
  512. verification_timer.elapsed_time_ms)
  513. return self._create_output_sampler_list(
  514. execute_model_req.seq_group_metadata_list,
  515. accepted_token_ids,
  516. target_logprobs=target_logprobs,
  517. k=execute_model_req.num_lookahead_slots,
  518. stage_times=stage_times)
  519. @nvtx_range("spec_decode_worker._verify_tokens")
  520. def _verify_tokens(
  521. self,
  522. seq_group_metadata_list: List[SequenceGroupMetadata],
  523. proposal_scores: SpeculativeScores,
  524. proposals: SpeculativeProposals,
  525. max_proposal_len: int,
  526. ) -> Tuple[torch.Tensor, torch.Tensor]:
  527. """Determine which speculative tokens are accepted using the
  528. probabilities of each token according to the proposer and scorer models.
  529. Returns a tuple of Tensors, one for the accepted token ids and one for
  530. the logprobs according to the scoring model.
  531. """
  532. proposal_lens_list = proposals.proposal_lens.tolist()
  533. # Aphrodite currently only supports proposal lens equal to zero or the
  534. # batch proposal len. This adds some complexity (splitting the batch
  535. # into spec and non spec sequences) and should be removed in the
  536. # future. It can be done by supporting per-sequence proposal lens.
  537. _, spec_indices = split_batch_by_proposal_len(
  538. seq_group_metadata_list,
  539. proposal_lens_list,
  540. select_proposal_len_zero=False)
  541. _, non_spec_indices = split_batch_by_proposal_len(
  542. seq_group_metadata_list,
  543. proposal_lens_list,
  544. select_proposal_len_zero=True)
  545. original_indices = spec_indices + non_spec_indices
  546. # Get probabilities of target model, excluding bonus token.
  547. proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
  548. # Get non-speculative sampled tokens from target model.
  549. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
  550. # Get bonus tokens from target model.
  551. bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
  552. # Get probabilities according to proposal method.
  553. proposal_probs = proposals.proposal_probs[spec_indices]
  554. # Get proposed tokens.
  555. proposal_token_ids = proposals.proposal_token_ids[spec_indices]
  556. # Sampler arguments
  557. sampler_extra_kwargs: Dict[str, Any] = {}
  558. if self.generators and isinstance(self.spec_decode_sampler,
  559. SpecDecodeStochasticBaseSampler):
  560. sampler_extra_kwargs["seeded_seqs"] = {
  561. idx: self.generators[sgm.request_id]
  562. for idx, sgm in enumerate(seq_group_metadata_list)
  563. if sgm.sampling_params.seed is not None
  564. }
  565. accepted_token_ids = self.spec_decode_sampler(
  566. target_probs=proposal_verifier_probs,
  567. bonus_token_ids=bonus_token_ids,
  568. draft_probs=proposal_probs,
  569. draft_token_ids=proposal_token_ids,
  570. **sampler_extra_kwargs,
  571. )
  572. # Append output tokens from non-speculative sequences to
  573. # the accepted token ids tensor.
  574. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
  575. 1).clone()
  576. non_spec_token_ids[:, 1:] = -1
  577. accepted_token_ids = torch.cat(
  578. [accepted_token_ids, non_spec_token_ids])
  579. logprobs = proposal_scores.logprobs
  580. # Rearrange so that results are in the order of the original seq group
  581. # metadata.
  582. accepted_token_ids[original_indices] = accepted_token_ids.clone()
  583. hidden_states = proposal_scores.hidden_states
  584. if hidden_states is not None:
  585. # Contract hidden states based on accepted tokens
  586. hs_size = hidden_states.shape[-1]
  587. accepted_index = accepted_token_ids + 1 # Convert -1 to 0
  588. accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
  589. index = accepted_index[:, None, None].expand(-1, 1, hs_size)
  590. second_last_token_hidden_states = hidden_states[:, -2] # b x d
  591. hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
  592. # Store hidden states from target model for subsequent decode step
  593. self.previous_hidden_states = HiddenStates(
  594. hidden_states, seq_group_metadata_list,
  595. second_last_token_hidden_states)
  596. return accepted_token_ids, logprobs
  597. def _create_output_sampler_list(
  598. self,
  599. seq_group_metadata_list: List[SequenceGroupMetadata],
  600. accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
  601. target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
  602. k: int,
  603. stage_times: Tuple[float, float, float],
  604. ) -> List[SamplerOutput]:
  605. """Given the accepted token ids, create a list of SamplerOutput.
  606. The output is padded with -1 tokens such that each sequence has
  607. the same number of outputs.
  608. """
  609. batch_size, num_steps = accepted_token_ids.shape
  610. accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
  611. if self._disable_logprobs:
  612. # We are skipping the logprobs. Hence don't serialize the
  613. # logprobs related tensors from the GPU. Instead create
  614. # empty/dummy lists.
  615. (accepted_token_id_ranks_by_step,
  616. accepted_token_id_logprobs_by_step,
  617. topk_logprobs_by_step, topk_indices_by_step) =\
  618. self._create_dummy_logprob_lists(
  619. batch_size, num_steps,
  620. self.scorer_worker.model_config.max_logprobs)
  621. else:
  622. # Organize input tensors by step instead of by sequence.
  623. target_logprobs_by_step = target_logprobs.transpose(0, 1)
  624. # Serialize all tensors into Python lists.
  625. (accepted_token_id_ranks_by_step,
  626. accepted_token_id_logprobs_by_step,
  627. topk_logprobs_by_step, topk_indices_by_step) =\
  628. self._create_logprob_lists_from_tensors(
  629. target_logprobs_by_step, accepted_token_ids_by_step,
  630. self.scorer_worker.model_config.max_logprobs)
  631. # Get the sequence ids and num_logprobs (sampling parameter) in the
  632. # batch.
  633. seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
  634. seq_group_metadata_list)
  635. num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
  636. # Serialize tensor to CPU Python list.
  637. accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
  638. # Construct the output on a per-step, per-sequence basis.
  639. sampler_output_list: List[SamplerOutput] = []
  640. for step_index in range(num_steps):
  641. if all(token_id == -1
  642. for token_id in accepted_token_ids_by_step[step_index]):
  643. break
  644. step_output_token_ids: List[CompletionSequenceGroupOutput] = []
  645. for sequence_index in range(batch_size):
  646. # Each sequence may have a different num_logprobs; retrieve it.
  647. num_logprobs = num_logprobs_per_seq[sequence_index]
  648. step_output_token_ids.append(
  649. create_sequence_group_output(
  650. token_id=accepted_token_ids_by_step[step_index]
  651. [sequence_index],
  652. token_id_logprob_rank=accepted_token_id_ranks_by_step[
  653. step_index][sequence_index],
  654. token_id_logprob=accepted_token_id_logprobs_by_step[
  655. step_index][sequence_index],
  656. seq_id=seq_ids[sequence_index],
  657. topk_token_ids=topk_indices_by_step[step_index]
  658. [sequence_index][:num_logprobs],
  659. topk_logprobs=topk_logprobs_by_step[step_index]
  660. [sequence_index][:num_logprobs],
  661. ))
  662. sampler_output_list.append(
  663. SamplerOutput(outputs=step_output_token_ids))
  664. # Populate the data structures needed to keep track of sequences with
  665. # bonus tokens.
  666. self._track_sequences_with_bonus_tokens(seq_ids,
  667. request_ids_seq_ids_mapping,
  668. accepted_token_ids_by_step)
  669. maybe_rejsample_metrics = (
  670. self._metrics.maybe_collect_rejsample_metrics(k))
  671. if maybe_rejsample_metrics is not None:
  672. sampler_output_list[
  673. 0].spec_decode_worker_metrics = maybe_rejsample_metrics
  674. # Log time spent in each stage periodically.
  675. # This is periodic because the rejection sampler emits metrics
  676. # periodically.
  677. self._maybe_log_stage_times(*stage_times)
  678. return sampler_output_list
  679. def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
  680. scoring_time_ms: float,
  681. verification_time_ms: float) -> None:
  682. """Log the speculative stage times. If stat logging is disabled, do
  683. nothing.
  684. """
  685. if self._disable_log_stats:
  686. return
  687. logger.info(f"SpecDecodeWorker stage times: "
  688. f"average_time_per_proposal_tok_ms="
  689. f"{average_time_per_proposal_tok_ms:.02f} "
  690. f"scoring_time_ms={scoring_time_ms:.02f} "
  691. f"verification_time_ms={verification_time_ms:.02f}")
  692. def _create_dummy_logprob_lists(
  693. self,
  694. batch_size: int,
  695. num_steps: int,
  696. num_top_k: int,
  697. ) -> Tuple[List[List[int]], List[List[float]],
  698. List[List[List[Optional[float]]]],
  699. List[List[List[Optional[int]]]]]:
  700. """
  701. Creates and returns four dummy lists representing token probabilities
  702. and their ranks.
  703. This method initializes and returns:
  704. - The ranks of the accepted tokens, shaped (num_steps, batch_size)
  705. - The log probabilities of the accepted tokens,
  706. shaped (num_steps, batch_size)
  707. - The log probabilities of the top k tokens,
  708. shaped (num_steps, batch_size, num_top_k)
  709. - The token IDs of the top k tokens,
  710. shaped (num_steps, batch_size, num_top_k)
  711. Args:
  712. batch_size (int): The size of the batch.
  713. num_steps (int): The number of steps in the sequence.
  714. num_top_k (int): The number of top-k token log probabilities to
  715. return.
  716. Returns:
  717. A tuple containing four dummy lists as described above.
  718. """
  719. accepted_token_id_ranks_by_step = [[-1] * batch_size
  720. for _ in range(num_steps)]
  721. accepted_token_id_logprobs_by_step = [[0.0] * batch_size
  722. for _ in range(num_steps)]
  723. topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
  724. [None] * num_top_k for _ in range(batch_size)
  725. ] for _ in range(num_steps)]
  726. topk_indices_by_step: List[List[List[Optional[int]]]] = [[
  727. [None] * num_top_k for _ in range(batch_size)
  728. ] for _ in range(num_steps)]
  729. return (accepted_token_id_ranks_by_step,
  730. accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
  731. topk_indices_by_step)
  732. def _create_logprob_lists_from_tensors(
  733. self,
  734. target_logprobs_by_step: torch.Tensor,
  735. accepted_token_ids_by_step: torch.Tensor,
  736. num_top_k: int,
  737. ) -> Tuple[List[List[int]], List[List[float]],
  738. List[List[List[Optional[float]]]],
  739. List[List[List[Optional[int]]]]]:
  740. """
  741. Creates and returns four lists representing token probabilities and
  742. their ranks.
  743. This method initializes and returns four lists containing:
  744. - The ranks of the accepted tokens, shaped (num_steps, batch_size)
  745. - The log probabilities of the accepted tokens,
  746. shaped (num_steps, batch_size)
  747. - The log probabilities of the top k tokens,
  748. shaped (num_steps, batch_size, num_top_k)
  749. - The token IDs of the top k tokens,
  750. shaped (num_steps, batch_size, num_top_k)
  751. Args:
  752. target_logprobs_by_step (torch.Tensor): Tensor representing the
  753. log probabilities of the target model,
  754. shaped (num_steps, batch_size, vocab_size)
  755. accepted_token_ids_by_step (torch.Tensor): Tensor representing
  756. the accepted token_ids, shaped (num_steps, batch_size)
  757. num_top_k (int): The number of top-k token log probabilities to
  758. return.
  759. Returns:
  760. A tuple containing the lists as described above.
  761. """
  762. # Serialize all tensors to CPU Python lists.
  763. # Get the logprobs/rank of the accepted tokens.
  764. (accepted_token_id_ranks_by_step_tensor,
  765. accepted_token_id_logprobs_by_step_tensor
  766. ) = get_sampled_token_logprobs(
  767. logprob_tensor=target_logprobs_by_step,
  768. sampled_token_ids=accepted_token_ids_by_step,
  769. )
  770. # Get the top-k logprobs (which may or may not include the
  771. # logprob of the accepted token).
  772. (topk_logprobs_by_step_tensor,
  773. topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
  774. k=num_top_k,
  775. dim=-1,
  776. )
  777. accepted_token_id_ranks_by_step = (
  778. accepted_token_id_ranks_by_step_tensor.tolist())
  779. accepted_token_id_logprobs_by_step = (
  780. accepted_token_id_logprobs_by_step_tensor.tolist())
  781. topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
  782. topk_indices_by_step = topk_indices_by_step_tensor.tolist()
  783. return (accepted_token_id_ranks_by_step,
  784. accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
  785. topk_indices_by_step)
  786. def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
  787. """
  788. Removes the finished requests and their associated sequence ids from
  789. internal book keeping data structures.
  790. """
  791. for finished_request in execute_model_req.finished_requests_ids:
  792. for seq_id in self._request_id_seq_id_mapping[finished_request]:
  793. self._seq_with_bonus_token_in_last_step.discard(seq_id)
  794. del self._request_id_seq_id_mapping[finished_request]
  795. def _track_sequences_with_bonus_tokens(
  796. self, seq_ids: List[int],
  797. request_ids_seq_ids_mapping: Dict[str, Set[int]],
  798. accepted_token_ids_by_step: List[List[int]]):
  799. """
  800. Updates the internal data structures which keep track of sequences
  801. which have been assigned bonus tokens in their last forward pass.
  802. """
  803. for seq_index, seq_id in enumerate(seq_ids):
  804. last_token_id = accepted_token_ids_by_step[-1][seq_index]
  805. if last_token_id == -1:
  806. self._seq_with_bonus_token_in_last_step.discard(seq_id)
  807. else:
  808. self._seq_with_bonus_token_in_last_step.add(seq_id)
  809. for request_id, sequences in request_ids_seq_ids_mapping.items():
  810. self._request_id_seq_id_mapping[request_id].update(sequences)
  811. @cached_property
  812. def _vocab_size(self) -> int:
  813. """Get the vocab size of the model and make sure it's consistent between
  814. draft and target workers.
  815. """
  816. vocab_sizes = [
  817. worker.vocab_size
  818. for worker in [self.proposer_worker, self.scorer_worker]
  819. ]
  820. assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
  821. return vocab_sizes[0]
  822. @property
  823. def rank(self):
  824. return self.scorer_worker.rank
  825. @property
  826. def device(self):
  827. return self.scorer_worker.device
  828. @property
  829. def _driver_rank(self) -> int:
  830. return 0
  831. def get_cache_block_size_bytes(self):
  832. """Return the size of a cache block in bytes.
  833. This function is only used to compose workers within a SpecDecodeWorker.
  834. We leave composing a SpecDecodeWorker within a SpecDecodeWorker
  835. undefined for now, although it could be implemented in the future.
  836. See https://arxiv.org/abs/2308.04623.
  837. """
  838. raise NotImplementedError
  839. def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
  840. proposer_cache_block_size_bytes: int,
  841. total_num_gpu_blocks: int) -> int:
  842. """Given total_num_gpu_blocks, the number of GPU blocks that could be
  843. allocate to the target model, this function calculates how many blocks
  844. should be given to the draft and target model.
  845. Note that usually the block size, in bytes, of each model is different,
  846. as it's a function of number of KV/layer, number of heads, and hidden
  847. dimension size.
  848. Since the target and draft models allocate the same number of blocks, we
  849. simply calculate the number of blocks where if allocated by both models,
  850. the total memory usage from KV cache is no larger than the number of
  851. blocks allocatable by the target model alone.
  852. """
  853. new_num_gpu_blocks = int(
  854. total_num_gpu_blocks * scorer_cache_block_size_bytes /
  855. (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
  856. return new_num_gpu_blocks
  857. def prepare_prefill_hidden_states(
  858. prefill_hidden_states: torch.Tensor) -> HiddenStates:
  859. # For prefill step in proposer, we run the model for N-1 tokens
  860. # because Nth token will be processed in the first decode step. For
  861. # N-1 tokens, the input should be 0:N-1 hidden states which should
  862. # be concatanated with 1:N token (since output of scorer has to be
  863. # the input for proposer). Therefore, we shift the hidden states to
  864. # align n-1th hidden state with nth token.
  865. return HiddenStates(prefill_hidden_states.roll(
  866. shifts=1, dims=0)) if prefill_hidden_states is not None else None