spec_decode_worker.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. from typing import List, Tuple, Optional, Dict
  2. from functools import cached_property
  3. import torch
  4. from aphrodite.spec_decode.metrics import AsyncMetricsCollector
  5. from aphrodite.common.sequence import (
  6. SamplerOutput,
  7. SequenceGroupMetadata,
  8. SequenceGroupOutput,
  9. SequenceOutput,
  10. )
  11. from aphrodite.task_handler.worker import Worker
  12. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  13. from aphrodite.modeling.layers.rejection import RejectionSampler
  14. from aphrodite.common.config import CacheConfig
  15. from aphrodite.spec_decode.util import (
  16. nvtx_range,
  17. get_all_seq_ids,
  18. split_batch_by_proposal_len,
  19. )
  20. from aphrodite.spec_decode.interfaces import (
  21. SpeculativeProposals,
  22. SpeculativeScores,
  23. )
  24. from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
  25. from aphrodite.spec_decode.interfaces import SpeculativeScorer
  26. class SpecDecodeWorker:
  27. """Worker which implements speculative decoding.
  28. Speculative decoding reduces decoding per-token latency by using a proposal
  29. method, such as a small draft model, to speculate ahead of a larger LLM. The
  30. probabilities of the speculative tokens are then determined by the larger
  31. LLM, after which some verification routine determines which (if any) of the
  32. speculative tokens are accepted by the larger LLM.
  33. The current implementation has the following limitations:
  34. * Only draft-model proposal is implemented (contributions for more forms are
  35. welcome!).
  36. * Only top-1 proposal and scoring are implemented. Tree-attention is left as
  37. future work.
  38. * Only lossless rejection sampling is supported. Contributions adding lossy
  39. verification routines are welcome (e.g. Medusa's typical acceptance).
  40. * All sequences in a batch must have the same proposal length, or zero. This
  41. can be improved by having per-sequence speculation in the future.
  42. * The scoring forward pass is done without an MQA kernel, which is
  43. suboptimal especially as the batch size, proposal length, and sequence
  44. lengths grow. Contributions to add a MQA scoring are welcome once
  45. correctness tests pass.
  46. """
  47. def __init__(
  48. self,
  49. proposer_worker: MultiStepWorker,
  50. scorer_worker: Worker,
  51. rejection_sampler: RejectionSampler,
  52. metrics_collector: Optional[AsyncMetricsCollector] = None,
  53. ):
  54. """
  55. Create a SpecDecodeWorker.
  56. Args:
  57. proposer_worker: A worker that can produce speculative tokens for
  58. sequences.
  59. scorer_worker: A worker that produces probabilities of speculative
  60. tokens according to some base model. Typically a vanilla
  61. Aphrodite Worker.
  62. rejection_sampler: A Torch module used to perform modified rejection
  63. sampling for speculative decoding.
  64. metrics_collector: Helper class for collecting metrics; can be set
  65. for testing purposes.
  66. """
  67. self.proposer_worker = proposer_worker
  68. self.scorer_worker = scorer_worker
  69. self.rejection_sampler = rejection_sampler
  70. self._metrics = (AsyncMetricsCollector(rejection_sampler)
  71. if metrics_collector is None else metrics_collector)
  72. self.probs_dtype = self.rejection_sampler.probs_dtype
  73. self.token_id_dtype = self.rejection_sampler.token_id_dtype
  74. self.scorer: SpeculativeScorer = None
  75. def init_model(self) -> None:
  76. """Initialize both scorer and proposer models."""
  77. # The scorer worker model is initialized first in case the proposer
  78. # model has a smaller TP degree than the target worker.
  79. self.scorer_worker.init_model()
  80. self.proposer_worker.init_model()
  81. self._metrics.init_gpu_tensors(self.rank)
  82. self.rejection_sampler.init_gpu_tensors(self.rank)
  83. self.scorer = BatchExpansionTop1Scorer(
  84. scorer_worker=self.scorer_worker,
  85. device=self.device,
  86. vocab_size=self._vocab_size,
  87. )
  88. def profile_num_available_blocks(
  89. self,
  90. block_size: int,
  91. gpu_memory_utilization: float,
  92. cpu_swap_space: int,
  93. cache_dtype: str,
  94. ) -> Tuple[int, int]:
  95. """Determine the number of cache blocks to use.
  96. This is done by profiling the scorer model (which is typically the
  97. larger of the two). Then the total memory which would be used by the
  98. scorer cache is divided evenly between the proposer and scorer model KV,
  99. such that the number of blocks is equal in both KV caches.
  100. """
  101. (
  102. num_gpu_blocks,
  103. num_cpu_blocks,
  104. ) = self.scorer_worker.profile_num_available_blocks(
  105. block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype)
  106. scorer_cache_block_size_bytes = (
  107. self.scorer_worker.get_cache_block_size_bytes(
  108. block_size, cache_dtype))
  109. proposer_cache_block_size_bytes = (
  110. self.proposer_worker.get_cache_block_size_bytes(
  111. block_size, cache_dtype))
  112. new_num_gpu_blocks = split_num_cache_blocks_evenly(
  113. scorer_cache_block_size_bytes,
  114. proposer_cache_block_size_bytes,
  115. num_gpu_blocks,
  116. )
  117. return new_num_gpu_blocks, num_cpu_blocks
  118. def init_cache_engine(self, cache_config: CacheConfig):
  119. """Initialize the cache engine of the scorer and proposer workers."""
  120. self.scorer_worker.init_cache_engine(cache_config)
  121. self.proposer_worker.init_cache_engine(cache_config)
  122. @torch.inference_mode()
  123. def execute_model(
  124. self,
  125. seq_group_metadata_list: List[SequenceGroupMetadata],
  126. blocks_to_swap_in: Optional[Dict[int, int]],
  127. blocks_to_swap_out: Optional[Dict[int, int]],
  128. blocks_to_copy: Optional[Dict[int, List[int]]],
  129. num_spec_tokens: int,
  130. ) -> List[SamplerOutput]:
  131. """Perform speculative decoding on the input batch."""
  132. assert seq_group_metadata_list is not None, (
  133. "speculative decoding "
  134. "requires non-None seq_group_metadata_list")
  135. # If no spec tokens, call the proposer and scorer workers normally.
  136. # Used for prefill.
  137. if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0:
  138. return self._run_no_spec(
  139. seq_group_metadata_list=seq_group_metadata_list,
  140. blocks_to_swap_in=blocks_to_swap_in,
  141. blocks_to_swap_out=blocks_to_swap_out,
  142. blocks_to_copy=blocks_to_copy,
  143. )
  144. return self._run_speculative_decoding_step(
  145. seq_group_metadata_list=seq_group_metadata_list,
  146. blocks_to_swap_in=blocks_to_swap_in,
  147. blocks_to_swap_out=blocks_to_swap_out,
  148. blocks_to_copy=blocks_to_copy,
  149. k=num_spec_tokens,
  150. )
  151. @nvtx_range("spec_decode_worker._run_no_spec")
  152. def _run_no_spec(
  153. self,
  154. seq_group_metadata_list: List[SequenceGroupMetadata],
  155. blocks_to_swap_in: Optional[Dict[int, int]],
  156. blocks_to_swap_out: Optional[Dict[int, int]],
  157. blocks_to_copy: Optional[Dict[int, List[int]]],
  158. ) -> List[SamplerOutput]:
  159. """Run a prefill step, without any speculation. The input is sent to the
  160. proposer and scorer model so that the KV cache is consistent between the
  161. two.
  162. """
  163. self.proposer_worker.execute_model(
  164. seq_group_metadata_list=seq_group_metadata_list,
  165. blocks_to_swap_in=blocks_to_swap_in,
  166. blocks_to_swap_out=blocks_to_swap_out,
  167. blocks_to_copy=blocks_to_copy,
  168. return_python_output=False,
  169. )
  170. sampler_output = self.scorer_worker.execute_model(
  171. seq_group_metadata_list=seq_group_metadata_list,
  172. blocks_to_swap_in=blocks_to_swap_in,
  173. blocks_to_swap_out=blocks_to_swap_out,
  174. blocks_to_copy=blocks_to_copy,
  175. )
  176. # Clear device tensors from sampler output. This reduces communication
  177. # overhead when the engine runs in a different process than the workers.
  178. sampler_output.probs = None
  179. sampler_output.sampled_tokens = None
  180. return [sampler_output]
  181. @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
  182. def _run_speculative_decoding_step(
  183. self,
  184. seq_group_metadata_list: List[SequenceGroupMetadata],
  185. blocks_to_swap_in: Optional[Dict[int, int]],
  186. blocks_to_swap_out: Optional[Dict[int, int]],
  187. blocks_to_copy: Optional[Dict[int, List[int]]],
  188. k: int,
  189. ) -> List[SamplerOutput]:
  190. """Execute a single step of speculative decoding.
  191. This invokes the proposer worker to get k speculative tokens for each
  192. sequence, then scores each speculative token using the scoring worker.
  193. Returns a list of SamplerOutput, each containing a single token per
  194. sequence.
  195. """
  196. # Generate proposals using draft worker.
  197. proposals = self.proposer_worker.get_spec_proposals(
  198. seq_group_metadata_list,
  199. blocks_to_swap_in,
  200. blocks_to_swap_out,
  201. blocks_to_copy,
  202. k,
  203. )
  204. proposal_scores = self.scorer.score_proposals(
  205. seq_group_metadata_list,
  206. blocks_to_swap_in,
  207. blocks_to_swap_out,
  208. blocks_to_copy,
  209. k,
  210. proposals,
  211. )
  212. accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
  213. proposal_scores, proposals, k)
  214. return self._create_output_sampler_list(seq_group_metadata_list,
  215. accepted_token_ids, k)
  216. @nvtx_range("spec_decode_worker._verify_tokens")
  217. def _verify_tokens(
  218. self,
  219. seq_group_metadata_list: List[SequenceGroupMetadata],
  220. proposal_scores: SpeculativeScores,
  221. proposals: SpeculativeProposals,
  222. max_proposal_len: int,
  223. ) -> torch.Tensor:
  224. """Determine which speculative tokens are accepted using the
  225. probabilities of each token according to the proposer and scorer models.
  226. """
  227. proposal_lens_list = proposals.proposal_lens.tolist()
  228. # Aphrodite currently only supports proposal lens equal to zero or the
  229. # batch proposal len. This adds some complexity (splitting the batch
  230. # into spec and non spec sequences) and should be removed in the
  231. # future. It can be done by supporting per-sequence proposal lens.
  232. _, spec_indices = split_batch_by_proposal_len(
  233. seq_group_metadata_list,
  234. proposal_lens_list,
  235. select_proposal_len_zero=False,
  236. )
  237. _, non_spec_indices = split_batch_by_proposal_len(
  238. seq_group_metadata_list,
  239. proposal_lens_list,
  240. select_proposal_len_zero=True,
  241. )
  242. original_indices = spec_indices + non_spec_indices
  243. proposal_probs = proposal_scores.probs[spec_indices, :-1]
  244. bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
  245. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
  246. accepted_token_ids = self.rejection_sampler(
  247. proposal_probs,
  248. bonus_token_ids,
  249. proposals.proposal_probs,
  250. proposals.proposal_token_ids,
  251. )
  252. # Append output tokens from non-speculative sequences to
  253. # the accepted token ids tensor.
  254. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
  255. 1).clone()
  256. non_spec_token_ids[:, 1:] = -1
  257. accepted_token_ids = torch.cat(
  258. [accepted_token_ids, non_spec_token_ids])
  259. # Rearrange so that results are in the order of the original seq group
  260. # metadata.
  261. accepted_token_ids[original_indices] = accepted_token_ids.clone()
  262. return accepted_token_ids
  263. def _create_output_sampler_list(
  264. self,
  265. seq_group_metadata_list: List[SequenceGroupMetadata],
  266. accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
  267. k: int,
  268. ) -> List[SamplerOutput]:
  269. """Given the accepted token ids, create a list of SamplerOutput.
  270. The output is padded with -1 tokens such that each sequence has
  271. the same number of outputs.
  272. """
  273. seq_ids = get_all_seq_ids(seq_group_metadata_list)
  274. # shape: [k+1, batch_size]
  275. accepted_token_ids_by_step = accepted_token_ids.transpose(0,
  276. 1).tolist()
  277. sampler_output_list = []
  278. for token_ids_by_step in accepted_token_ids_by_step:
  279. if all(token_id == -1 for token_id in token_ids_by_step):
  280. break
  281. step_output_token_ids = []
  282. for token_id, seq_id in zip(token_ids_by_step, seq_ids):
  283. step_output_token_ids.append(
  284. SequenceGroupOutput(
  285. samples=[
  286. SequenceOutput(
  287. parent_seq_id=seq_id,
  288. output_token=token_id,
  289. # TODO Add verifier logprobs.
  290. logprobs={token_id: 0.0},
  291. persistent_data={},
  292. )
  293. ],
  294. prompt_logprobs=None,
  295. ))
  296. sampler_output_list.append(
  297. SamplerOutput(outputs=step_output_token_ids))
  298. maybe_rejsample_metrics = self._metrics.maybe_collect_rejsample_metrics(
  299. k)
  300. if maybe_rejsample_metrics is not None:
  301. sampler_output_list[
  302. 0].spec_decode_worker_metrics = maybe_rejsample_metrics
  303. return sampler_output_list
  304. @cached_property
  305. def _vocab_size(self) -> int:
  306. """Get the vocab size of the model and make sure it's consistent between
  307. draft and target workers.
  308. """
  309. vocab_sizes = [
  310. worker.vocab_size
  311. for worker in [self.proposer_worker, self.scorer_worker]
  312. ]
  313. assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
  314. return vocab_sizes[0]
  315. @property
  316. def rank(self):
  317. return self.scorer_worker.rank
  318. @property
  319. def device(self):
  320. return self.scorer_worker.device
  321. def split_num_cache_blocks_evenly(
  322. scorer_cache_block_size_bytes: int,
  323. proposer_cache_block_size_bytes: int,
  324. total_num_gpu_blocks: int,
  325. ) -> int:
  326. """Given total_num_gpu_blocks, the number of GPU blocks that could be
  327. allocate to the target model, this function calculates how many blocks
  328. should be given to the draft and target model.
  329. Note that usually the block size, in bytes, of each model is different,
  330. as it's a function of number of KV/layer, number of heads, and hidden
  331. dimension size.
  332. Since the target and draft models allocate the same number of blocks, we
  333. simply calculate the number of blocks where if allocated by both models,
  334. the total memory usage from KV cache is no larger than the number of
  335. blocks allocatable by the target model alone.
  336. """
  337. new_num_gpu_blocks = int(
  338. total_num_gpu_blocks * scorer_cache_block_size_bytes /
  339. (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
  340. return new_num_gpu_blocks