medusa_worker.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import weakref
  2. from typing import List, Optional, Set, Tuple
  3. import torch
  4. from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
  5. SequenceGroupMetadata)
  6. from aphrodite.modeling import SamplingMetadata
  7. from aphrodite.spec_decode.interfaces import SpeculativeProposals
  8. from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
  9. from aphrodite.spec_decode.top1_proposer import Top1Proposer
  10. from aphrodite.task_handler.worker import Worker
  11. class MedusaWorker(NonLLMProposerWorkerBase, Worker):
  12. """Worker for Medusa.
  13. """
  14. def __init__(self, *args, **kwargs):
  15. super().__init__(*args, **kwargs)
  16. # Lazy initialization list.
  17. self._proposer: Top1Proposer
  18. def init_device(self):
  19. super().init_device()
  20. self._proposer = Top1Proposer(
  21. weakref.proxy(self), # type: ignore[arg-type]
  22. self.device,
  23. self.vocab_size,
  24. max_proposal_len=self.max_model_len,
  25. )
  26. def set_include_gpu_probs_tensor(self):
  27. pass
  28. @torch.inference_mode()
  29. def sampler_output(
  30. self,
  31. execute_model_req: ExecuteModelRequest,
  32. sample_len: int,
  33. # Unused parameter.
  34. seq_ids_with_bonus_token_in_last_step: Set[int],
  35. ) -> Tuple[List[SamplerOutput], bool]:
  36. """Run the model forward pass to generate sample_len future tokens.
  37. Returns the list of sampler output, one per layer, along with indicator
  38. of whether torch tensor in sampler output need to be transposed in
  39. latter sampler_output_to_torch logic.
  40. For medusa worker, this indicator shall be False.
  41. """
  42. self._raise_if_unsupported(execute_model_req)
  43. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  44. seq_lens, query_lens = self._prepare_input_tensors(
  45. seq_group_metadata_list)
  46. sampling_metadata = SamplingMetadata.prepare(
  47. seq_group_metadata_list, seq_lens, query_lens, self.device,
  48. self.model_runner.pin_memory)
  49. model_outputs = self.model_runner.model.generate_proposals(
  50. previous_hidden_states=execute_model_req.previous_hidden_states.
  51. hidden_states,
  52. sampling_metadata=sampling_metadata)
  53. return model_outputs, False
  54. def _prepare_input_tensors(
  55. self,
  56. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  57. ) -> Tuple[List[int], List[int]]:
  58. if not seq_group_metadata_list:
  59. return [], []
  60. seq_lens: List[int] = []
  61. query_lens: List[int] = []
  62. for seq_group_metadata in seq_group_metadata_list:
  63. is_prompt = seq_group_metadata.is_prompt
  64. for seq_data in seq_group_metadata.seq_data.values():
  65. seq_data_len = seq_data.get_len()
  66. if is_prompt:
  67. context_len = seq_data.get_num_computed_tokens()
  68. seq_len = min(
  69. seq_data_len,
  70. context_len + seq_group_metadata.token_chunk_size)
  71. seq_lens.append(seq_len)
  72. query_lens.append(seq_len - context_len)
  73. else:
  74. seq_lens.append(seq_data_len)
  75. query_lens.append(1)
  76. return seq_lens, query_lens
  77. def get_spec_proposals(
  78. self,
  79. execute_model_req: ExecuteModelRequest,
  80. seq_ids_with_bonus_token_in_last_step: Set[int],
  81. ) -> SpeculativeProposals:
  82. """Produce speculations given an input batch of sequences. The number of
  83. speculative tokens per sequence is determined by max_proposal_len.
  84. """
  85. return self._proposer.get_spec_proposals(
  86. execute_model_req, seq_ids_with_bonus_token_in_last_step)
  87. def _raise_if_unsupported(
  88. self,
  89. execute_model_req: ExecuteModelRequest,
  90. ) -> None:
  91. """MedusaWorker does not yet implement support for cache swap
  92. operations or beam search.
  93. """
  94. if any([
  95. execute_model_req.blocks_to_swap_in,
  96. execute_model_req.blocks_to_swap_out,
  97. execute_model_req.blocks_to_copy
  98. ]):
  99. raise NotImplementedError(
  100. "MedusaWorker does not support cache operations")
  101. if any(
  102. len(seq_group_metadata.seq_data.keys()) != 1
  103. for seq_group_metadata in
  104. execute_model_req.seq_group_metadata_list):
  105. raise NotImplementedError(
  106. "MedusaWorker does not support beam search.")