spec_decode_worker.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913
  1. from collections import defaultdict
  2. from functools import cached_property
  3. from typing import Any, Dict, List, Optional, Set, Tuple
  4. import torch
  5. from loguru import logger
  6. from aphrodite.common.config import ParallelConfig, SpeculativeConfig
  7. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  8. ExecuteModelRequest, HiddenStates,
  9. SamplerOutput, SequenceGroupMetadata,
  10. get_all_seq_ids,
  11. get_all_seq_ids_and_request_ids)
  12. from aphrodite.distributed.communication_op import broadcast_tensor_dict
  13. from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
  14. from aphrodite.modeling.layers.spec_decode_base_sampler import (
  15. SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
  16. from aphrodite.modeling.layers.typical_acceptance_sampler import \
  17. TypicalAcceptanceSampler
  18. from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
  19. from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
  20. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  21. SpeculativeScorer,
  22. SpeculativeScores)
  23. from aphrodite.spec_decode.medusa_worker import MedusaWorker
  24. from aphrodite.spec_decode.metrics import AsyncMetricsCollector
  25. from aphrodite.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
  26. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  27. from aphrodite.spec_decode.ngram_worker import NGramWorker
  28. from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
  29. from aphrodite.spec_decode.smaller_tp_proposer_worker import \
  30. SmallerTpProposerWorker
  31. from aphrodite.spec_decode.target_model_runner import TargetModelRunner
  32. from aphrodite.spec_decode.util import (create_sequence_group_output,
  33. get_all_num_logprobs,
  34. get_sampled_token_logprobs, nvtx_range,
  35. split_batch_by_proposal_len)
  36. from aphrodite.task_handler.worker import Worker
  37. from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase,
  38. WorkerBase)
  39. def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
  40. """Helper method that is the entrypoint for Executors which use
  41. WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
  42. """
  43. assert "speculative_config" in kwargs
  44. speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
  45. assert speculative_config is not None
  46. draft_worker_kwargs = kwargs.copy()
  47. kwargs["model_runner_cls"] = TargetModelRunner
  48. target_worker = Worker(*args, **kwargs)
  49. # Set the disable_logprobs variable in the TargetModelRunner instance
  50. # as per its value specified in the SpeculativeConfig.
  51. target_worker.model_runner.disable_logprobs =\
  52. speculative_config.disable_logprobs
  53. # Override draft-model specific worker args.
  54. draft_worker_kwargs.update(
  55. model_config=speculative_config.draft_model_config,
  56. parallel_config=speculative_config.draft_parallel_config,
  57. ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
  58. ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
  59. # TODO allow draft-model specific load config.
  60. #load_config=load_config,
  61. )
  62. spec_decode_worker = SpecDecodeWorker.create_worker(
  63. scorer_worker=target_worker,
  64. draft_worker_kwargs=draft_worker_kwargs,
  65. disable_by_batch_size=speculative_config.
  66. speculative_disable_by_batch_size,
  67. draft_token_acceptance_method=speculative_config.
  68. draft_token_acceptance_method,
  69. typical_acceptance_sampler_posterior_threshold=speculative_config.
  70. typical_acceptance_sampler_posterior_threshold,
  71. typical_acceptance_sampler_posterior_alpha=speculative_config.
  72. typical_acceptance_sampler_posterior_alpha,
  73. disable_logprobs=speculative_config.disable_logprobs)
  74. return spec_decode_worker
  75. class SpecDecodeWorker(LoraNotSupportedWorkerBase):
  76. """Worker which implements speculative decoding.
  77. Speculative decoding reduces decoding per-token latency by using a proposal
  78. method, such as a small draft model, to speculate ahead of a larger LLM. The
  79. probabilities of the speculative tokens are then determined by the larger
  80. LLM, after which some verification routine determines which (if any) of the
  81. speculative tokens are accepted by the larger LLM.
  82. The current implementation has the following limitations:
  83. * Only draft-model proposal is implemented (contributions for more forms are
  84. welcome!).
  85. * Only top-1 proposal and scoring are implemented. Tree-attention is left as
  86. future work.
  87. * All sequences in a batch must have the same proposal length, or zero. This
  88. can be improved by having per-sequence speculation in the future.
  89. * The scoring forward pass is done without an MQA kernel, which is
  90. suboptimal especially as the batch size, proposal length, and sequence
  91. lengths grow. Contributions to add a MQA scoring are welcome once
  92. correctness tests pass.
  93. """
  94. @classmethod
  95. def create_worker(
  96. cls,
  97. scorer_worker: Worker,
  98. draft_worker_kwargs: Dict[str, Any],
  99. disable_by_batch_size: Optional[int],
  100. draft_token_acceptance_method: str,
  101. typical_acceptance_sampler_posterior_threshold: float,
  102. typical_acceptance_sampler_posterior_alpha: float,
  103. disable_logprobs: bool,
  104. ) -> "SpecDecodeWorker":
  105. allow_zero_draft_token_step = True
  106. ngram_prompt_lookup_max = (
  107. draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
  108. ngram_prompt_lookup_min = (
  109. draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
  110. if ngram_prompt_lookup_max > 0:
  111. proposer_worker = NGramWorker(**draft_worker_kwargs)
  112. proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
  113. ngram_prompt_lookup_max)
  114. else:
  115. draft_parallel_config: ParallelConfig = draft_worker_kwargs[
  116. 'parallel_config']
  117. draft_tp = draft_parallel_config.tensor_parallel_size
  118. target_tp = scorer_worker.parallel_config.tensor_parallel_size
  119. if draft_worker_kwargs[
  120. "model_config"].hf_config.model_type == "mlp_speculator":
  121. proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
  122. elif draft_worker_kwargs[
  123. "model_config"].hf_config.model_type == "medusa":
  124. proposer_worker = MedusaWorker(**draft_worker_kwargs)
  125. else:
  126. if draft_tp == 1:
  127. draft_worker_kwargs[
  128. "model_runner_cls"] = TP1DraftModelRunner
  129. else:
  130. allow_zero_draft_token_step = False
  131. proposer_worker = MultiStepWorker(**draft_worker_kwargs)
  132. proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
  133. proposer_worker, draft_tp, target_tp)
  134. logger.info("Configuring SpecDecodeWorker with "
  135. f"proposer={type(proposer_worker)}")
  136. spec_decode_sampler: SpecDecodeBaseSampler = None
  137. if draft_token_acceptance_method == "rejection_sampler":
  138. spec_decode_sampler = RejectionSampler(
  139. disable_bonus_tokens=False, )
  140. elif draft_token_acceptance_method == "typical_acceptance_sampler":
  141. spec_decode_sampler = TypicalAcceptanceSampler(
  142. disable_bonus_tokens=False,
  143. posterior_threshold=\
  144. typical_acceptance_sampler_posterior_threshold,
  145. posterior_alpha=typical_acceptance_sampler_posterior_alpha,
  146. )
  147. logger.info("Configuring SpecDecodeWorker with "
  148. f"sampler={type(spec_decode_sampler)}")
  149. return SpecDecodeWorker(
  150. proposer_worker,
  151. scorer_worker,
  152. disable_logprobs=disable_logprobs,
  153. disable_by_batch_size=disable_by_batch_size,
  154. spec_decode_sampler=spec_decode_sampler,
  155. allow_zero_draft_token_step=allow_zero_draft_token_step)
  156. def __init__(
  157. self,
  158. proposer_worker: ProposerWorkerBase,
  159. scorer_worker: WorkerBase,
  160. spec_decode_sampler: SpecDecodeBaseSampler,
  161. disable_logprobs: bool,
  162. metrics_collector: Optional[AsyncMetricsCollector] = None,
  163. disable_by_batch_size: Optional[int] = None,
  164. allow_zero_draft_token_step: Optional[bool] = True,
  165. ):
  166. """
  167. Create a SpecDecodeWorker.
  168. Args:
  169. proposer_worker: A worker that can produce speculative tokens for
  170. sequences.
  171. scorer_worker: A worker that produces probabilities of speculative
  172. tokens according to some base model. Typically a vanilla vLLM
  173. Worker.
  174. spec_decode_sampler: A Torch module used to perform acceptance
  175. sampling of the draft tokens in the verification step of
  176. speculative decoding. Currently we support two different
  177. types of sampler namely RejectionSampler and
  178. TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
  179. instance of RejectionSampler or TypicalAcceptanceSampler.
  180. disable_logprobs: If set to True, token log probabilities will
  181. not be output in both the draft worker and the target worker.
  182. If set to False, log probabilities will be output by both.
  183. disable_by_batch_size: If the batch size is larger than this,
  184. disable speculative decoding for new incoming requests.
  185. metrics_collector: Helper class for collecting metrics; can be set
  186. for testing purposes.
  187. allow_zero_draft_token_step: whether to allow a step where the draft
  188. model generates no draft token; should disallow when the tp of
  189. draft model is larger than 1
  190. """
  191. self.proposer_worker = proposer_worker
  192. self.scorer_worker = scorer_worker
  193. self.disable_by_batch_size = disable_by_batch_size or float("inf")
  194. self.spec_decode_sampler = spec_decode_sampler
  195. self._allow_zero_draft_token_step = allow_zero_draft_token_step
  196. self._metrics = AsyncMetricsCollector(
  197. self.spec_decode_sampler
  198. ) if metrics_collector is None else metrics_collector
  199. # Tracks the sequence IDs that received a bonus token ID in
  200. # their last forward pass. Needed only if KV cache is being
  201. # used for token generation such as in the case of MultiStepWorker.
  202. self._seq_with_bonus_token_in_last_step: Set[int] = set()
  203. # Tracks the currently active request ids and the sequence IDs
  204. # corresponding to them
  205. self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
  206. # Tracks if the proposer worker uses the KV cache or not.
  207. self.probs_dtype = self.spec_decode_sampler.probs_dtype
  208. self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
  209. # Lazy initialization.
  210. self.scorer: SpeculativeScorer
  211. # Hidden states from target model to pass to proposer
  212. # in the subsequent step.
  213. self.previous_hidden_states: Optional[HiddenStates] = None
  214. self._disable_logprobs = disable_logprobs
  215. def init_device(self) -> None:
  216. """Initialize both scorer and proposer models.
  217. """
  218. # The scorer worker model is initialized first in case the proposer
  219. # model has a smaller TP degree than the target worker.
  220. self.scorer_worker.init_device()
  221. self.proposer_worker.init_device()
  222. # NOTE: load_model is not part of the WorkerBase interface.
  223. self.scorer_worker.load_model()
  224. self.proposer_worker.load_model()
  225. self._metrics.init_gpu_tensors(self.rank)
  226. self.spec_decode_sampler.init_gpu_tensors(self.rank)
  227. self.scorer = BatchExpansionTop1Scorer(
  228. scorer_worker=self.scorer_worker,
  229. device=self.device,
  230. vocab_size=self._vocab_size)
  231. self._configure_model_sampler_for_spec_decode()
  232. def load_model(self, *args, **kwargs):
  233. pass
  234. def _configure_model_sampler_for_spec_decode(self):
  235. """Configure model sampler to emit GPU tensors. This allows spec decode
  236. to keep data on device without transferring to CPU and serializing,
  237. which significantly reduces overhead of sampling during verification.
  238. NOTE: This breaks abstraction boundaries pretty badly. The better
  239. design is to have the "move to CPU and serialize" sampling decision be
  240. done outside of the model/sampler; this way the "last-mile" worker
  241. object which interfaces with the scheduler can serialize and incur the
  242. performance hit as necessary. This allows us to run the worker several
  243. iterations in a row without incurring the "move to CPU and serialize"
  244. performance penalty.
  245. Since this requires a large change to Aphrodite, we defer it to later
  246. and temporarily accept this broken abstraction boundary.
  247. NOTE: This will require a special check if the proposer worker
  248. does not have a sampler (e.g. ngram speculation).
  249. """
  250. (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
  251. ) = True
  252. self.proposer_worker.set_include_gpu_probs_tensor()
  253. def determine_num_available_blocks(self) -> Tuple[int, int]:
  254. """Determine the number of cache blocks to use.
  255. This is done by profiling the scorer model (which is typically the
  256. larger of the two). Then the total memory which would be used by the
  257. scorer cache is divided evenly between the proposer and scorer model KV,
  258. such that the number of blocks is equal in both KV caches.
  259. """
  260. num_gpu_blocks, num_cpu_blocks = (
  261. self.scorer_worker.determine_num_available_blocks())
  262. scorer_cache_block_size_bytes = (
  263. self.scorer_worker.get_cache_block_size_bytes())
  264. proposer_cache_block_size_bytes = (
  265. self.proposer_worker.get_cache_block_size_bytes())
  266. new_num_gpu_blocks = split_num_cache_blocks_evenly(
  267. scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
  268. num_gpu_blocks)
  269. return new_num_gpu_blocks, num_cpu_blocks
  270. def initialize_cache(self, num_gpu_blocks: int,
  271. num_cpu_blocks: int) -> None:
  272. """Initialize the cache engine of the scorer and proposer workers.
  273. """
  274. self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  275. num_cpu_blocks=num_cpu_blocks)
  276. self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  277. num_cpu_blocks=num_cpu_blocks)
  278. @torch.inference_mode()
  279. def execute_model(
  280. self,
  281. execute_model_req: Optional[ExecuteModelRequest] = None
  282. ) -> List[SamplerOutput]:
  283. """Perform speculative decoding on the input batch.
  284. """
  285. if self.rank != self._driver_rank:
  286. self._run_non_driver_rank()
  287. return []
  288. if execute_model_req is None:
  289. # This signals that there's no more requests to process for now.
  290. # All workers are running infinite loop with broadcast_tensor_dict,
  291. # and it stops the loop when the driver broadcasts an empty input.
  292. # Send an empty input to notify all other workers to stop their
  293. # execution loop.
  294. broadcast_tensor_dict({}, src=0)
  295. return []
  296. self._track_finished_requests(execute_model_req)
  297. disable_all_speculation = self._should_disable_all_speculation(
  298. execute_model_req)
  299. num_lookahead_slots = execute_model_req.num_lookahead_slots
  300. # Broadcast how many lookahead slots are scheduled for this step, and
  301. # whether all speculation is disabled, to all non-driver workers.
  302. # This is required as if the number of draft model runs changes
  303. # dynamically, the non-driver workers won't know unless we perform a
  304. # communication to inform them.
  305. broadcast_dict = dict(
  306. num_lookahead_slots=num_lookahead_slots,
  307. disable_all_speculation=disable_all_speculation,
  308. )
  309. broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
  310. assert execute_model_req.seq_group_metadata_list is not None, (
  311. "speculative decoding requires non-None seq_group_metadata_list")
  312. self._maybe_disable_speculative_tokens(
  313. disable_all_speculation, execute_model_req.seq_group_metadata_list)
  314. # Speculative decoding is disabled in the following cases:
  315. # 1. Prefill phase: Speculative decoding is not
  316. # used during the prefill phase.
  317. # 2. Auto-disable enabled: The running queue size exceeds
  318. # the specified threshold.
  319. # 3. No request: There are no requests in the batch.
  320. # In any of these cases, the proposer and scorer workers
  321. # are called normally.
  322. if num_lookahead_slots == 0 or len(
  323. execute_model_req.seq_group_metadata_list
  324. ) == 0 or disable_all_speculation:
  325. return self._run_no_spec(execute_model_req,
  326. skip_proposer=disable_all_speculation)
  327. return self._run_speculative_decoding_step(execute_model_req,
  328. num_lookahead_slots)
  329. @torch.inference_mode()
  330. def start_worker_execution_loop(self) -> None:
  331. """Execute model loop to perform speculative decoding
  332. in parallel worker."""
  333. while self._run_non_driver_rank():
  334. pass
  335. def _should_disable_all_speculation(
  336. self, execute_model_req: ExecuteModelRequest) -> bool:
  337. # When the batch size is too large, disable speculative decoding
  338. # to stop trading off throughput for latency.
  339. disable_all_speculation = (execute_model_req.running_queue_size >=
  340. self.disable_by_batch_size)
  341. return disable_all_speculation
  342. def _maybe_disable_speculative_tokens(
  343. self, disable_all_speculation: bool,
  344. seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
  345. if not disable_all_speculation:
  346. return
  347. for seq_group_metadata in seq_group_metadata_list:
  348. # Once num_speculative_tokens is set to 0, the spec decode
  349. # of this request will be disabled forever.
  350. # TODO: We currently store spec decoding specific
  351. # state in the global data structure, but we should maintain
  352. # this state within spec decode worker.
  353. seq_group_metadata.num_speculative_tokens = 0
  354. def _serialize_sampler_output_no_logprobs(
  355. self, execute_model_req: ExecuteModelRequest,
  356. sampler_output: SamplerOutput) -> SamplerOutput:
  357. """
  358. Creates and returns a `SamplerOutput` with only the sampled token IDs
  359. being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
  360. All other parameters in `CompletionSequenceGroupOutput` related to log
  361. probabilities are skipped.
  362. Args:
  363. execute_model_req (ExecuteModelRequest): The model request that
  364. was executed.
  365. sampler_output (SamplerOutput): The output from the sampler with
  366. only GPU tensors populated.
  367. Returns:
  368. SamplerOutput: A new `SamplerOutput` instance containing a list of
  369. `CompletionSequenceGroupOutput` objects with only sampled token
  370. IDs populated.
  371. """
  372. seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
  373. sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
  374. completion_seq_group_output_list: List[
  375. CompletionSequenceGroupOutput] = []
  376. for index, seq_id in enumerate(seq_ids):
  377. completion_seq_group_output_list.append(
  378. create_sequence_group_output(
  379. token_id=sampled_token_ids_list[index][0],
  380. token_id_logprob_rank=-1,
  381. token_id_logprob=0.0,
  382. seq_id=seq_id,
  383. topk_token_ids=[],
  384. topk_logprobs=[],
  385. ))
  386. return SamplerOutput(outputs=completion_seq_group_output_list)
  387. @nvtx_range("spec_decode_worker._run_no_spec")
  388. def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
  389. skip_proposer: bool) -> List[SamplerOutput]:
  390. """Run a single generation step without any speculation. The input is
  391. sent to the proposer and scorer model so that the KV cache is consistent
  392. between the two. When skip_proposer is True, the proposer model is
  393. not called, meaning that the kv-cache in proposer for requests is not
  394. updated, so they cannot enable spec decode in the rest decoding.
  395. """
  396. if not skip_proposer:
  397. self.proposer_worker.execute_model(execute_model_req)
  398. sampler_output = self.scorer_worker.execute_model(execute_model_req)
  399. assert len(sampler_output) == 1
  400. sampler_output = sampler_output[0]
  401. # Store hidden states from target model execution.
  402. hidden_states = sampler_output.hidden_states
  403. if hidden_states is not None:
  404. if self.previous_hidden_states is None:
  405. self.previous_hidden_states = HiddenStates(
  406. execute_model_req.seq_group_metadata_list, hidden_states)
  407. else:
  408. self.previous_hidden_states.update(
  409. execute_model_req.seq_group_metadata_list, hidden_states)
  410. sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
  411. execute_model_req=execute_model_req, sampler_output=sampler_output)
  412. if self._disable_logprobs else
  413. sampler_output)
  414. # Clear device tensors from sampler output. This reduces communication
  415. # overhead when the engine runs in a different process than the workers.
  416. sampler_output.sampled_token_probs = None
  417. sampler_output.sampled_token_ids = None
  418. sampler_output.logprobs = None
  419. return [sampler_output_to_return]
  420. def _run_non_driver_rank(self) -> bool:
  421. """Run proposer and verifier model in non-driver workers. This is used
  422. for both speculation cases (num_lookahead_slots>0) and non-speculation
  423. cases (e.g. prefill).
  424. Returns True iff there are remaining sequences to process.
  425. """
  426. assert self.rank != self._driver_rank
  427. data = broadcast_tensor_dict(src=self._driver_rank)
  428. if not data:
  429. return False
  430. num_lookahead_slots = data["num_lookahead_slots"]
  431. # Even if num_lookahead_slots is zero, we want to run the proposer model
  432. # as it may have KV.
  433. #
  434. # We run the proposer once per lookahead slot. In the future we should
  435. # delegate how many times it runs to the proposer.
  436. for _ in range(max(num_lookahead_slots, 1)):
  437. self.proposer_worker.execute_model()
  438. self.scorer_worker.execute_model()
  439. return True
  440. @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
  441. def _run_speculative_decoding_step(
  442. self, execute_model_req: ExecuteModelRequest,
  443. num_lookahead_slots: int) -> List[SamplerOutput]:
  444. """Execute a single step of speculative decoding.
  445. This invokes the proposer worker to get k speculative tokens for each
  446. sequence, then scores each speculative token using the scoring worker.
  447. Returns a list of SamplerOutput, each containing a single token per
  448. sequence.
  449. """
  450. assert num_lookahead_slots == execute_model_req.num_lookahead_slots
  451. # Pass last hidden states from target model to proposer
  452. execute_model_req.previous_hidden_states = self.previous_hidden_states
  453. self.previous_hidden_states = None
  454. # Generate proposals using draft worker.
  455. proposals = self.proposer_worker.get_spec_proposals(
  456. execute_model_req, self._seq_with_bonus_token_in_last_step)
  457. if not self._allow_zero_draft_token_step and proposals.no_proposals:
  458. #TODO: Fix it #5814
  459. raise RuntimeError("Cannot handle cases where distributed draft "
  460. "workers generate no tokens")
  461. proposal_scores = self.scorer.score_proposals(
  462. execute_model_req,
  463. proposals,
  464. )
  465. accepted_token_ids, target_logprobs = self._verify_tokens(
  466. execute_model_req.seq_group_metadata_list, proposal_scores,
  467. proposals, execute_model_req.num_lookahead_slots)
  468. return self._create_output_sampler_list(
  469. execute_model_req.seq_group_metadata_list,
  470. accepted_token_ids,
  471. target_logprobs=target_logprobs,
  472. k=execute_model_req.num_lookahead_slots)
  473. @nvtx_range("spec_decode_worker._verify_tokens")
  474. def _verify_tokens(
  475. self,
  476. seq_group_metadata_list: List[SequenceGroupMetadata],
  477. proposal_scores: SpeculativeScores,
  478. proposals: SpeculativeProposals,
  479. max_proposal_len: int,
  480. ) -> Tuple[torch.Tensor, torch.Tensor]:
  481. """Determine which speculative tokens are accepted using the
  482. probabilities of each token according to the proposer and scorer models.
  483. Returns a tuple of Tensors, one for the accepted token ids and one for
  484. the logprobs according to the scoring model.
  485. """
  486. proposal_lens_list = proposals.proposal_lens.tolist()
  487. # Aphrodite currently only supports proposal lens equal to zero or the
  488. # batch proposal len. This adds some complexity (splitting the batch
  489. # into spec and non spec sequences) and should be removed in the
  490. # future. It can be done by supporting per-sequence proposal lens.
  491. _, spec_indices = split_batch_by_proposal_len(
  492. seq_group_metadata_list,
  493. proposal_lens_list,
  494. select_proposal_len_zero=False)
  495. _, non_spec_indices = split_batch_by_proposal_len(
  496. seq_group_metadata_list,
  497. proposal_lens_list,
  498. select_proposal_len_zero=True)
  499. original_indices = spec_indices + non_spec_indices
  500. # Get probabilities of target model, excluding bonus token.
  501. proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
  502. # Get non-speculative sampled tokens from target model.
  503. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
  504. # Get bonus tokens from target model.
  505. bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
  506. # Get probabilities according to proposal method.
  507. proposal_probs = proposals.proposal_probs[spec_indices]
  508. # Get proposed tokens.
  509. proposal_token_ids = proposals.proposal_token_ids[spec_indices]
  510. # Sampler arguments
  511. sampler_extra_kwargs = {}
  512. if isinstance(self.spec_decode_sampler,
  513. SpecDecodeStochasticBaseSampler):
  514. # Get sequence group state
  515. generators = []
  516. for seq_group_metadata in seq_group_metadata_list:
  517. if (seq_group_metadata.state is not None
  518. and seq_group_metadata.state.generator is not None):
  519. generators.append(seq_group_metadata.state.generator)
  520. else:
  521. generators.append(None)
  522. sampler_extra_kwargs["generators"] = generators
  523. accepted_token_ids = self.spec_decode_sampler(
  524. target_probs=proposal_verifier_probs,
  525. bonus_token_ids=bonus_token_ids,
  526. draft_probs=proposal_probs,
  527. draft_token_ids=proposal_token_ids,
  528. **sampler_extra_kwargs,
  529. )
  530. # Append output tokens from non-speculative sequences to
  531. # the accepted token ids tensor.
  532. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
  533. 1).clone()
  534. non_spec_token_ids[:, 1:] = -1
  535. accepted_token_ids = torch.cat(
  536. [accepted_token_ids, non_spec_token_ids])
  537. logprobs = proposal_scores.logprobs
  538. # Rearrange so that results are in the order of the original seq group
  539. # metadata.
  540. accepted_token_ids[original_indices] = accepted_token_ids.clone()
  541. hidden_states = proposal_scores.hidden_states
  542. if hidden_states is not None:
  543. # Contract hidden states based on accepted tokens
  544. hs_size = hidden_states.shape[1]
  545. hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
  546. hs_size)
  547. accepted_index = accepted_token_ids + 1 # Convert -1 to 0
  548. accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
  549. index = accepted_index[:, None, None].expand(-1, 1, hs_size)
  550. hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
  551. # Store hidden states from target model for subsequent decode step
  552. self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
  553. hidden_states)
  554. return accepted_token_ids, logprobs
  555. def _create_output_sampler_list(
  556. self,
  557. seq_group_metadata_list: List[SequenceGroupMetadata],
  558. accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
  559. target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
  560. k: int,
  561. ) -> List[SamplerOutput]:
  562. """Given the accepted token ids, create a list of SamplerOutput.
  563. The output is padded with -1 tokens such that each sequence has
  564. the same number of outputs.
  565. """
  566. batch_size, num_steps = accepted_token_ids.shape
  567. accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
  568. if self._disable_logprobs:
  569. # We are skipping the logprobs. Hence don't serialize the
  570. # logprobs related tensors from the GPU. Instead create
  571. # empty/dummy lists.
  572. (accepted_token_id_ranks_by_step,
  573. accepted_token_id_logprobs_by_step,
  574. topk_logprobs_by_step, topk_indices_by_step) =\
  575. self._create_dummy_logprob_lists(
  576. batch_size, num_steps,
  577. self.scorer_worker.model_config.max_logprobs)
  578. else:
  579. # Organize input tensors by step instead of by sequence.
  580. target_logprobs_by_step = target_logprobs.transpose(0, 1)
  581. # Serialize all tensors into Python lists.
  582. (accepted_token_id_ranks_by_step,
  583. accepted_token_id_logprobs_by_step,
  584. topk_logprobs_by_step, topk_indices_by_step) =\
  585. self._create_logprob_lists_from_tensors(
  586. target_logprobs_by_step, accepted_token_ids_by_step,
  587. self.scorer_worker.model_config.max_logprobs)
  588. # Get the sequence ids and num_logprobs (sampling parameter) in the
  589. # batch.
  590. seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
  591. seq_group_metadata_list)
  592. num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
  593. # Serialize tensor to CPU Python list.
  594. accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
  595. # Construct the output on a per-step, per-sequence basis.
  596. sampler_output_list: List[SamplerOutput] = []
  597. for step_index in range(num_steps):
  598. if all(token_id == -1
  599. for token_id in accepted_token_ids_by_step[step_index]):
  600. break
  601. step_output_token_ids: List[CompletionSequenceGroupOutput] = []
  602. for sequence_index in range(batch_size):
  603. # Each sequence may have a different num_logprobs; retrieve it.
  604. num_logprobs = num_logprobs_per_seq[sequence_index]
  605. step_output_token_ids.append(
  606. create_sequence_group_output(
  607. token_id=accepted_token_ids_by_step[step_index]
  608. [sequence_index],
  609. token_id_logprob_rank=accepted_token_id_ranks_by_step[
  610. step_index][sequence_index],
  611. token_id_logprob=accepted_token_id_logprobs_by_step[
  612. step_index][sequence_index],
  613. seq_id=seq_ids[sequence_index],
  614. topk_token_ids=topk_indices_by_step[step_index]
  615. [sequence_index][:num_logprobs],
  616. topk_logprobs=topk_logprobs_by_step[step_index]
  617. [sequence_index][:num_logprobs],
  618. ))
  619. sampler_output_list.append(
  620. SamplerOutput(outputs=step_output_token_ids))
  621. # Populate the data structures needed to keep track of sequences with
  622. # bonus tokens.
  623. self._track_sequences_with_bonus_tokens(seq_ids,
  624. request_ids_seq_ids_mapping,
  625. accepted_token_ids_by_step)
  626. maybe_rejsample_metrics = (
  627. self._metrics.maybe_collect_rejsample_metrics(k))
  628. if maybe_rejsample_metrics is not None:
  629. sampler_output_list[
  630. 0].spec_decode_worker_metrics = maybe_rejsample_metrics
  631. return sampler_output_list
  632. def _create_dummy_logprob_lists(
  633. self,
  634. batch_size: int,
  635. num_steps: int,
  636. num_top_k: int,
  637. ) -> Tuple[List[List[int]], List[List[float]],
  638. List[List[List[Optional[float]]]],
  639. List[List[List[Optional[int]]]]]:
  640. """
  641. Creates and returns four dummy lists representing token probabilities
  642. and their ranks.
  643. This method initializes and returns:
  644. - The ranks of the accepted tokens, shaped (num_steps, batch_size)
  645. - The log probabilities of the accepted tokens,
  646. shaped (num_steps, batch_size)
  647. - The log probabilities of the top k tokens,
  648. shaped (num_steps, batch_size, num_top_k)
  649. - The token IDs of the top k tokens,
  650. shaped (num_steps, batch_size, num_top_k)
  651. Args:
  652. batch_size (int): The size of the batch.
  653. num_steps (int): The number of steps in the sequence.
  654. num_top_k (int): The number of top-k token log probabilities to
  655. return.
  656. Returns:
  657. A tuple containing four dummy lists as described above.
  658. """
  659. accepted_token_id_ranks_by_step = [[-1] * batch_size
  660. for _ in range(num_steps)]
  661. accepted_token_id_logprobs_by_step = [[0.0] * batch_size
  662. for _ in range(num_steps)]
  663. topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
  664. [None] * num_top_k for _ in range(batch_size)
  665. ] for _ in range(num_steps)]
  666. topk_indices_by_step: List[List[List[Optional[int]]]] = [[
  667. [None] * num_top_k for _ in range(batch_size)
  668. ] for _ in range(num_steps)]
  669. return (accepted_token_id_ranks_by_step,
  670. accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
  671. topk_indices_by_step)
  672. def _create_logprob_lists_from_tensors(
  673. self,
  674. target_logprobs_by_step: torch.Tensor,
  675. accepted_token_ids_by_step: torch.Tensor,
  676. num_top_k: int,
  677. ) -> Tuple[List[List[int]], List[List[float]],
  678. List[List[List[Optional[float]]]],
  679. List[List[List[Optional[int]]]]]:
  680. """
  681. Creates and returns four lists representing token probabilities and
  682. their ranks.
  683. This method initializes and returns four lists containing:
  684. - The ranks of the accepted tokens, shaped (num_steps, batch_size)
  685. - The log probabilities of the accepted tokens,
  686. shaped (num_steps, batch_size)
  687. - The log probabilities of the top k tokens,
  688. shaped (num_steps, batch_size, num_top_k)
  689. - The token IDs of the top k tokens,
  690. shaped (num_steps, batch_size, num_top_k)
  691. Args:
  692. target_logprobs_by_step (torch.Tensor): Tensor representing the
  693. log probabilities of the target model,
  694. shaped (num_steps, batch_size, vocab_size)
  695. accepted_token_ids_by_step (torch.Tensor): Tensor representing
  696. the accepted token_ids, shaped (num_steps, batch_size)
  697. num_top_k (int): The number of top-k token log probabilities to
  698. return.
  699. Returns:
  700. A tuple containing the lists as described above.
  701. """
  702. # Serialize all tensors to CPU Python lists.
  703. # Get the logprobs/rank of the accepted tokens.
  704. (accepted_token_id_ranks_by_step_tensor,
  705. accepted_token_id_logprobs_by_step_tensor
  706. ) = get_sampled_token_logprobs(
  707. logprob_tensor=target_logprobs_by_step,
  708. sampled_token_ids=accepted_token_ids_by_step,
  709. )
  710. # Get the top-k logprobs (which may or may not include the
  711. # logprob of the accepted token).
  712. (topk_logprobs_by_step_tensor,
  713. topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
  714. k=num_top_k,
  715. dim=-1,
  716. )
  717. accepted_token_id_ranks_by_step = (
  718. accepted_token_id_ranks_by_step_tensor.tolist())
  719. accepted_token_id_logprobs_by_step = (
  720. accepted_token_id_logprobs_by_step_tensor.tolist())
  721. topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
  722. topk_indices_by_step = topk_indices_by_step_tensor.tolist()
  723. return (accepted_token_id_ranks_by_step,
  724. accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
  725. topk_indices_by_step)
  726. def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
  727. """
  728. Removes the finished requests and their associated sequence ids from
  729. internal book keeping data structures.
  730. """
  731. for finished_request in execute_model_req.finished_requests_ids:
  732. for seq_id in self._request_id_seq_id_mapping[finished_request]:
  733. self._seq_with_bonus_token_in_last_step.discard(seq_id)
  734. del self._request_id_seq_id_mapping[finished_request]
  735. def _track_sequences_with_bonus_tokens(
  736. self, seq_ids: List[int],
  737. request_ids_seq_ids_mapping: Dict[str, Set[int]],
  738. accepted_token_ids_by_step: List[List[int]]):
  739. """
  740. Updates the internal data structures which keep track of sequences
  741. which have been assigned bonus tokens in their last forward pass.
  742. """
  743. for seq_index, seq_id in enumerate(seq_ids):
  744. last_token_id = accepted_token_ids_by_step[-1][seq_index]
  745. if last_token_id == -1:
  746. self._seq_with_bonus_token_in_last_step.discard(seq_id)
  747. else:
  748. self._seq_with_bonus_token_in_last_step.add(seq_id)
  749. for request_id, sequences in request_ids_seq_ids_mapping.items():
  750. self._request_id_seq_id_mapping[request_id].update(sequences)
  751. @cached_property
  752. def _vocab_size(self) -> int:
  753. """Get the vocab size of the model and make sure it's consistent between
  754. draft and target workers.
  755. """
  756. vocab_sizes = [
  757. worker.vocab_size
  758. for worker in [self.proposer_worker, self.scorer_worker]
  759. ]
  760. assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
  761. return vocab_sizes[0]
  762. @property
  763. def rank(self):
  764. return self.scorer_worker.rank
  765. @property
  766. def device(self):
  767. return self.scorer_worker.device
  768. @property
  769. def _driver_rank(self) -> int:
  770. return 0
  771. def get_cache_block_size_bytes(self):
  772. """Return the size of a cache block in bytes.
  773. This function is only used to compose workers within a SpecDecodeWorker.
  774. We leave composing a SpecDecodeWorker within a SpecDecodeWorker
  775. undefined for now, although it could be implemented in the future.
  776. See https://arxiv.org/abs/2308.04623.
  777. """
  778. raise NotImplementedError
  779. def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
  780. proposer_cache_block_size_bytes: int,
  781. total_num_gpu_blocks: int) -> int:
  782. """Given total_num_gpu_blocks, the number of GPU blocks that could be
  783. allocate to the target model, this function calculates how many blocks
  784. should be given to the draft and target model.
  785. Note that usually the block size, in bytes, of each model is different,
  786. as it's a function of number of KV/layer, number of heads, and hidden
  787. dimension size.
  788. Since the target and draft models allocate the same number of blocks, we
  789. simply calculate the number of blocks where if allocated by both models,
  790. the total memory usage from KV cache is no larger than the number of
  791. blocks allocatable by the target model alone.
  792. """
  793. new_num_gpu_blocks = int(
  794. total_num_gpu_blocks * scorer_cache_block_size_bytes /
  795. (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
  796. return new_num_gpu_blocks