spec_decode_worker.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. from functools import cached_property
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import torch
  4. from loguru import logger
  5. from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
  6. SequenceGroupMetadata)
  7. from aphrodite.distributed.communication_op import broadcast_tensor_dict
  8. from aphrodite.modeling.layers.rejection import RejectionSampler
  9. from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
  10. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  11. SpeculativeScorer,
  12. SpeculativeScores)
  13. from aphrodite.spec_decode.metrics import AsyncMetricsCollector
  14. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  15. from aphrodite.spec_decode.ngram_worker import NGramWorker
  16. from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
  17. from aphrodite.spec_decode.util import (create_sequence_group_output,
  18. get_all_num_logprobs, get_all_seq_ids,
  19. get_sampled_token_logprobs, nvtx_range,
  20. split_batch_by_proposal_len)
  21. from aphrodite.task_handler.worker import Worker
  22. from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase,
  23. WorkerBase)
  24. def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
  25. """Helper method that is the entrypoint for Executors which use
  26. WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
  27. """
  28. assert "speculative_config" in kwargs
  29. speculative_config = kwargs.get("speculative_config")
  30. assert speculative_config is not None
  31. target_worker = Worker(*args, **kwargs)
  32. draft_worker_kwargs = kwargs.copy()
  33. # Override draft-model specific worker args.
  34. draft_worker_kwargs.update(
  35. model_config=speculative_config.draft_model_config,
  36. parallel_config=speculative_config.draft_parallel_config,
  37. ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
  38. ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
  39. # TODO allow draft-model specific load config.
  40. #load_config=load_config,
  41. )
  42. spec_decode_worker = SpecDecodeWorker.create_worker(
  43. scorer_worker=target_worker,
  44. draft_worker_kwargs=draft_worker_kwargs,
  45. disable_by_batch_size=speculative_config.
  46. speculative_disable_by_batch_size,
  47. )
  48. return spec_decode_worker
  49. class SpecDecodeWorker(LoraNotSupportedWorkerBase):
  50. """Worker which implements speculative decoding.
  51. Speculative decoding reduces decoding per-token latency by using a proposal
  52. method, such as a small draft model, to speculate ahead of a larger LLM. The
  53. probabilities of the speculative tokens are then determined by the larger
  54. LLM, after which some verification routine determines which (if any) of the
  55. speculative tokens are accepted by the larger LLM.
  56. The current implementation has the following limitations:
  57. * Only draft-model proposal is implemented (contributions for more forms are
  58. welcome!).
  59. * Only top-1 proposal and scoring are implemented. Tree-attention is left as
  60. future work.
  61. * Only lossless rejection sampling is supported. Contributions adding lossy
  62. verification routines are welcome (e.g. Medusa's typical acceptance).
  63. * All sequences in a batch must have the same proposal length, or zero. This
  64. can be improved by having per-sequence speculation in the future.
  65. * The scoring forward pass is done without an MQA kernel, which is
  66. suboptimal especially as the batch size, proposal length, and sequence
  67. lengths grow. Contributions to add a MQA scoring are welcome once
  68. correctness tests pass.
  69. """
  70. @classmethod
  71. def create_worker(
  72. cls,
  73. scorer_worker: WorkerBase,
  74. draft_worker_kwargs: Dict[str, Any],
  75. disable_by_batch_size: Optional[int],
  76. ) -> "SpecDecodeWorker":
  77. ngram_prompt_lookup_max = (
  78. draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
  79. ngram_prompt_lookup_min = (
  80. draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
  81. disable_bonus_tokens = True
  82. if ngram_prompt_lookup_max > 0:
  83. disable_bonus_tokens = False
  84. proposer_worker = NGramWorker(**draft_worker_kwargs)
  85. proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
  86. ngram_prompt_lookup_max)
  87. else:
  88. proposer_worker = MultiStepWorker(**draft_worker_kwargs)
  89. logger.info("Configuring SpecDecodeWorker with "
  90. f"proposer={type(proposer_worker)}")
  91. return SpecDecodeWorker(
  92. proposer_worker,
  93. scorer_worker,
  94. disable_by_batch_size=disable_by_batch_size,
  95. rejection_sampler=RejectionSampler(
  96. disable_bonus_tokens=disable_bonus_tokens, ))
  97. def __init__(
  98. self,
  99. proposer_worker: ProposerWorkerBase,
  100. scorer_worker: WorkerBase,
  101. rejection_sampler: RejectionSampler,
  102. metrics_collector: Optional[AsyncMetricsCollector] = None,
  103. disable_by_batch_size: Optional[int] = None,
  104. ):
  105. """
  106. Create a SpecDecodeWorker.
  107. Args:
  108. proposer_worker: A worker that can produce speculative tokens for
  109. sequences.
  110. scorer_worker: A worker that produces probabilities of speculative
  111. tokens according to some base model. Typically a vanilla vLLM
  112. Worker.
  113. rejection_sampler: A Torch module used to perform modified rejection
  114. sampling for speculative decoding.
  115. disable_by_batch_size: If the batch size is larger than this,
  116. disable speculative decoding for new incoming requests.
  117. metrics_collector: Helper class for collecting metrics; can be set
  118. for testing purposes.
  119. """
  120. self.proposer_worker = proposer_worker
  121. self.scorer_worker = scorer_worker
  122. self.disable_by_batch_size = disable_by_batch_size or float("inf")
  123. self.rejection_sampler = rejection_sampler
  124. self._metrics = AsyncMetricsCollector(
  125. rejection_sampler
  126. ) if metrics_collector is None else metrics_collector
  127. self.probs_dtype = self.rejection_sampler.probs_dtype
  128. self.token_id_dtype = self.rejection_sampler.token_id_dtype
  129. # Lazy initiazliation.
  130. self.scorer: SpeculativeScorer
  131. def init_device(self) -> None:
  132. """Initialize both scorer and proposer models.
  133. """
  134. # The scorer worker model is initialized first in case the proposer
  135. # model has a smaller TP degree than the target worker.
  136. self.scorer_worker.init_device()
  137. self.proposer_worker.init_device()
  138. # NOTE(cade): load_model is not part of the WorkerBase interface.
  139. self.scorer_worker.load_model()
  140. self.proposer_worker.load_model()
  141. self._metrics.init_gpu_tensors(self.rank)
  142. self.rejection_sampler.init_gpu_tensors(self.rank)
  143. self.scorer = BatchExpansionTop1Scorer(
  144. scorer_worker=self.scorer_worker,
  145. device=self.device,
  146. vocab_size=self._vocab_size)
  147. self._configure_model_sampler_for_spec_decode()
  148. def load_model(self, *args, **kwargs):
  149. pass
  150. def _configure_model_sampler_for_spec_decode(self):
  151. """Configure model sampler to emit GPU tensors. This allows spec decode
  152. to keep data on device without transferring to CPU and serializing,
  153. which significantly reduces overhead of rejection sampling.
  154. NOTE: This breaks abstraction boundaries pretty badly. The better
  155. design is to have the "move to CPU and serialize" sampling decision be
  156. done outside of the model/sampler; this way the "last-mile" worker
  157. object which interfaces with the scheduler can serialize and incur the
  158. performance hit as necessary. This allows us to run the worker several
  159. iterations in a row without incurring the "move to CPU and serialize"
  160. performance penalty.
  161. Since this requires a large change to vLLM, we defer it to later and
  162. temporarily accept this broken abstraction boundary.
  163. NOTE: This will require a special check if the proposer worker
  164. does not have a sampler (e.g. ngram speculation).
  165. """
  166. (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
  167. ) = True
  168. self.proposer_worker.set_include_gpu_probs_tensor()
  169. def determine_num_available_blocks(self) -> Tuple[int, int]:
  170. """Determine the number of cache blocks to use.
  171. This is done by profiling the scorer model (which is typically the
  172. larger of the two). Then the total memory which would be used by the
  173. scorer cache is divided evenly between the proposer and scorer model KV,
  174. such that the number of blocks is equal in both KV caches.
  175. """
  176. num_gpu_blocks, num_cpu_blocks = (
  177. self.scorer_worker.determine_num_available_blocks())
  178. scorer_cache_block_size_bytes = (
  179. self.scorer_worker.get_cache_block_size_bytes())
  180. proposer_cache_block_size_bytes = (
  181. self.proposer_worker.get_cache_block_size_bytes())
  182. new_num_gpu_blocks = split_num_cache_blocks_evenly(
  183. scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
  184. num_gpu_blocks)
  185. return new_num_gpu_blocks, num_cpu_blocks
  186. def initialize_cache(self, num_gpu_blocks: int,
  187. num_cpu_blocks: int) -> None:
  188. """Initialize the cache engine of the scorer and proposer workers.
  189. """
  190. self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  191. num_cpu_blocks=num_cpu_blocks)
  192. self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  193. num_cpu_blocks=num_cpu_blocks)
  194. @torch.inference_mode()
  195. def execute_model(
  196. self,
  197. execute_model_req: Optional[ExecuteModelRequest] = None
  198. ) -> List[SamplerOutput]:
  199. """Perform speculative decoding on the input batch.
  200. """
  201. if self.rank != self._driver_rank:
  202. self._run_non_driver_rank()
  203. return []
  204. if execute_model_req is None:
  205. # This signals that there's no more requests to process for now.
  206. # All workers are running infinite loop with broadcast_tensor_dict,
  207. # and it stops the loop when the driver broadcasts an empty input.
  208. # Send an empty input to notify all other workers to stop their
  209. # execution loop.
  210. broadcast_tensor_dict({}, src=0)
  211. return []
  212. disable_all_speculation = self._should_disable_all_speculation(
  213. execute_model_req)
  214. num_lookahead_slots = execute_model_req.num_lookahead_slots
  215. # Broadcast how many lookahead slots are scheduled for this step, and
  216. # whether all speculation is disabled, to all non-driver workers.
  217. # This is required as if the number of draft model runs changes
  218. # dynamically, the non-driver workers won't know unless we perform a
  219. # communication to inform them.
  220. broadcast_dict = dict(
  221. num_lookahead_slots=num_lookahead_slots,
  222. disable_all_speculation=disable_all_speculation,
  223. )
  224. broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
  225. assert execute_model_req.seq_group_metadata_list is not None, (
  226. "speculative decoding requires non-None seq_group_metadata_list")
  227. self._maybe_disable_speculative_tokens(
  228. disable_all_speculation, execute_model_req.seq_group_metadata_list)
  229. # Speculative decoding is disabled in the following cases:
  230. # 1. Prefill phase: Speculative decoding is not
  231. # used during the prefill phase.
  232. # 2. Auto-disable enabled: The running queue size exceeds
  233. # the specified threshold.
  234. # 3. No request: There are no requests in the batch.
  235. # In any of these cases, the proposer and scorer workers
  236. # are called normally.
  237. if num_lookahead_slots == 0 or len(
  238. execute_model_req.seq_group_metadata_list
  239. ) == 0 or disable_all_speculation:
  240. return self._run_no_spec(execute_model_req,
  241. skip_proposer=disable_all_speculation)
  242. return self._run_speculative_decoding_step(execute_model_req,
  243. num_lookahead_slots)
  244. @torch.inference_mode()
  245. def start_worker_execution_loop(self) -> None:
  246. """Execute model loop to perform speculative decoding
  247. in parallel worker."""
  248. while self._run_non_driver_rank():
  249. pass
  250. def _should_disable_all_speculation(
  251. self, execute_model_req: ExecuteModelRequest) -> bool:
  252. # When the batch size is too large, disable speculative decoding
  253. # to stop trading off throughput for latency.
  254. disable_all_speculation = (execute_model_req.running_queue_size >=
  255. self.disable_by_batch_size)
  256. return disable_all_speculation
  257. def _maybe_disable_speculative_tokens(
  258. self, disable_all_speculation: bool,
  259. seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
  260. if not disable_all_speculation:
  261. return
  262. for seq_group_metadata in seq_group_metadata_list:
  263. # Once num_speculative_tokens is set to 0, the spec decode
  264. # of this request will be disabled forever.
  265. # TODO: We currently store spec decoding specific
  266. # state in the global data structure, but we should maintain
  267. # this state within spec decode worker.
  268. seq_group_metadata.num_speculative_tokens = 0
  269. @nvtx_range("spec_decode_worker._run_no_spec")
  270. def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
  271. skip_proposer: bool) -> List[SamplerOutput]:
  272. """Run a single generation step without any speculation. The input is
  273. sent to the proposer and scorer model so that the KV cache is consistent
  274. between the two. When skip_proposer is True, the proposer model is
  275. not called, meaning that the kv-cache in proposer for requests is not
  276. updated, so they cannot enable spec decode in the rest decoding.
  277. """
  278. if not skip_proposer:
  279. self.proposer_worker.execute_model(execute_model_req)
  280. sampler_output = self.scorer_worker.execute_model(execute_model_req)
  281. assert len(sampler_output) == 1
  282. sampler_output = sampler_output[0]
  283. # Clear device tensors from sampler output. This reduces communication
  284. # overhead when the engine runs in a different process than the workers.
  285. sampler_output.probs = None
  286. sampler_output.sampled_tokens = None
  287. sampler_output.logprobs = None
  288. return [sampler_output]
  289. def _run_non_driver_rank(self) -> bool:
  290. """Run proposer and verifier model in non-driver workers. This is used
  291. for both speculation cases (num_lookahead_slots>0) and non-speculation
  292. cases (e.g. prefill).
  293. Returns True iff there are remaining sequences to process.
  294. """
  295. assert self.rank != self._driver_rank
  296. data = broadcast_tensor_dict(src=self._driver_rank)
  297. if not data:
  298. return False
  299. num_lookahead_slots = data["num_lookahead_slots"]
  300. # Even if num_lookahead_slots is zero, we want to run the proposer model
  301. # as it may have KV.
  302. #
  303. # We run the proposer once per lookahead slot. In the future we should
  304. # delegate how many times it runs to the proposer.
  305. for _ in range(max(num_lookahead_slots, 1)):
  306. self.proposer_worker.execute_model()
  307. self.scorer_worker.execute_model()
  308. return True
  309. @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
  310. def _run_speculative_decoding_step(
  311. self, execute_model_req: ExecuteModelRequest,
  312. num_lookahead_slots: int) -> List[SamplerOutput]:
  313. """Execute a single step of speculative decoding.
  314. This invokes the proposer worker to get k speculative tokens for each
  315. sequence, then scores each speculative token using the scoring worker.
  316. Returns a list of SamplerOutput, each containing a single token per
  317. sequence.
  318. """
  319. assert num_lookahead_slots == execute_model_req.num_lookahead_slots
  320. # Generate proposals using draft worker.
  321. proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
  322. proposal_scores = self.scorer.score_proposals(
  323. execute_model_req,
  324. proposals,
  325. )
  326. accepted_token_ids, target_logprobs = self._verify_tokens(
  327. execute_model_req.seq_group_metadata_list, proposal_scores,
  328. proposals, execute_model_req.num_lookahead_slots)
  329. return self._create_output_sampler_list(
  330. execute_model_req.seq_group_metadata_list,
  331. accepted_token_ids,
  332. target_logprobs=target_logprobs,
  333. k=execute_model_req.num_lookahead_slots)
  334. @nvtx_range("spec_decode_worker._verify_tokens")
  335. def _verify_tokens(
  336. self,
  337. seq_group_metadata_list: List[SequenceGroupMetadata],
  338. proposal_scores: SpeculativeScores,
  339. proposals: SpeculativeProposals,
  340. max_proposal_len: int,
  341. ) -> Tuple[torch.Tensor, torch.Tensor]:
  342. """Determine which speculative tokens are accepted using the
  343. probabilities of each token according to the proposer and scorer models.
  344. Returns a tuple of Tensors, one for the accepted token ids and one for
  345. the logprobs according to the scoring model.
  346. """
  347. proposal_lens_list = proposals.proposal_lens.tolist()
  348. # Aphrodite currently only supports proposal lens equal to zero or the
  349. # batch proposal len. This adds some complexity (splitting the batch
  350. # into spec and non spec sequences) and should be removed in the
  351. # future. It can be done by supporting per-sequence proposal lens.
  352. _, spec_indices = split_batch_by_proposal_len(
  353. seq_group_metadata_list,
  354. proposal_lens_list,
  355. select_proposal_len_zero=False)
  356. _, non_spec_indices = split_batch_by_proposal_len(
  357. seq_group_metadata_list,
  358. proposal_lens_list,
  359. select_proposal_len_zero=True)
  360. original_indices = spec_indices + non_spec_indices
  361. # Get probabilities of target model, excluding bonus token.
  362. proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
  363. # Get non-speculative sampled tokens from target model.
  364. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
  365. # Get bonus tokens from target model.
  366. bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
  367. # Get probabilities according to proposal method.
  368. proposal_probs = proposals.proposal_probs[spec_indices]
  369. # Get proposed tokens.
  370. proposal_token_ids = proposals.proposal_token_ids[spec_indices]
  371. accepted_token_ids = self.rejection_sampler(
  372. target_probs=proposal_verifier_probs,
  373. bonus_token_ids=bonus_token_ids,
  374. draft_probs=proposal_probs,
  375. draft_token_ids=proposal_token_ids,
  376. )
  377. # Append output tokens from non-speculative sequences to
  378. # the accepted token ids tensor.
  379. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
  380. 1).clone()
  381. non_spec_token_ids[:, 1:] = -1
  382. accepted_token_ids = torch.cat(
  383. [accepted_token_ids, non_spec_token_ids])
  384. logprobs = proposal_scores.logprobs
  385. # Rearrange so that results are in the order of the original seq group
  386. # metadata.
  387. accepted_token_ids[original_indices] = accepted_token_ids.clone()
  388. return accepted_token_ids, logprobs
  389. def _create_output_sampler_list(
  390. self,
  391. seq_group_metadata_list: List[SequenceGroupMetadata],
  392. accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
  393. target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
  394. k: int,
  395. ) -> List[SamplerOutput]:
  396. """Given the accepted token ids, create a list of SamplerOutput.
  397. The output is padded with -1 tokens such that each sequence has
  398. the same number of outputs.
  399. """
  400. batch_size, num_steps = accepted_token_ids.shape
  401. # Organize input tensors by step instead of by sequence.
  402. target_logprobs_by_step = target_logprobs.transpose(0, 1)
  403. accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
  404. # Get the logprobs/rank of the accepted tokens.
  405. (accepted_token_id_ranks_by_step,
  406. accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
  407. logprob_tensor=target_logprobs_by_step,
  408. sampled_token_ids=accepted_token_ids_by_step,
  409. )
  410. # Get the top-k logprobs (which may or may not include the logprob of
  411. # the accepted token).
  412. (topk_logprobs_by_step,
  413. topk_indices_by_step) = target_logprobs_by_step.topk(
  414. k=self.scorer_worker.model_config.max_logprobs,
  415. dim=-1,
  416. )
  417. # Get the sequence ids and num_logprobs (sampling parameter) in the
  418. # batch.
  419. seq_ids = get_all_seq_ids(seq_group_metadata_list)
  420. num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
  421. # Serialize all tensors to CPU Python lists.
  422. accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
  423. accepted_token_id_ranks_by_step = (
  424. accepted_token_id_ranks_by_step.tolist())
  425. accepted_token_id_logprobs_by_step = (
  426. accepted_token_id_logprobs_by_step.tolist())
  427. topk_logprobs_by_step = topk_logprobs_by_step.tolist()
  428. topk_indices_by_step = topk_indices_by_step.tolist()
  429. # Construct the output on a per-step, per-sequence basis.
  430. sampler_output_list = []
  431. for step_index in range(num_steps):
  432. if all(token_id == -1
  433. for token_id in accepted_token_ids_by_step[step_index]):
  434. break
  435. step_output_token_ids = []
  436. for sequence_index in range(batch_size):
  437. # Each sequence may have a different num_logprobs; retrieve it.
  438. num_logprobs = num_logprobs_per_seq[sequence_index]
  439. step_output_token_ids.append(
  440. create_sequence_group_output(
  441. token_id=accepted_token_ids_by_step[step_index]
  442. [sequence_index],
  443. token_id_logprob_rank=accepted_token_id_ranks_by_step[
  444. step_index][sequence_index],
  445. token_id_logprob=accepted_token_id_logprobs_by_step[
  446. step_index][sequence_index],
  447. seq_id=seq_ids[sequence_index],
  448. topk_token_ids=topk_indices_by_step[step_index]
  449. [sequence_index][:num_logprobs],
  450. topk_logprobs=topk_logprobs_by_step[step_index]
  451. [sequence_index][:num_logprobs],
  452. ))
  453. sampler_output_list.append(
  454. SamplerOutput(outputs=step_output_token_ids))
  455. maybe_rejsample_metrics = (
  456. self._metrics.maybe_collect_rejsample_metrics(k))
  457. if maybe_rejsample_metrics is not None:
  458. sampler_output_list[
  459. 0].spec_decode_worker_metrics = maybe_rejsample_metrics
  460. return sampler_output_list
  461. @cached_property
  462. def _vocab_size(self) -> int:
  463. """Get the vocab size of the model and make sure it's consistent between
  464. draft and target workers.
  465. """
  466. vocab_sizes = [
  467. worker.vocab_size
  468. for worker in [self.proposer_worker, self.scorer_worker]
  469. ]
  470. assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
  471. return vocab_sizes[0]
  472. @property
  473. def rank(self):
  474. return self.scorer_worker.rank
  475. @property
  476. def device(self):
  477. return self.scorer_worker.device
  478. @property
  479. def _driver_rank(self) -> int:
  480. return 0
  481. def get_cache_block_size_bytes(self):
  482. """Return the size of a cache block in bytes.
  483. This function is only used to compose workers within a SpecDecodeWorker.
  484. We leave composing a SpecDecodeWorker within a SpecDecodeWorker
  485. undefined for now, although it could be implemented in the future.
  486. See https://arxiv.org/abs/2308.04623.
  487. """
  488. raise NotImplementedError
  489. def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
  490. proposer_cache_block_size_bytes: int,
  491. total_num_gpu_blocks: int) -> int:
  492. """Given total_num_gpu_blocks, the number of GPU blocks that could be
  493. allocate to the target model, this function calculates how many blocks
  494. should be given to the draft and target model.
  495. Note that usually the block size, in bytes, of each model is different,
  496. as it's a function of number of KV/layer, number of heads, and hidden
  497. dimension size.
  498. Since the target and draft models allocate the same number of blocks, we
  499. simply calculate the number of blocks where if allocated by both models,
  500. the total memory usage from KV cache is no larger than the number of
  501. blocks allocatable by the target model alone.
  502. """
  503. new_num_gpu_blocks = int(
  504. total_num_gpu_blocks * scorer_cache_block_size_bytes /
  505. (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
  506. return new_num_gpu_blocks