1
0

spec_decode_worker.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. from functools import cached_property
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. from loguru import logger
  5. from aphrodite.common.config import SchedulerConfig
  6. from aphrodite.common.sequence import (Logprob, SamplerOutput,
  7. SequenceGroupMetadata,
  8. SequenceGroupOutput, SequenceOutput)
  9. from aphrodite.modeling.layers.rejection import RejectionSampler
  10. from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
  11. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  12. SpeculativeScorer,
  13. SpeculativeScores)
  14. from aphrodite.spec_decode.metrics import AsyncMetricsCollector
  15. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  16. from aphrodite.spec_decode.ngram_worker import NGramWorker
  17. from aphrodite.spec_decode.util import (get_all_seq_ids, nvtx_range,
  18. split_batch_by_proposal_len)
  19. from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase,
  20. WorkerBase)
  21. class SpecDecodeWorker(LoraNotSupportedWorkerBase):
  22. """Worker which implements speculative decoding.
  23. Speculative decoding reduces decoding per-token latency by using a proposal
  24. method, such as a small draft model, to speculate ahead of a larger LLM. The
  25. probabilities of the speculative tokens are then determined by the larger
  26. LLM, after which some verification routine determines which (if any) of the
  27. speculative tokens are accepted by the larger LLM.
  28. The current implementation has the following limitations:
  29. * Only draft-model proposal is implemented (contributions for more forms are
  30. welcome!).
  31. * Only top-1 proposal and scoring are implemented. Tree-attention is left as
  32. future work.
  33. * Only lossless rejection sampling is supported. Contributions adding lossy
  34. verification routines are welcome (e.g. Medusa's typical acceptance).
  35. * All sequences in a batch must have the same proposal length, or zero. This
  36. can be improved by having per-sequence speculation in the future.
  37. * The scoring forward pass is done without an MQA kernel, which is
  38. suboptimal especially as the batch size, proposal length, and sequence
  39. lengths grow. Contributions to add a MQA scoring are welcome once
  40. correctness tests pass.
  41. """
  42. @classmethod
  43. def create_worker(
  44. cls,
  45. scorer_worker: WorkerBase,
  46. speculative_config: SchedulerConfig,
  47. ) -> "SpecDecodeWorker":
  48. if speculative_config.ngram_prompt_lookup_max > 0:
  49. proposer_worker = NGramWorker(
  50. model_config=speculative_config.draft_model_config,
  51. parallel_config=speculative_config.draft_parallel_config,
  52. scheduler_config=scorer_worker.scheduler_config,
  53. device_config=scorer_worker.device_config,
  54. cache_config=scorer_worker.cache_config,
  55. local_rank=0,
  56. rank=0,
  57. distributed_init_method=scorer_worker.distributed_init_method,
  58. )
  59. proposer_worker.set_ngram_window_size(
  60. speculative_config.ngram_prompt_lookup_min,
  61. speculative_config.ngram_prompt_lookup_max)
  62. else:
  63. proposer_worker = MultiStepWorker(
  64. model_config=speculative_config.draft_model_config,
  65. parallel_config=speculative_config.draft_parallel_config,
  66. scheduler_config=scorer_worker.scheduler_config,
  67. device_config=scorer_worker.device_config,
  68. cache_config=scorer_worker.cache_config,
  69. local_rank=0,
  70. rank=0,
  71. distributed_init_method=scorer_worker.distributed_init_method,
  72. lora_config=scorer_worker.lora_config,
  73. vision_language_config=scorer_worker.vision_language_config,
  74. is_driver_worker=True,
  75. )
  76. return SpecDecodeWorker(
  77. proposer_worker,
  78. scorer_worker,
  79. # TODO: disable strict mode for speedup.
  80. rejection_sampler=RejectionSampler(strict_mode=True),
  81. )
  82. def __init__(
  83. self,
  84. proposer_worker: WorkerBase,
  85. scorer_worker: WorkerBase,
  86. rejection_sampler: RejectionSampler,
  87. metrics_collector: Optional[AsyncMetricsCollector] = None,
  88. ):
  89. """
  90. Create a SpecDecodeWorker.
  91. Args:
  92. proposer_worker: A worker that can produce speculative tokens for
  93. sequences.
  94. scorer_worker: A worker that produces probabilities of speculative
  95. tokens according to some base model. Typically a vanilla
  96. Aphrodite Worker.
  97. rejection_sampler: A Torch module used to perform modified rejection
  98. sampling for speculative decoding.
  99. metrics_collector: Helper class for collecting metrics; can be set
  100. for testing purposes.
  101. """
  102. self.proposer_worker = proposer_worker
  103. self.scorer_worker = scorer_worker
  104. self.rejection_sampler = rejection_sampler
  105. self._metrics = AsyncMetricsCollector(
  106. rejection_sampler
  107. ) if metrics_collector is None else metrics_collector
  108. self.probs_dtype = self.rejection_sampler.probs_dtype
  109. self.token_id_dtype = self.rejection_sampler.token_id_dtype
  110. # Lazy initiazliation.
  111. self.scorer: SpeculativeScorer
  112. def init_device(self) -> None:
  113. """Initialize both scorer and proposer models.
  114. """
  115. # The scorer worker model is initialized first in case the proposer
  116. # model has a smaller TP degree than the target worker.
  117. self.scorer_worker.init_device()
  118. self.proposer_worker.init_device()
  119. # NOTE: load_model is not part of the WorkerBase interface.
  120. self.scorer_worker.load_model()
  121. self.proposer_worker.load_model()
  122. self._metrics.init_gpu_tensors(self.rank)
  123. self.rejection_sampler.init_gpu_tensors(self.rank)
  124. self.scorer = BatchExpansionTop1Scorer(
  125. scorer_worker=self.scorer_worker,
  126. device=self.device,
  127. vocab_size=self._vocab_size)
  128. self._configure_model_sampler_for_spec_decode()
  129. def _configure_model_sampler_for_spec_decode(self):
  130. """Configure model sampler to emit GPU tensors. This allows spec decode
  131. to keep data on device without transferring to CPU and serializing,
  132. which significantly reduces overhead of rejection sampling.
  133. NOTE: This breaks abstraction boundaries pretty badly. The better
  134. design is to have the "move to CPU and serialize" sampling decision be
  135. done outside of the model/sampler; this way the "last-mile" worker
  136. object which interfaces with the scheduler can serialize and incur the
  137. performance hit as necessary. This allows us to run the worker several
  138. iterations in a row without incurring the "move to CPU and serialize"
  139. performance penalty.
  140. Since this requires a large change to Aphrodite, we defer it to later
  141. and temporarily accept this broken abstraction boundary.
  142. NOTE: This will require a special check if the proposer worker
  143. does not have a sampler (e.g. ngram speculation).
  144. """
  145. (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
  146. ) = True
  147. self.proposer_worker.set_include_gpu_probs_tensor()
  148. def determine_num_available_blocks(self) -> Tuple[int, int]:
  149. """Determine the number of cache blocks to use.
  150. This is done by profiling the scorer model (which is typically the
  151. larger of the two). Then the total memory which would be used by the
  152. scorer cache is divided evenly between the proposer and scorer model KV,
  153. such that the number of blocks is equal in both KV caches.
  154. """
  155. num_gpu_blocks, num_cpu_blocks = (
  156. self.scorer_worker.determine_num_available_blocks())
  157. scorer_cache_block_size_bytes = (
  158. self.scorer_worker.get_cache_block_size_bytes())
  159. proposer_cache_block_size_bytes = (
  160. self.proposer_worker.get_cache_block_size_bytes())
  161. new_num_gpu_blocks = split_num_cache_blocks_evenly(
  162. scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
  163. num_gpu_blocks)
  164. return new_num_gpu_blocks, num_cpu_blocks
  165. def initialize_cache(self, num_gpu_blocks: int,
  166. num_cpu_blocks: int) -> None:
  167. """Initialize the cache engine of the scorer and proposer workers.
  168. """
  169. self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  170. num_cpu_blocks=num_cpu_blocks)
  171. self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  172. num_cpu_blocks=num_cpu_blocks)
  173. @torch.inference_mode()
  174. def execute_model(
  175. self,
  176. seq_group_metadata_list: List[SequenceGroupMetadata],
  177. blocks_to_swap_in: Optional[Dict[int, int]],
  178. blocks_to_swap_out: Optional[Dict[int, int]],
  179. blocks_to_copy: Optional[Dict[int, List[int]]],
  180. num_lookahead_slots: int,
  181. ) -> List[SamplerOutput]:
  182. """Perform speculative decoding on the input batch.
  183. """
  184. assert seq_group_metadata_list is not None, (
  185. "speculative decoding "
  186. "requires non-None seq_group_metadata_list")
  187. logger.debug(
  188. f"spec_decode_worker.execute_model {num_lookahead_slots=}")
  189. # If no spec tokens, call the proposer and scorer workers normally.
  190. # Used for prefill.
  191. if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
  192. return self._run_no_spec(
  193. seq_group_metadata_list=seq_group_metadata_list,
  194. blocks_to_swap_in=blocks_to_swap_in,
  195. blocks_to_swap_out=blocks_to_swap_out,
  196. blocks_to_copy=blocks_to_copy,
  197. )
  198. return self._run_speculative_decoding_step(
  199. seq_group_metadata_list=seq_group_metadata_list,
  200. blocks_to_swap_in=blocks_to_swap_in,
  201. blocks_to_swap_out=blocks_to_swap_out,
  202. blocks_to_copy=blocks_to_copy,
  203. k=num_lookahead_slots,
  204. )
  205. @nvtx_range("spec_decode_worker._run_no_spec")
  206. def _run_no_spec(
  207. self,
  208. seq_group_metadata_list: List[SequenceGroupMetadata],
  209. blocks_to_swap_in: Optional[Dict[int, int]],
  210. blocks_to_swap_out: Optional[Dict[int, int]],
  211. blocks_to_copy: Optional[Dict[int, List[int]]],
  212. ) -> List[SamplerOutput]:
  213. """Run a prefill step, without any speculation. The input is sent to the
  214. proposer and scorer model so that the KV cache is consistent between the
  215. two.
  216. """
  217. logger.debug("run proposer worker no spec")
  218. self.proposer_worker.execute_model(
  219. seq_group_metadata_list=seq_group_metadata_list,
  220. blocks_to_swap_in=blocks_to_swap_in,
  221. blocks_to_swap_out=blocks_to_swap_out,
  222. blocks_to_copy=blocks_to_copy,
  223. )
  224. logger.debug("run target worker no spec")
  225. sampler_output = self.scorer_worker.execute_model(
  226. seq_group_metadata_list=seq_group_metadata_list,
  227. blocks_to_swap_in=blocks_to_swap_in,
  228. blocks_to_swap_out=blocks_to_swap_out,
  229. blocks_to_copy=blocks_to_copy,
  230. )
  231. assert len(sampler_output) == 1
  232. sampler_output = sampler_output[0]
  233. # Clear device tensors from sampler output. This reduces communication
  234. # overhead when the engine runs in a different process than the workers.
  235. sampler_output.probs = None
  236. sampler_output.sampled_tokens = None
  237. return [sampler_output]
  238. @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
  239. def _run_speculative_decoding_step(
  240. self,
  241. seq_group_metadata_list: List[SequenceGroupMetadata],
  242. blocks_to_swap_in: Optional[Dict[int, int]],
  243. blocks_to_swap_out: Optional[Dict[int, int]],
  244. blocks_to_copy: Optional[Dict[int, List[int]]],
  245. k: int,
  246. ) -> List[SamplerOutput]:
  247. """Execute a single step of speculative decoding.
  248. This invokes the proposer worker to get k speculative tokens for each
  249. sequence, then scores each speculative token using the scoring worker.
  250. Returns a list of SamplerOutput, each containing a single token per
  251. sequence.
  252. """
  253. logger.debug("get spec proposals")
  254. # Generate proposals using draft worker.
  255. assert blocks_to_swap_in is not None
  256. assert blocks_to_swap_out is not None
  257. assert blocks_to_copy is not None
  258. proposals = self.proposer_worker.get_spec_proposals(
  259. seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
  260. blocks_to_copy, k)
  261. logger.debug("score proposals")
  262. proposal_scores = self.scorer.score_proposals(
  263. seq_group_metadata_list,
  264. blocks_to_swap_in,
  265. blocks_to_swap_out,
  266. blocks_to_copy,
  267. k,
  268. proposals,
  269. )
  270. logger.debug("verify proposals")
  271. accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
  272. proposal_scores, proposals, k)
  273. logger.debug("create output list")
  274. return self._create_output_sampler_list(seq_group_metadata_list,
  275. accepted_token_ids, k)
  276. @nvtx_range("spec_decode_worker._verify_tokens")
  277. def _verify_tokens(
  278. self,
  279. seq_group_metadata_list: List[SequenceGroupMetadata],
  280. proposal_scores: SpeculativeScores,
  281. proposals: SpeculativeProposals,
  282. max_proposal_len: int,
  283. ) -> torch.Tensor:
  284. """Determine which speculative tokens are accepted using the
  285. probabilities of each token according to the proposer and scorer models.
  286. """
  287. proposal_lens_list = proposals.proposal_lens.tolist()
  288. # Aphrodite currently only supports proposal lens equal to zero or the
  289. # batch proposal len. This adds some complexity (splitting the batch
  290. # into spec and non spec sequences) and should be removed in the
  291. # future. It can be done by supporting per-sequence proposal lens.
  292. _, spec_indices = split_batch_by_proposal_len(
  293. seq_group_metadata_list,
  294. proposal_lens_list,
  295. select_proposal_len_zero=False)
  296. _, non_spec_indices = split_batch_by_proposal_len(
  297. seq_group_metadata_list,
  298. proposal_lens_list,
  299. select_proposal_len_zero=True)
  300. original_indices = spec_indices + non_spec_indices
  301. # Get probabilities of target model, excluding bonus token.
  302. proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
  303. # Get non-speculative sampled tokens from target model.
  304. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
  305. # Get bonus tokens from target model.
  306. bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
  307. # Get probabilities according to proposal method.
  308. proposal_probs = proposals.proposal_probs[spec_indices]
  309. # Get proposed tokens.
  310. proposal_token_ids = proposals.proposal_token_ids[spec_indices]
  311. accepted_token_ids = self.rejection_sampler(
  312. target_probs=proposal_verifier_probs,
  313. bonus_token_ids=bonus_token_ids,
  314. draft_probs=proposal_probs,
  315. draft_token_ids=proposal_token_ids,
  316. )
  317. # Append output tokens from non-speculative sequences to
  318. # the accepted token ids tensor.
  319. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
  320. 1).clone()
  321. non_spec_token_ids[:, 1:] = -1
  322. accepted_token_ids = torch.cat(
  323. [accepted_token_ids, non_spec_token_ids])
  324. # Rearrange so that results are in the order of the original seq group
  325. # metadata.
  326. accepted_token_ids[original_indices] = accepted_token_ids.clone()
  327. return accepted_token_ids
  328. def _create_output_sampler_list(
  329. self,
  330. seq_group_metadata_list: List[SequenceGroupMetadata],
  331. accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
  332. k: int,
  333. ) -> List[SamplerOutput]:
  334. """Given the accepted token ids, create a list of SamplerOutput.
  335. The output is padded with -1 tokens such that each sequence has
  336. the same number of outputs.
  337. """
  338. seq_ids = get_all_seq_ids(seq_group_metadata_list)
  339. # shape: [k+1, batch_size]
  340. accepted_token_ids_by_step = accepted_token_ids.transpose(0,
  341. 1).tolist()
  342. sampler_output_list = []
  343. for token_ids_by_step in accepted_token_ids_by_step:
  344. if all(token_id == -1 for token_id in token_ids_by_step):
  345. break
  346. step_output_token_ids = []
  347. for token_id, seq_id in zip(token_ids_by_step, seq_ids):
  348. step_output_token_ids.append(
  349. SequenceGroupOutput(
  350. samples=[
  351. SequenceOutput(
  352. parent_seq_id=seq_id,
  353. output_token=token_id,
  354. # TODO Add verifier logprobs.
  355. logprobs={token_id: Logprob(0.0)},
  356. persistent_data={},
  357. )
  358. ],
  359. prompt_logprobs=None,
  360. ))
  361. sampler_output_list.append(
  362. SamplerOutput(outputs=step_output_token_ids))
  363. maybe_rejsample_metrics = (
  364. self._metrics.maybe_collect_rejsample_metrics(k))
  365. if maybe_rejsample_metrics is not None:
  366. sampler_output_list[
  367. 0].spec_decode_worker_metrics = maybe_rejsample_metrics
  368. return sampler_output_list
  369. @cached_property
  370. def _vocab_size(self) -> int:
  371. """Get the vocab size of the model and make sure it's consistent between
  372. draft and target workers.
  373. """
  374. vocab_sizes = [
  375. worker.vocab_size
  376. for worker in [self.proposer_worker, self.scorer_worker]
  377. ]
  378. assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
  379. return vocab_sizes[0]
  380. @property
  381. def rank(self):
  382. return self.scorer_worker.rank
  383. @property
  384. def device(self):
  385. return self.scorer_worker.device
  386. def get_cache_block_size_bytes(self):
  387. """Return the size of a cache block in bytes.
  388. This function is only used to compose workers within a SpecDecodeWorker.
  389. We leave composing a SpecDecodeWorker within a SpecDecodeWorker
  390. undefined for now, although it could be implemented in the future.
  391. See https://arxiv.org/abs/2308.04623.
  392. """
  393. raise NotImplementedError
  394. def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
  395. proposer_cache_block_size_bytes: int,
  396. total_num_gpu_blocks: int) -> int:
  397. """Given total_num_gpu_blocks, the number of GPU blocks that could be
  398. allocate to the target model, this function calculates how many blocks
  399. should be given to the draft and target model.
  400. Note that usually the block size, in bytes, of each model is different,
  401. as it's a function of number of KV/layer, number of heads, and hidden
  402. dimension size.
  403. Since the target and draft models allocate the same number of blocks, we
  404. simply calculate the number of blocks where if allocated by both models,
  405. the total memory usage from KV cache is no larger than the number of
  406. blocks allocatable by the target model alone.
  407. """
  408. new_num_gpu_blocks = int(
  409. total_num_gpu_blocks * scorer_cache_block_size_bytes /
  410. (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
  411. return new_num_gpu_blocks