1
0

spec_decode_worker.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. from functools import cached_property
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import torch
  4. from loguru import logger
  5. from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
  6. SequenceGroupMetadata)
  7. from aphrodite.distributed.communication_op import broadcast_tensor_dict
  8. from aphrodite.modeling.layers.rejection import RejectionSampler
  9. from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
  10. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  11. SpeculativeScorer,
  12. SpeculativeScores)
  13. from aphrodite.spec_decode.metrics import AsyncMetricsCollector
  14. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  15. from aphrodite.spec_decode.ngram_worker import NGramWorker
  16. from aphrodite.spec_decode.util import (create_sequence_group_output,
  17. get_all_num_logprobs, get_all_seq_ids,
  18. get_sampled_token_logprobs, nvtx_range,
  19. split_batch_by_proposal_len)
  20. from aphrodite.task_handler.worker import Worker
  21. from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase,
  22. WorkerBase)
  23. def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
  24. """Helper method that is the entrypoint for Executors which use
  25. WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
  26. """
  27. assert "speculative_config" in kwargs
  28. speculative_config = kwargs.get("speculative_config")
  29. assert speculative_config is not None
  30. target_worker = Worker(*args, **kwargs)
  31. draft_worker_kwargs = kwargs.copy()
  32. # Override draft-model specific worker args.
  33. draft_worker_kwargs.update(
  34. model_config=speculative_config.draft_model_config,
  35. parallel_config=speculative_config.draft_parallel_config,
  36. ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
  37. ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
  38. # TODO allow draft-model specific load config.
  39. #load_config=load_config,
  40. )
  41. spec_decode_worker = SpecDecodeWorker.create_worker(
  42. scorer_worker=target_worker,
  43. draft_worker_kwargs=draft_worker_kwargs,
  44. disable_by_batch_size=speculative_config.
  45. speculative_disable_by_batch_size,
  46. )
  47. return spec_decode_worker
  48. class SpecDecodeWorker(LoraNotSupportedWorkerBase):
  49. """Worker which implements speculative decoding.
  50. Speculative decoding reduces decoding per-token latency by using a proposal
  51. method, such as a small draft model, to speculate ahead of a larger LLM. The
  52. probabilities of the speculative tokens are then determined by the larger
  53. LLM, after which some verification routine determines which (if any) of the
  54. speculative tokens are accepted by the larger LLM.
  55. The current implementation has the following limitations:
  56. * Only draft-model proposal is implemented (contributions for more forms are
  57. welcome!).
  58. * Only top-1 proposal and scoring are implemented. Tree-attention is left as
  59. future work.
  60. * Only lossless rejection sampling is supported. Contributions adding lossy
  61. verification routines are welcome (e.g. Medusa's typical acceptance).
  62. * All sequences in a batch must have the same proposal length, or zero. This
  63. can be improved by having per-sequence speculation in the future.
  64. * The scoring forward pass is done without an MQA kernel, which is
  65. suboptimal especially as the batch size, proposal length, and sequence
  66. lengths grow. Contributions to add a MQA scoring are welcome once
  67. correctness tests pass.
  68. """
  69. @classmethod
  70. def create_worker(
  71. cls,
  72. scorer_worker: WorkerBase,
  73. draft_worker_kwargs: Dict[str, Any],
  74. disable_by_batch_size: Optional[int],
  75. ) -> "SpecDecodeWorker":
  76. ngram_prompt_lookup_max = (
  77. draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
  78. ngram_prompt_lookup_min = (
  79. draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
  80. disable_bonus_tokens = True
  81. if ngram_prompt_lookup_max > 0:
  82. disable_bonus_tokens = False
  83. proposer_worker = NGramWorker(**draft_worker_kwargs)
  84. proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
  85. ngram_prompt_lookup_max)
  86. else:
  87. proposer_worker = MultiStepWorker(**draft_worker_kwargs)
  88. logger.info("Configuring SpecDecodeWorker with "
  89. f"proposer={type(proposer_worker)}")
  90. return SpecDecodeWorker(
  91. proposer_worker,
  92. scorer_worker,
  93. disable_by_batch_size=disable_by_batch_size,
  94. rejection_sampler=RejectionSampler(
  95. disable_bonus_tokens=disable_bonus_tokens, ))
  96. def __init__(
  97. self,
  98. proposer_worker: WorkerBase,
  99. scorer_worker: WorkerBase,
  100. rejection_sampler: RejectionSampler,
  101. metrics_collector: Optional[AsyncMetricsCollector] = None,
  102. disable_by_batch_size: Optional[int] = None,
  103. ):
  104. """
  105. Create a SpecDecodeWorker.
  106. Args:
  107. proposer_worker: A worker that can produce speculative tokens for
  108. sequences.
  109. scorer_worker: A worker that produces probabilities of speculative
  110. tokens according to some base model. Typically a vanilla
  111. Aphrodite Worker.
  112. rejection_sampler: A Torch module used to perform modified rejection
  113. sampling for speculative decoding.
  114. disable_by_batch_size: If the batch size is larger than this,
  115. disable speculative decoding for new incoming requests.
  116. metrics_collector: Helper class for collecting metrics; can be set
  117. for testing purposes.
  118. """
  119. self.proposer_worker = proposer_worker
  120. self.scorer_worker = scorer_worker
  121. self.disable_by_batch_size = disable_by_batch_size or float("inf")
  122. self.rejection_sampler = rejection_sampler
  123. self._metrics = AsyncMetricsCollector(
  124. rejection_sampler
  125. ) if metrics_collector is None else metrics_collector
  126. self.probs_dtype = self.rejection_sampler.probs_dtype
  127. self.token_id_dtype = self.rejection_sampler.token_id_dtype
  128. # Lazy initiazliation.
  129. self.scorer: SpeculativeScorer
  130. def init_device(self) -> None:
  131. """Initialize both scorer and proposer models.
  132. """
  133. # The scorer worker model is initialized first in case the proposer
  134. # model has a smaller TP degree than the target worker.
  135. self.scorer_worker.init_device()
  136. self.proposer_worker.init_device()
  137. # NOTE(cade): load_model is not part of the WorkerBase interface.
  138. self.scorer_worker.load_model()
  139. self.proposer_worker.load_model()
  140. self._metrics.init_gpu_tensors(self.rank)
  141. self.rejection_sampler.init_gpu_tensors(self.rank)
  142. self.scorer = BatchExpansionTop1Scorer(
  143. scorer_worker=self.scorer_worker,
  144. device=self.device,
  145. vocab_size=self._vocab_size)
  146. self._configure_model_sampler_for_spec_decode()
  147. def load_model(self, *args, **kwargs):
  148. pass
  149. def _configure_model_sampler_for_spec_decode(self):
  150. """Configure model sampler to emit GPU tensors. This allows spec decode
  151. to keep data on device without transferring to CPU and serializing,
  152. which significantly reduces overhead of rejection sampling.
  153. NOTE: This breaks abstraction boundaries pretty badly. The better
  154. design is to have the "move to CPU and serialize" sampling decision be
  155. done outside of the model/sampler; this way the "last-mile" worker
  156. object which interfaces with the scheduler can serialize and incur the
  157. performance hit as necessary. This allows us to run the worker several
  158. iterations in a row without incurring the "move to CPU and serialize"
  159. performance penalty.
  160. Since this requires a large change to Aphrodite, we defer it to later
  161. and temporarily accept this broken abstraction boundary.
  162. NOTE: This will require a special check if the proposer worker
  163. does not have a sampler (e.g. ngram speculation).
  164. """
  165. (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
  166. ) = True
  167. self.proposer_worker.set_include_gpu_probs_tensor()
  168. def determine_num_available_blocks(self) -> Tuple[int, int]:
  169. """Determine the number of cache blocks to use.
  170. This is done by profiling the scorer model (which is typically the
  171. larger of the two). Then the total memory which would be used by the
  172. scorer cache is divided evenly between the proposer and scorer model KV,
  173. such that the number of blocks is equal in both KV caches.
  174. """
  175. num_gpu_blocks, num_cpu_blocks = (
  176. self.scorer_worker.determine_num_available_blocks())
  177. scorer_cache_block_size_bytes = (
  178. self.scorer_worker.get_cache_block_size_bytes())
  179. proposer_cache_block_size_bytes = (
  180. self.proposer_worker.get_cache_block_size_bytes())
  181. new_num_gpu_blocks = split_num_cache_blocks_evenly(
  182. scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
  183. num_gpu_blocks)
  184. return new_num_gpu_blocks, num_cpu_blocks
  185. def initialize_cache(self, num_gpu_blocks: int,
  186. num_cpu_blocks: int) -> None:
  187. """Initialize the cache engine of the scorer and proposer workers.
  188. """
  189. self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  190. num_cpu_blocks=num_cpu_blocks)
  191. self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
  192. num_cpu_blocks=num_cpu_blocks)
  193. @torch.inference_mode()
  194. def execute_model(
  195. self,
  196. execute_model_req: Optional[ExecuteModelRequest] = None
  197. ) -> List[SamplerOutput]:
  198. """Perform speculative decoding on the input batch.
  199. """
  200. if self.rank != self._driver_rank:
  201. self._run_non_driver_rank()
  202. return []
  203. if execute_model_req is None:
  204. # This signals that there's no more requests to process for now.
  205. # All workers are running infinite loop with broadcast_tensor_dict,
  206. # and it stops the loop when the driver broadcasts an empty input.
  207. # Send an empty input to notify all other workers to stop their
  208. # execution loop.
  209. broadcast_tensor_dict({}, src=0)
  210. return []
  211. disable_all_speculation = self._should_disable_all_speculation(
  212. execute_model_req)
  213. num_lookahead_slots = execute_model_req.num_lookahead_slots
  214. # Broadcast how many lookahead slots are scheduled for this step, and
  215. # whether all speculation is disabled, to all non-driver workers.
  216. # This is required as if the number of draft model runs changes
  217. # dynamically, the non-driver workers won't know unless we perform a
  218. # communication to inform then.
  219. broadcast_dict = dict(
  220. num_lookahead_slots=num_lookahead_slots,
  221. disable_all_speculation=disable_all_speculation,
  222. )
  223. broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
  224. assert execute_model_req.seq_group_metadata_list is not None, (
  225. "speculative decoding requires non-None seq_group_metadata_list")
  226. self._maybe_disable_speculative_tokens(
  227. disable_all_speculation, execute_model_req.seq_group_metadata_list)
  228. # Speculative decoding is disabled in the following cases:
  229. # 1. Prefill phase: Speculative decoding is not
  230. # used during the prefill phase.
  231. # 2. Auto-disable enabled: The running queue size exceeds
  232. # the specified threshold.
  233. # 3. No request: There are no requests in the batch.
  234. # In any of these cases, the proposer and scorer workers
  235. # are called normally.
  236. if num_lookahead_slots == 0 or len(
  237. execute_model_req.seq_group_metadata_list
  238. ) == 0 or disable_all_speculation:
  239. return self._run_no_spec(execute_model_req,
  240. skip_proposer=disable_all_speculation)
  241. return self._run_speculative_decoding_step(execute_model_req,
  242. num_lookahead_slots)
  243. @torch.inference_mode()
  244. def start_worker_execution_loop(self) -> None:
  245. """Execute model loop to perform speculative decoding
  246. in parallel worker."""
  247. while self._run_non_driver_rank():
  248. pass
  249. def _should_disable_all_speculation(
  250. self, execute_model_req: ExecuteModelRequest) -> bool:
  251. # When the batch size is too large, disable speculative decoding
  252. # to stop trading off throughput for latency.
  253. disable_all_speculation = (execute_model_req.running_queue_size >=
  254. self.disable_by_batch_size)
  255. return disable_all_speculation
  256. def _maybe_disable_speculative_tokens(
  257. self, disable_all_speculation: bool,
  258. seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
  259. if not disable_all_speculation:
  260. return
  261. for seq_group_metadata in seq_group_metadata_list:
  262. # Once num_speculative_tokens is set to 0, the spec decode
  263. # of this request will be disabled forever.
  264. # TODO: We currently store spec decoding specific
  265. # state in the global data structure, but we should maintain
  266. # this state within spec decode worker.
  267. seq_group_metadata.num_speculative_tokens = 0
  268. @nvtx_range("spec_decode_worker._run_no_spec")
  269. def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
  270. skip_proposer: bool) -> List[SamplerOutput]:
  271. """Run a single generation step without any speculation. The input is
  272. sent to the proposer and scorer model so that the KV cache is consistent
  273. between the two. When skip_proposer is True, the proposer model is
  274. not called, meaning that the kv-cache in proposer for requests is not
  275. updated, so they cannot enable spec decode in the rest decoding.
  276. """
  277. if not skip_proposer:
  278. self.proposer_worker.execute_model(execute_model_req)
  279. sampler_output = self.scorer_worker.execute_model(execute_model_req)
  280. assert len(sampler_output) == 1
  281. sampler_output = sampler_output[0]
  282. # Clear device tensors from sampler output. This reduces communication
  283. # overhead when the engine runs in a different process than the workers.
  284. sampler_output.probs = None
  285. sampler_output.sampled_tokens = None
  286. sampler_output.logprobs = None
  287. return [sampler_output]
  288. def _run_non_driver_rank(self) -> bool:
  289. """Run proposer and verifier model in non-driver workers. This is used
  290. for both speculation cases (num_lookahead_slots>0) and non-speculation
  291. cases (e.g. prefill).
  292. Returns True iff there are remaining sequences to process.
  293. """
  294. assert self.rank != self._driver_rank
  295. data = broadcast_tensor_dict(src=self._driver_rank)
  296. if not data:
  297. return False
  298. num_lookahead_slots = data["num_lookahead_slots"]
  299. # Even if num_lookahead_slots is zero, we want to run the proposer model
  300. # as it may have KV.
  301. #
  302. # We run the proposer once per lookahead slot. In the future we should
  303. # delegate how many times it runs to the proposer.
  304. for _ in range(max(num_lookahead_slots, 1)):
  305. self.proposer_worker.execute_model()
  306. self.scorer_worker.execute_model()
  307. return True
  308. @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
  309. def _run_speculative_decoding_step(
  310. self, execute_model_req: ExecuteModelRequest,
  311. num_lookahead_slots: int) -> List[SamplerOutput]:
  312. """Execute a single step of speculative decoding.
  313. This invokes the proposer worker to get k speculative tokens for each
  314. sequence, then scores each speculative token using the scoring worker.
  315. Returns a list of SamplerOutput, each containing a single token per
  316. sequence.
  317. """
  318. assert num_lookahead_slots == execute_model_req.num_lookahead_slots
  319. # Generate proposals using draft worker.
  320. proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
  321. proposal_scores = self.scorer.score_proposals(
  322. execute_model_req,
  323. proposals,
  324. )
  325. accepted_token_ids, target_logprobs = self._verify_tokens(
  326. execute_model_req.seq_group_metadata_list, proposal_scores,
  327. proposals, execute_model_req.num_lookahead_slots)
  328. return self._create_output_sampler_list(
  329. execute_model_req.seq_group_metadata_list,
  330. accepted_token_ids,
  331. target_logprobs=target_logprobs,
  332. k=execute_model_req.num_lookahead_slots)
  333. @nvtx_range("spec_decode_worker._verify_tokens")
  334. def _verify_tokens(
  335. self,
  336. seq_group_metadata_list: List[SequenceGroupMetadata],
  337. proposal_scores: SpeculativeScores,
  338. proposals: SpeculativeProposals,
  339. max_proposal_len: int,
  340. ) -> Tuple[torch.Tensor, torch.Tensor]:
  341. """Determine which speculative tokens are accepted using the
  342. probabilities of each token according to the proposer and scorer models.
  343. Returns a tuple of Tensors, one for the accepted token ids and one for
  344. the logprobs according to the scoring model.
  345. """
  346. proposal_lens_list = proposals.proposal_lens.tolist()
  347. # Aphrodite currently only supports proposal lens equal to zero or the
  348. # batch proposal len. This adds some complexity (splitting the batch
  349. # into spec and non spec sequences) and should be removed in the
  350. # future. It can be done by supporting per-sequence proposal lens.
  351. _, spec_indices = split_batch_by_proposal_len(
  352. seq_group_metadata_list,
  353. proposal_lens_list,
  354. select_proposal_len_zero=False)
  355. _, non_spec_indices = split_batch_by_proposal_len(
  356. seq_group_metadata_list,
  357. proposal_lens_list,
  358. select_proposal_len_zero=True)
  359. original_indices = spec_indices + non_spec_indices
  360. # Get probabilities of target model, excluding bonus token.
  361. proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
  362. # Get non-speculative sampled tokens from target model.
  363. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
  364. # Get bonus tokens from target model.
  365. bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
  366. # Get probabilities according to proposal method.
  367. proposal_probs = proposals.proposal_probs[spec_indices]
  368. # Get proposed tokens.
  369. proposal_token_ids = proposals.proposal_token_ids[spec_indices]
  370. accepted_token_ids = self.rejection_sampler(
  371. target_probs=proposal_verifier_probs,
  372. bonus_token_ids=bonus_token_ids,
  373. draft_probs=proposal_probs,
  374. draft_token_ids=proposal_token_ids,
  375. )
  376. # Append output tokens from non-speculative sequences to
  377. # the accepted token ids tensor.
  378. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
  379. 1).clone()
  380. non_spec_token_ids[:, 1:] = -1
  381. accepted_token_ids = torch.cat(
  382. [accepted_token_ids, non_spec_token_ids])
  383. logprobs = proposal_scores.logprobs
  384. # Rearrange so that results are in the order of the original seq group
  385. # metadata.
  386. accepted_token_ids[original_indices] = accepted_token_ids.clone()
  387. return accepted_token_ids, logprobs
  388. def _create_output_sampler_list(
  389. self,
  390. seq_group_metadata_list: List[SequenceGroupMetadata],
  391. accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
  392. target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
  393. k: int,
  394. ) -> List[SamplerOutput]:
  395. """Given the accepted token ids, create a list of SamplerOutput.
  396. The output is padded with -1 tokens such that each sequence has
  397. the same number of outputs.
  398. """
  399. batch_size, num_steps = accepted_token_ids.shape
  400. # Organize input tensors by step instead of by sequence.
  401. target_logprobs_by_step = target_logprobs.transpose(0, 1)
  402. accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
  403. # Get the logprobs/rank of the accepted tokens.
  404. (accepted_token_id_ranks_by_step,
  405. accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
  406. logprob_tensor=target_logprobs_by_step,
  407. sampled_token_ids=accepted_token_ids_by_step,
  408. )
  409. # Get the top-k logprobs (which may or may not include the logprob of
  410. # the accepted token).
  411. (topk_logprobs_by_step,
  412. topk_indices_by_step) = target_logprobs_by_step.topk(
  413. k=self.scorer_worker.model_config.max_logprobs,
  414. dim=-1,
  415. )
  416. # Get the sequence ids and num_logprobs (sampling parameter) in the
  417. # batch.
  418. seq_ids = get_all_seq_ids(seq_group_metadata_list)
  419. num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
  420. # Serialize all tensors to CPU Python lists.
  421. accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
  422. accepted_token_id_ranks_by_step = (
  423. accepted_token_id_ranks_by_step.tolist())
  424. accepted_token_id_logprobs_by_step = (
  425. accepted_token_id_logprobs_by_step.tolist())
  426. topk_logprobs_by_step = topk_logprobs_by_step.tolist()
  427. topk_indices_by_step = topk_indices_by_step.tolist()
  428. # Construct the output on a per-step, per-sequence basis.
  429. sampler_output_list = []
  430. for step_index in range(num_steps):
  431. if all(token_id == -1
  432. for token_id in accepted_token_ids_by_step[step_index]):
  433. break
  434. step_output_token_ids = []
  435. for sequence_index in range(batch_size):
  436. # Each sequence may have a different num_logprobs; retrieve it.
  437. num_logprobs = num_logprobs_per_seq[sequence_index]
  438. step_output_token_ids.append(
  439. create_sequence_group_output(
  440. token_id=accepted_token_ids_by_step[step_index]
  441. [sequence_index],
  442. token_id_logprob_rank=accepted_token_id_ranks_by_step[
  443. step_index][sequence_index],
  444. token_id_logprob=accepted_token_id_logprobs_by_step[
  445. step_index][sequence_index],
  446. seq_id=seq_ids[sequence_index],
  447. topk_token_ids=topk_indices_by_step[step_index]
  448. [sequence_index][:num_logprobs],
  449. topk_logprobs=topk_logprobs_by_step[step_index]
  450. [sequence_index][:num_logprobs],
  451. ))
  452. sampler_output_list.append(
  453. SamplerOutput(outputs=step_output_token_ids))
  454. maybe_rejsample_metrics = (
  455. self._metrics.maybe_collect_rejsample_metrics(k))
  456. if maybe_rejsample_metrics is not None:
  457. sampler_output_list[
  458. 0].spec_decode_worker_metrics = maybe_rejsample_metrics
  459. return sampler_output_list
  460. @cached_property
  461. def _vocab_size(self) -> int:
  462. """Get the vocab size of the model and make sure it's consistent between
  463. draft and target workers.
  464. """
  465. vocab_sizes = [
  466. worker.vocab_size
  467. for worker in [self.proposer_worker, self.scorer_worker]
  468. ]
  469. assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
  470. return vocab_sizes[0]
  471. @property
  472. def rank(self):
  473. return self.scorer_worker.rank
  474. @property
  475. def device(self):
  476. return self.scorer_worker.device
  477. @property
  478. def _driver_rank(self) -> int:
  479. return 0
  480. def get_cache_block_size_bytes(self):
  481. """Return the size of a cache block in bytes.
  482. This function is only used to compose workers within a SpecDecodeWorker.
  483. We leave composing a SpecDecodeWorker within a SpecDecodeWorker
  484. undefined for now, although it could be implemented in the future.
  485. See https://arxiv.org/abs/2308.04623.
  486. """
  487. raise NotImplementedError
  488. def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
  489. proposer_cache_block_size_bytes: int,
  490. total_num_gpu_blocks: int) -> int:
  491. """Given total_num_gpu_blocks, the number of GPU blocks that could be
  492. allocate to the target model, this function calculates how many blocks
  493. should be given to the draft and target model.
  494. Note that usually the block size, in bytes, of each model is different,
  495. as it's a function of number of KV/layer, number of heads, and hidden
  496. dimension size.
  497. Since the target and draft models allocate the same number of blocks, we
  498. simply calculate the number of blocks where if allocated by both models,
  499. the total memory usage from KV cache is no larger than the number of
  500. blocks allocatable by the target model alone.
  501. """
  502. new_num_gpu_blocks = int(
  503. total_num_gpu_blocks * scorer_cache_block_size_bytes /
  504. (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
  505. return new_num_gpu_blocks