spec_decode_worker.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952
  1. from collections import defaultdict
  2. from functools import cached_property
  3. from typing import Any, Dict, List, Optional, Set, Tuple
  4. import torch
  5. from loguru import logger
  6. from aphrodite.common.config import ParallelConfig, SpeculativeConfig
  7. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  8. ExecuteModelRequest, HiddenStates,
  9. SamplerOutput, SequenceGroupMetadata,
  10. get_all_seq_ids,
  11. get_all_seq_ids_and_request_ids)
  12. from aphrodite.distributed.communication_op import broadcast_tensor_dict
  13. from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
  14. from aphrodite.modeling.layers.spec_decode_base_sampler import (
  15. SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
  16. from aphrodite.modeling.layers.typical_acceptance_sampler import (
  17. TypicalAcceptanceSampler)
  18. from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
  19. from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
  20. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  21. SpeculativeScorer,
  22. SpeculativeScores)
  23. from aphrodite.spec_decode.medusa_worker import MedusaWorker
  24. from aphrodite.spec_decode.metrics import AsyncMetricsCollector
  25. from aphrodite.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
  26. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  27. from aphrodite.spec_decode.ngram_worker import NGramWorker
  28. from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
  29. from aphrodite.spec_decode.smaller_tp_proposer_worker import (
  30. SmallerTpProposerWorker)
  31. from aphrodite.spec_decode.target_model_runner import TargetModelRunner
  32. from aphrodite.spec_decode.util import (Timer, create_sequence_group_output,
  33. get_all_num_logprobs,
  34. get_sampled_token_logprobs, nvtx_range,
  35. split_batch_by_proposal_len)
  36. from aphrodite.task_handler.worker import Worker
  37. from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase,
  38. WorkerBase)
  39. def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
  40. """Helper method that is the entrypoint for Executors which use
  41. WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
  42. """
  43. assert "speculative_config" in kwargs
  44. speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
  45. assert speculative_config is not None
  46. draft_worker_kwargs = kwargs.copy()
  47. kwargs["model_runner_cls"] = TargetModelRunner
  48. target_worker = Worker(*args, **kwargs)
  49. # Set the disable_logprobs variable in the TargetModelRunner instance
  50. # as per its value specified in the SpeculativeConfig.
  51. target_worker.model_runner.disable_logprobs =\
  52. speculative_config.disable_logprobs
  53. # Override draft-model specific worker args.
  54. draft_worker_kwargs.update(
  55. model_config=speculative_config.draft_model_config,
  56. parallel_config=speculative_config.draft_parallel_config,
  57. ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
  58. ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
  59. # TODO allow draft-model specific load config.
  60. #load_config=load_config,
  61. )
  62. spec_decode_worker = SpecDecodeWorker.create_worker(
  63. scorer_worker=target_worker,
  64. draft_worker_kwargs=draft_worker_kwargs,
  65. disable_by_batch_size=speculative_config.
  66. speculative_disable_by_batch_size,
  67. draft_token_acceptance_method=speculative_config.
  68. draft_token_acceptance_method,
  69. typical_acceptance_sampler_posterior_threshold=speculative_config.
  70. typical_acceptance_sampler_posterior_threshold,
  71. typical_acceptance_sampler_posterior_alpha=speculative_config.
  72. typical_acceptance_sampler_posterior_alpha,
  73. disable_logprobs=speculative_config.disable_logprobs,
  74. disable_log_stats=speculative_config.disable_log_stats,
  75. )
  76. return spec_decode_worker
  77. class SpecDecodeWorker(LoraNotSupportedWorkerBase):
  78. """Worker which implements speculative decoding.
  79. Speculative decoding reduces decoding per-token latency by using a proposal
  80. method, such as a small draft model, to speculate ahead of a larger LLM. The
  81. probabilities of the speculative tokens are then determined by the larger
  82. LLM, after which some verification routine determines which (if any) of the
  83. speculative tokens are accepted by the larger LLM.
  84. The current implementation has the following limitations:
  85. * Only draft-model proposal is implemented (contributions for more forms are
  86. welcome!).
  87. * Only top-1 proposal and scoring are implemented. Tree-attention is left as
  88. future work.
  89. * All sequences in a batch must have the same proposal length, or zero. This
  90. can be improved by having per-sequence speculation in the future.
  91. * The scoring forward pass is done without an MQA kernel, which is
  92. suboptimal especially as the batch size, proposal length, and sequence
  93. lengths grow. Contributions to add a MQA scoring are welcome once
  94. correctness tests pass.
  95. """
  96. @classmethod
  97. def create_worker(
  98. cls,
  99. scorer_worker: Worker,
  100. draft_worker_kwargs: Dict[str, Any],
  101. disable_by_batch_size: Optional[int],
  102. draft_token_acceptance_method: str,
  103. typical_acceptance_sampler_posterior_threshold: float,
  104. typical_acceptance_sampler_posterior_alpha: float,
  105. disable_logprobs: bool,
  106. disable_log_stats: bool,
  107. ) -> "SpecDecodeWorker":
  108. allow_zero_draft_token_step = True
  109. ngram_prompt_lookup_max = (
  110. draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
  111. ngram_prompt_lookup_min = (
  112. draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
  113. if ngram_prompt_lookup_max > 0:
  114. proposer_worker = NGramWorker(**draft_worker_kwargs)
  115. proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
  116. ngram_prompt_lookup_max)
  117. else:
  118. draft_parallel_config: ParallelConfig = draft_worker_kwargs[
  119. 'parallel_config']
  120. draft_tp = draft_parallel_config.tensor_parallel_size
  121. target_tp = scorer_worker.parallel_config.tensor_parallel_size
  122. if draft_worker_kwargs[
  123. "model_config"].hf_config.model_type == "mlp_speculator":
  124. proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
  125. elif draft_worker_kwargs[
  126. "model_config"].hf_config.model_type == "medusa":
  127. proposer_worker = MedusaWorker(**draft_worker_kwargs)
  128. else:
  129. if draft_tp == 1:
  130. draft_worker_kwargs[
  131. "model_runner_cls"] = TP1DraftModelRunner
  132. else:
  133. allow_zero_draft_token_step = False
  134. proposer_worker = MultiStepWorker(**draft_worker_kwargs)
  135. proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
  136. proposer_worker, draft_tp, target_tp)
  137. logger.info("Configuring SpecDecodeWorker with "
  138. f"proposer={type(proposer_worker)}")
  139. spec_decode_sampler: SpecDecodeBaseSampler = None
  140. if draft_token_acceptance_method == "rejection_sampler":
  141. spec_decode_sampler = RejectionSampler(
  142. disable_bonus_tokens=False, )
  143. elif draft_token_acceptance_method == "typical_acceptance_sampler":
  144. spec_decode_sampler = TypicalAcceptanceSampler(
  145. disable_bonus_tokens=False,
  146. posterior_threshold=\
  147. typical_acceptance_sampler_posterior_threshold,
  148. posterior_alpha=typical_acceptance_sampler_posterior_alpha,
  149. )
  150. logger.info("Configuring SpecDecodeWorker with "
  151. f"sampler={type(spec_decode_sampler)}")
  152. return SpecDecodeWorker(
  153. proposer_worker,
  154. scorer_worker,
  155. disable_logprobs=disable_logprobs,
  156. disable_by_batch_size=disable_by_batch_size,
  157. disable_log_stats=disable_log_stats,
  158. spec_decode_sampler=spec_decode_sampler,
  159. allow_zero_draft_token_step=allow_zero_draft_token_step)
  160. def __init__(
  161. self,
  162. proposer_worker: ProposerWorkerBase,
  163. scorer_worker: WorkerBase,
  164. spec_decode_sampler: SpecDecodeBaseSampler,
  165. disable_logprobs: bool = False,
  166. disable_log_stats: bool = False,
  167. metrics_collector: Optional[AsyncMetricsCollector] = None,
  168. disable_by_batch_size: Optional[int] = None,
  169. allow_zero_draft_token_step: Optional[bool] = True,
  170. ):
  171. """
  172. Create a SpecDecodeWorker.
  173. Args:
  174. proposer_worker: A worker that can produce speculative tokens for
  175. sequences.
  176. scorer_worker: A worker that produces probabilities of speculative
  177. tokens according to some base model. Typically a vanilla
  178. Aphrodite Worker.
  179. spec_decode_sampler: A Torch module used to perform acceptance
  180. sampling of the draft tokens in the verification step of
  181. speculative decoding. Currently we support two different
  182. types of sampler namely RejectionSampler and
  183. TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
  184. instance of RejectionSampler or TypicalAcceptanceSampler.
  185. disable_logprobs: If set to True, token log probabilities will
  186. not be output in both the draft worker and the target worker.
  187. If set to False, log probabilities will be output by both.
  188. disable_log_stats: If set to True, disable periodic printing of
  189. speculative stage times.
  190. disable_by_batch_size: If the batch size is larger than this,
  191. disable speculative decoding for new incoming requests.
  192. metrics_collector: Helper class for collecting metrics; can be set
  193. for testing purposes.
  194. allow_zero_draft_token_step: whether to allow a step where the draft
  195. model generates no draft token; should disallow when the tp of
  196. draft model is larger than 1
  197. """
  198. self.proposer_worker = proposer_worker
  199. self.scorer_worker = scorer_worker
  200. scorer_runner = getattr(self.scorer_worker, "model_runner", None)
  201. self.generators = scorer_runner.get_generators(
  202. ) if scorer_runner else None
  203. self.disable_by_batch_size = disable_by_batch_size or float("inf")
  204. self.spec_decode_sampler = spec_decode_sampler
  205. self._allow_zero_draft_token_step = allow_zero_draft_token_step
  206. self._metrics = AsyncMetricsCollector(
  207. self.spec_decode_sampler
  208. ) if metrics_collector is None else metrics_collector
  209. # Tracks the sequence IDs that received a bonus token ID in
  210. # their last forward pass. Needed only if KV cache is being
  211. # used for token generation such as in the case of MultiStepWorker.
  212. self._seq_with_bonus_token_in_last_step: Set[int] = set()
  213. # Tracks the currently active request ids and the sequence IDs
  214. # corresponding to them
  215. self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
  216. # Tracks if the proposer worker uses the KV cache or not.
  217. self.probs_dtype = self.spec_decode_sampler.probs_dtype
  218. self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
  219. # Lazy initialization.
  220. self.scorer: SpeculativeScorer
  221. # Hidden states from target model to pass to proposer
  222. # in the subsequent step.
  223. self.previous_hidden_states: Optional[HiddenStates] = None
  224. self._disable_logprobs = disable_logprobs
  225. self._disable_log_stats = disable_log_stats
  226. def init_device(self) -> None:
  227. """Initialize both scorer and proposer models.
  228. """
  229. # The scorer worker model is initialized first in case the proposer
  230. # model has a smaller TP degree than the target worker.
  231. self.scorer_worker.init_device()
  232. self.proposer_worker.init_device()
  233. # NOTE: load_model is not part of the WorkerBase interface.
  234. self.scorer_worker.load_model()
  235. self.proposer_worker.load_model()
  236. self._metrics.init_gpu_tensors(self.rank)
  237. self.spec_decode_sampler.init_gpu_tensors(self.rank)
  238. self.scorer = BatchExpansionTop1Scorer(
  239. scorer_worker=self.scorer_worker,
  240. device=self.device,
  241. vocab_size=self._vocab_size)
  242. self._configure_model_sampler_for_spec_decode()
  243. def load_model(self, *args, **kwargs):
  244. pass
  245. def _configure_model_sampler_for_spec_decode(self):
  246. """Configure model sampler to emit GPU tensors. This allows spec decode
  247. to keep data on device without transferring to CPU and serializing,
  248. which significantly reduces overhead of sampling during verification.
  249. NOTE: This breaks abstraction boundaries pretty badly. The better
  250. design is to have the "move to CPU and serialize" sampling decision be
  251. done outside of the model/sampler; this way the "last-mile" worker
  252. object which interfaces with the scheduler can serialize and incur the
  253. performance hit as necessary. This allows us to run the worker several
  254. iterations in a row without incurring the "move to CPU and serialize"
  255. performance penalty.
  256. Since this requires a large change to Aphrodite, we defer it to later
  257. and temporarily accept this broken abstraction boundary.
  258. NOTE: This will require a special check if the proposer worker
  259. does not have a sampler (e.g. ngram speculation).
  260. """
  261. (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
  262. ) = True
  263. (self.scorer_worker.model_runner.model.sampler.
  264. should_modify_greedy_probs_inplace) = True
  265. self.proposer_worker.set_include_gpu_probs_tensor()
  266. self.proposer_worker.set_should_modify_greedy_probs_inplace()
  267. def determine_num_available_blocks(self) -> Tuple[int, int]:
  268. """Determine the number of cache blocks to use.
  269. This is done by profiling the scorer model (which is typically the
  270. larger of the two). Then the total memory which would be used by the
  271. scorer cache is divided evenly between the proposer and scorer model KV,
  272. such that the number of blocks is equal in both KV caches.
  273. """
  274. num_gpu_blocks, num_cpu_blocks = (
  275. self.scorer_worker.determine_num_available_blocks())
  276. scorer_cache_block_size_bytes = (
  277. self.scorer_worker.get_cache_block_size_bytes())
  278. proposer_cache_block_size_bytes = (
  279. self.proposer_worker.get_cache_block_size_bytes())
  280. new_num_gpu_blocks = split_num_cache_blocks_evenly(
  281. scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
  282. num_gpu_blocks)
  283. return new_num_gpu_blocks, num_cpu_blocks
  284. def initialize_cache(self, num_gpu_blocks: int,
  285. num_cpu_blocks: int) -> None:
  286. """Initialize the cache engine of the scorer and proposer workers.
  287. """
  288. self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  289. num_cpu_blocks=num_cpu_blocks)
  290. self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  291. num_cpu_blocks=num_cpu_blocks)
  292. @torch.inference_mode()
  293. def execute_model(
  294. self,
  295. execute_model_req: Optional[ExecuteModelRequest] = None
  296. ) -> List[SamplerOutput]:
  297. """Perform speculative decoding on the input batch.
  298. """
  299. if self.rank != self._driver_rank:
  300. self._run_non_driver_rank()
  301. return []
  302. if execute_model_req is None:
  303. # This signals that there's no more requests to process for now.
  304. # All workers are running infinite loop with broadcast_tensor_dict,
  305. # and it stops the loop when the driver broadcasts an empty input.
  306. # Send an empty input to notify all other workers to stop their
  307. # execution loop.
  308. broadcast_tensor_dict({}, src=0)
  309. return []
  310. self._track_finished_requests(execute_model_req)
  311. disable_all_speculation = self._should_disable_all_speculation(
  312. execute_model_req)
  313. num_lookahead_slots = execute_model_req.num_lookahead_slots
  314. # Broadcast how many lookahead slots are scheduled for this step, and
  315. # whether all speculation is disabled, to all non-driver workers.
  316. # This is required as if the number of draft model runs changes
  317. # dynamically, the non-driver workers won't know unless we perform a
  318. # communication to inform them.
  319. broadcast_dict = dict(
  320. num_lookahead_slots=num_lookahead_slots,
  321. disable_all_speculation=disable_all_speculation,
  322. )
  323. broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
  324. assert execute_model_req.seq_group_metadata_list is not None, (
  325. "speculative decoding requires non-None seq_group_metadata_list")
  326. self._maybe_disable_speculative_tokens(
  327. disable_all_speculation, execute_model_req.seq_group_metadata_list)
  328. # Speculative decoding is disabled in the following cases:
  329. # 1. Prefill phase: Speculative decoding is not
  330. # used during the prefill phase.
  331. # 2. Auto-disable enabled: The running queue size exceeds
  332. # the specified threshold.
  333. # 3. No request: There are no requests in the batch.
  334. # In any of these cases, the proposer and scorer workers
  335. # are called normally.
  336. if num_lookahead_slots == 0 or len(
  337. execute_model_req.seq_group_metadata_list
  338. ) == 0 or disable_all_speculation:
  339. return self._run_no_spec(execute_model_req,
  340. skip_proposer=disable_all_speculation)
  341. return self._run_speculative_decoding_step(execute_model_req,
  342. num_lookahead_slots)
  343. @torch.inference_mode()
  344. def start_worker_execution_loop(self) -> None:
  345. """Execute model loop to perform speculative decoding
  346. in parallel worker."""
  347. while self._run_non_driver_rank():
  348. pass
  349. def _should_disable_all_speculation(
  350. self, execute_model_req: ExecuteModelRequest) -> bool:
  351. # When the batch size is too large, disable speculative decoding
  352. # to stop trading off throughput for latency.
  353. disable_all_speculation = (execute_model_req.running_queue_size >=
  354. self.disable_by_batch_size)
  355. return disable_all_speculation
  356. def _maybe_disable_speculative_tokens(
  357. self, disable_all_speculation: bool,
  358. seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
  359. if not disable_all_speculation:
  360. return
  361. for seq_group_metadata in seq_group_metadata_list:
  362. # Once num_speculative_tokens is set to 0, the spec decode
  363. # of this request will be disabled forever.
  364. # TODO: We currently store spec decoding specific
  365. # state in the global data structure, but we should maintain
  366. # this state within spec decode worker.
  367. seq_group_metadata.num_speculative_tokens = 0
  368. def _serialize_sampler_output_no_logprobs(
  369. self, execute_model_req: ExecuteModelRequest,
  370. sampler_output: SamplerOutput) -> SamplerOutput:
  371. """
  372. Creates and returns a `SamplerOutput` with only the sampled token IDs
  373. being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
  374. All other parameters in `CompletionSequenceGroupOutput` related to log
  375. probabilities are skipped.
  376. Args:
  377. execute_model_req (ExecuteModelRequest): The model request that
  378. was executed.
  379. sampler_output (SamplerOutput): The output from the sampler with
  380. only GPU tensors populated.
  381. Returns:
  382. SamplerOutput: A new `SamplerOutput` instance containing a list of
  383. `CompletionSequenceGroupOutput` objects with only sampled token
  384. IDs populated.
  385. """
  386. seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
  387. sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
  388. completion_seq_group_output_list: List[
  389. CompletionSequenceGroupOutput] = []
  390. for index, seq_id in enumerate(seq_ids):
  391. completion_seq_group_output_list.append(
  392. create_sequence_group_output(
  393. token_id=sampled_token_ids_list[index][0],
  394. token_id_logprob_rank=-1,
  395. token_id_logprob=0.0,
  396. seq_id=seq_id,
  397. topk_token_ids=[],
  398. topk_logprobs=[],
  399. ))
  400. return SamplerOutput(outputs=completion_seq_group_output_list)
  401. @nvtx_range("spec_decode_worker._run_no_spec")
  402. def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
  403. skip_proposer: bool) -> List[SamplerOutput]:
  404. """Run a single generation step without any speculation. The input is
  405. sent to the proposer and scorer model so that the KV cache is consistent
  406. between the two. When skip_proposer is True, the proposer model is
  407. not called, meaning that the kv-cache in proposer for requests is not
  408. updated, so they cannot enable spec decode in the rest decoding.
  409. """
  410. if not skip_proposer:
  411. self.proposer_worker.execute_model(execute_model_req)
  412. sampler_output = self.scorer_worker.execute_model(execute_model_req)
  413. assert len(sampler_output) == 1
  414. sampler_output = sampler_output[0]
  415. # Store hidden states from target model execution.
  416. hidden_states = sampler_output.hidden_states
  417. if hidden_states is not None:
  418. if self.previous_hidden_states is None:
  419. self.previous_hidden_states = HiddenStates(
  420. execute_model_req.seq_group_metadata_list, hidden_states)
  421. else:
  422. self.previous_hidden_states.update(
  423. execute_model_req.seq_group_metadata_list, hidden_states)
  424. sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
  425. execute_model_req=execute_model_req, sampler_output=sampler_output)
  426. if self._disable_logprobs else
  427. sampler_output)
  428. # Clear device tensors from sampler output. This reduces communication
  429. # overhead when the engine runs in a different process than the workers.
  430. sampler_output.sampled_token_probs = None
  431. sampler_output.sampled_token_ids = None
  432. sampler_output.logprobs = None
  433. return [sampler_output_to_return]
  434. def _run_non_driver_rank(self) -> bool:
  435. """Run proposer and verifier model in non-driver workers. This is used
  436. for both speculation cases (num_lookahead_slots>0) and non-speculation
  437. cases (e.g. prefill).
  438. Returns True iff there are remaining sequences to process.
  439. """
  440. assert self.rank != self._driver_rank
  441. data = broadcast_tensor_dict(src=self._driver_rank)
  442. if not data:
  443. return False
  444. num_lookahead_slots = data["num_lookahead_slots"]
  445. # Even if num_lookahead_slots is zero, we want to run the proposer model
  446. # as it may have KV.
  447. #
  448. # We run the proposer once per lookahead slot. In the future we should
  449. # delegate how many times it runs to the proposer.
  450. for _ in range(max(num_lookahead_slots, 1)):
  451. self.proposer_worker.execute_model()
  452. self.scorer_worker.execute_model()
  453. return True
  454. @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
  455. def _run_speculative_decoding_step(
  456. self, execute_model_req: ExecuteModelRequest,
  457. num_lookahead_slots: int) -> List[SamplerOutput]:
  458. """Execute a single step of speculative decoding.
  459. This invokes the proposer worker to get k speculative tokens for each
  460. sequence, then scores each speculative token using the scoring worker.
  461. Returns a list of SamplerOutput, each containing a single token per
  462. sequence.
  463. """
  464. assert num_lookahead_slots == execute_model_req.num_lookahead_slots
  465. # Pass last hidden states from target model to proposer
  466. execute_model_req.previous_hidden_states = self.previous_hidden_states
  467. self.previous_hidden_states = None
  468. with Timer() as proposal_timer:
  469. # Generate proposals using draft worker.
  470. proposals = self.proposer_worker.get_spec_proposals(
  471. execute_model_req, self._seq_with_bonus_token_in_last_step)
  472. if not self._allow_zero_draft_token_step and proposals.no_proposals:
  473. #TODO: Fix it #5814
  474. raise RuntimeError("Cannot handle cases where distributed draft "
  475. "workers generate no tokens")
  476. with Timer() as scoring_timer:
  477. proposal_scores = self.scorer.score_proposals(
  478. execute_model_req,
  479. proposals,
  480. )
  481. with Timer() as verification_timer:
  482. accepted_token_ids, target_logprobs = self._verify_tokens(
  483. execute_model_req.seq_group_metadata_list, proposal_scores,
  484. proposals, execute_model_req.num_lookahead_slots)
  485. stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
  486. scoring_timer.elapsed_time_ms,
  487. verification_timer.elapsed_time_ms)
  488. return self._create_output_sampler_list(
  489. execute_model_req.seq_group_metadata_list,
  490. accepted_token_ids,
  491. target_logprobs=target_logprobs,
  492. k=execute_model_req.num_lookahead_slots,
  493. stage_times=stage_times)
  494. @nvtx_range("spec_decode_worker._verify_tokens")
  495. def _verify_tokens(
  496. self,
  497. seq_group_metadata_list: List[SequenceGroupMetadata],
  498. proposal_scores: SpeculativeScores,
  499. proposals: SpeculativeProposals,
  500. max_proposal_len: int,
  501. ) -> Tuple[torch.Tensor, torch.Tensor]:
  502. """Determine which speculative tokens are accepted using the
  503. probabilities of each token according to the proposer and scorer models.
  504. Returns a tuple of Tensors, one for the accepted token ids and one for
  505. the logprobs according to the scoring model.
  506. """
  507. proposal_lens_list = proposals.proposal_lens.tolist()
  508. # Aphrodite currently only supports proposal lens equal to zero or the
  509. # batch proposal len. This adds some complexity (splitting the batch
  510. # into spec and non spec sequences) and should be removed in the
  511. # future. It can be done by supporting per-sequence proposal lens.
  512. _, spec_indices = split_batch_by_proposal_len(
  513. seq_group_metadata_list,
  514. proposal_lens_list,
  515. select_proposal_len_zero=False)
  516. _, non_spec_indices = split_batch_by_proposal_len(
  517. seq_group_metadata_list,
  518. proposal_lens_list,
  519. select_proposal_len_zero=True)
  520. original_indices = spec_indices + non_spec_indices
  521. # Get probabilities of target model, excluding bonus token.
  522. proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
  523. # Get non-speculative sampled tokens from target model.
  524. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
  525. # Get bonus tokens from target model.
  526. bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
  527. # Get probabilities according to proposal method.
  528. proposal_probs = proposals.proposal_probs[spec_indices]
  529. # Get proposed tokens.
  530. proposal_token_ids = proposals.proposal_token_ids[spec_indices]
  531. # Sampler arguments
  532. sampler_extra_kwargs: Dict[str, Any] = {}
  533. if self.generators and isinstance(self.spec_decode_sampler,
  534. SpecDecodeStochasticBaseSampler):
  535. sampler_extra_kwargs["seeded_seqs"] = {
  536. idx: self.generators[sgm.request_id]
  537. for idx, sgm in enumerate(seq_group_metadata_list)
  538. if sgm.sampling_params.seed is not None
  539. }
  540. accepted_token_ids = self.spec_decode_sampler(
  541. target_probs=proposal_verifier_probs,
  542. bonus_token_ids=bonus_token_ids,
  543. draft_probs=proposal_probs,
  544. draft_token_ids=proposal_token_ids,
  545. **sampler_extra_kwargs,
  546. )
  547. # Append output tokens from non-speculative sequences to
  548. # the accepted token ids tensor.
  549. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
  550. 1).clone()
  551. non_spec_token_ids[:, 1:] = -1
  552. accepted_token_ids = torch.cat(
  553. [accepted_token_ids, non_spec_token_ids])
  554. logprobs = proposal_scores.logprobs
  555. # Rearrange so that results are in the order of the original seq group
  556. # metadata.
  557. accepted_token_ids[original_indices] = accepted_token_ids.clone()
  558. hidden_states = proposal_scores.hidden_states
  559. if hidden_states is not None:
  560. # Contract hidden states based on accepted tokens
  561. hs_size = hidden_states.shape[1]
  562. hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
  563. hs_size)
  564. accepted_index = accepted_token_ids + 1 # Convert -1 to 0
  565. accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
  566. index = accepted_index[:, None, None].expand(-1, 1, hs_size)
  567. hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
  568. # Store hidden states from target model for subsequent decode step
  569. self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
  570. hidden_states)
  571. return accepted_token_ids, logprobs
  572. def _create_output_sampler_list(
  573. self,
  574. seq_group_metadata_list: List[SequenceGroupMetadata],
  575. accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
  576. target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
  577. k: int,
  578. stage_times: Tuple[float, float, float],
  579. ) -> List[SamplerOutput]:
  580. """Given the accepted token ids, create a list of SamplerOutput.
  581. The output is padded with -1 tokens such that each sequence has
  582. the same number of outputs.
  583. """
  584. batch_size, num_steps = accepted_token_ids.shape
  585. accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
  586. if self._disable_logprobs:
  587. # We are skipping the logprobs. Hence don't serialize the
  588. # logprobs related tensors from the GPU. Instead create
  589. # empty/dummy lists.
  590. (accepted_token_id_ranks_by_step,
  591. accepted_token_id_logprobs_by_step,
  592. topk_logprobs_by_step, topk_indices_by_step) =\
  593. self._create_dummy_logprob_lists(
  594. batch_size, num_steps,
  595. self.scorer_worker.model_config.max_logprobs)
  596. else:
  597. # Organize input tensors by step instead of by sequence.
  598. target_logprobs_by_step = target_logprobs.transpose(0, 1)
  599. # Serialize all tensors into Python lists.
  600. (accepted_token_id_ranks_by_step,
  601. accepted_token_id_logprobs_by_step,
  602. topk_logprobs_by_step, topk_indices_by_step) =\
  603. self._create_logprob_lists_from_tensors(
  604. target_logprobs_by_step, accepted_token_ids_by_step,
  605. self.scorer_worker.model_config.max_logprobs)
  606. # Get the sequence ids and num_logprobs (sampling parameter) in the
  607. # batch.
  608. seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
  609. seq_group_metadata_list)
  610. num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
  611. # Serialize tensor to CPU Python list.
  612. accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
  613. # Construct the output on a per-step, per-sequence basis.
  614. sampler_output_list: List[SamplerOutput] = []
  615. for step_index in range(num_steps):
  616. if all(token_id == -1
  617. for token_id in accepted_token_ids_by_step[step_index]):
  618. break
  619. step_output_token_ids: List[CompletionSequenceGroupOutput] = []
  620. for sequence_index in range(batch_size):
  621. # Each sequence may have a different num_logprobs; retrieve it.
  622. num_logprobs = num_logprobs_per_seq[sequence_index]
  623. step_output_token_ids.append(
  624. create_sequence_group_output(
  625. token_id=accepted_token_ids_by_step[step_index]
  626. [sequence_index],
  627. token_id_logprob_rank=accepted_token_id_ranks_by_step[
  628. step_index][sequence_index],
  629. token_id_logprob=accepted_token_id_logprobs_by_step[
  630. step_index][sequence_index],
  631. seq_id=seq_ids[sequence_index],
  632. topk_token_ids=topk_indices_by_step[step_index]
  633. [sequence_index][:num_logprobs],
  634. topk_logprobs=topk_logprobs_by_step[step_index]
  635. [sequence_index][:num_logprobs],
  636. ))
  637. sampler_output_list.append(
  638. SamplerOutput(outputs=step_output_token_ids))
  639. # Populate the data structures needed to keep track of sequences with
  640. # bonus tokens.
  641. self._track_sequences_with_bonus_tokens(seq_ids,
  642. request_ids_seq_ids_mapping,
  643. accepted_token_ids_by_step)
  644. maybe_rejsample_metrics = (
  645. self._metrics.maybe_collect_rejsample_metrics(k))
  646. if maybe_rejsample_metrics is not None:
  647. sampler_output_list[
  648. 0].spec_decode_worker_metrics = maybe_rejsample_metrics
  649. # Log time spent in each stage periodically.
  650. # This is periodic because the rejection sampler emits metrics
  651. # periodically.
  652. self._maybe_log_stage_times(*stage_times)
  653. return sampler_output_list
  654. def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
  655. scoring_time_ms: float,
  656. verification_time_ms: float) -> None:
  657. """Log the speculative stage times. If stat logging is disabled, do
  658. nothing.
  659. """
  660. if self._disable_log_stats:
  661. return
  662. logger.info(f"SpecDecodeWorker stage times: "
  663. f"average_time_per_proposal_tok_ms="
  664. f"{average_time_per_proposal_tok_ms:.02f} "
  665. f"scoring_time_ms={scoring_time_ms:.02f} "
  666. f"verification_time_ms={verification_time_ms:.02f}")
  667. def _create_dummy_logprob_lists(
  668. self,
  669. batch_size: int,
  670. num_steps: int,
  671. num_top_k: int,
  672. ) -> Tuple[List[List[int]], List[List[float]],
  673. List[List[List[Optional[float]]]],
  674. List[List[List[Optional[int]]]]]:
  675. """
  676. Creates and returns four dummy lists representing token probabilities
  677. and their ranks.
  678. This method initializes and returns:
  679. - The ranks of the accepted tokens, shaped (num_steps, batch_size)
  680. - The log probabilities of the accepted tokens,
  681. shaped (num_steps, batch_size)
  682. - The log probabilities of the top k tokens,
  683. shaped (num_steps, batch_size, num_top_k)
  684. - The token IDs of the top k tokens,
  685. shaped (num_steps, batch_size, num_top_k)
  686. Args:
  687. batch_size (int): The size of the batch.
  688. num_steps (int): The number of steps in the sequence.
  689. num_top_k (int): The number of top-k token log probabilities to
  690. return.
  691. Returns:
  692. A tuple containing four dummy lists as described above.
  693. """
  694. accepted_token_id_ranks_by_step = [[-1] * batch_size
  695. for _ in range(num_steps)]
  696. accepted_token_id_logprobs_by_step = [[0.0] * batch_size
  697. for _ in range(num_steps)]
  698. topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
  699. [None] * num_top_k for _ in range(batch_size)
  700. ] for _ in range(num_steps)]
  701. topk_indices_by_step: List[List[List[Optional[int]]]] = [[
  702. [None] * num_top_k for _ in range(batch_size)
  703. ] for _ in range(num_steps)]
  704. return (accepted_token_id_ranks_by_step,
  705. accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
  706. topk_indices_by_step)
  707. def _create_logprob_lists_from_tensors(
  708. self,
  709. target_logprobs_by_step: torch.Tensor,
  710. accepted_token_ids_by_step: torch.Tensor,
  711. num_top_k: int,
  712. ) -> Tuple[List[List[int]], List[List[float]],
  713. List[List[List[Optional[float]]]],
  714. List[List[List[Optional[int]]]]]:
  715. """
  716. Creates and returns four lists representing token probabilities and
  717. their ranks.
  718. This method initializes and returns four lists containing:
  719. - The ranks of the accepted tokens, shaped (num_steps, batch_size)
  720. - The log probabilities of the accepted tokens,
  721. shaped (num_steps, batch_size)
  722. - The log probabilities of the top k tokens,
  723. shaped (num_steps, batch_size, num_top_k)
  724. - The token IDs of the top k tokens,
  725. shaped (num_steps, batch_size, num_top_k)
  726. Args:
  727. target_logprobs_by_step (torch.Tensor): Tensor representing the
  728. log probabilities of the target model,
  729. shaped (num_steps, batch_size, vocab_size)
  730. accepted_token_ids_by_step (torch.Tensor): Tensor representing
  731. the accepted token_ids, shaped (num_steps, batch_size)
  732. num_top_k (int): The number of top-k token log probabilities to
  733. return.
  734. Returns:
  735. A tuple containing the lists as described above.
  736. """
  737. # Serialize all tensors to CPU Python lists.
  738. # Get the logprobs/rank of the accepted tokens.
  739. (accepted_token_id_ranks_by_step_tensor,
  740. accepted_token_id_logprobs_by_step_tensor
  741. ) = get_sampled_token_logprobs(
  742. logprob_tensor=target_logprobs_by_step,
  743. sampled_token_ids=accepted_token_ids_by_step,
  744. )
  745. # Get the top-k logprobs (which may or may not include the
  746. # logprob of the accepted token).
  747. (topk_logprobs_by_step_tensor,
  748. topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
  749. k=num_top_k,
  750. dim=-1,
  751. )
  752. accepted_token_id_ranks_by_step = (
  753. accepted_token_id_ranks_by_step_tensor.tolist())
  754. accepted_token_id_logprobs_by_step = (
  755. accepted_token_id_logprobs_by_step_tensor.tolist())
  756. topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
  757. topk_indices_by_step = topk_indices_by_step_tensor.tolist()
  758. return (accepted_token_id_ranks_by_step,
  759. accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
  760. topk_indices_by_step)
  761. def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
  762. """
  763. Removes the finished requests and their associated sequence ids from
  764. internal book keeping data structures.
  765. """
  766. for finished_request in execute_model_req.finished_requests_ids:
  767. for seq_id in self._request_id_seq_id_mapping[finished_request]:
  768. self._seq_with_bonus_token_in_last_step.discard(seq_id)
  769. del self._request_id_seq_id_mapping[finished_request]
  770. def _track_sequences_with_bonus_tokens(
  771. self, seq_ids: List[int],
  772. request_ids_seq_ids_mapping: Dict[str, Set[int]],
  773. accepted_token_ids_by_step: List[List[int]]):
  774. """
  775. Updates the internal data structures which keep track of sequences
  776. which have been assigned bonus tokens in their last forward pass.
  777. """
  778. for seq_index, seq_id in enumerate(seq_ids):
  779. last_token_id = accepted_token_ids_by_step[-1][seq_index]
  780. if last_token_id == -1:
  781. self._seq_with_bonus_token_in_last_step.discard(seq_id)
  782. else:
  783. self._seq_with_bonus_token_in_last_step.add(seq_id)
  784. for request_id, sequences in request_ids_seq_ids_mapping.items():
  785. self._request_id_seq_id_mapping[request_id].update(sequences)
  786. @cached_property
  787. def _vocab_size(self) -> int:
  788. """Get the vocab size of the model and make sure it's consistent between
  789. draft and target workers.
  790. """
  791. vocab_sizes = [
  792. worker.vocab_size
  793. for worker in [self.proposer_worker, self.scorer_worker]
  794. ]
  795. assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
  796. return vocab_sizes[0]
  797. @property
  798. def rank(self):
  799. return self.scorer_worker.rank
  800. @property
  801. def device(self):
  802. return self.scorer_worker.device
  803. @property
  804. def _driver_rank(self) -> int:
  805. return 0
  806. def get_cache_block_size_bytes(self):
  807. """Return the size of a cache block in bytes.
  808. This function is only used to compose workers within a SpecDecodeWorker.
  809. We leave composing a SpecDecodeWorker within a SpecDecodeWorker
  810. undefined for now, although it could be implemented in the future.
  811. See https://arxiv.org/abs/2308.04623.
  812. """
  813. raise NotImplementedError
  814. def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
  815. proposer_cache_block_size_bytes: int,
  816. total_num_gpu_blocks: int) -> int:
  817. """Given total_num_gpu_blocks, the number of GPU blocks that could be
  818. allocate to the target model, this function calculates how many blocks
  819. should be given to the draft and target model.
  820. Note that usually the block size, in bytes, of each model is different,
  821. as it's a function of number of KV/layer, number of heads, and hidden
  822. dimension size.
  823. Since the target and draft models allocate the same number of blocks, we
  824. simply calculate the number of blocks where if allocated by both models,
  825. the total memory usage from KV cache is no larger than the number of
  826. blocks allocatable by the target model alone.
  827. """
  828. new_num_gpu_blocks = int(
  829. total_num_gpu_blocks * scorer_cache_block_size_bytes /
  830. (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
  831. return new_num_gpu_blocks