medusa_worker.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import weakref
  2. from typing import List, Optional, Set, Tuple
  3. import torch
  4. from aphrodite.common.sequence import (ExecuteModelRequest,
  5. SequenceGroupMetadata)
  6. from aphrodite.modeling.layers.sampler import SamplerOutput
  7. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  8. from aphrodite.spec_decode.interfaces import SpeculativeProposals
  9. from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
  10. from aphrodite.spec_decode.top1_proposer import Top1Proposer
  11. from aphrodite.worker.worker import Worker
  12. class MedusaWorker(NonLLMProposerWorkerBase, Worker):
  13. """Worker for Medusa.
  14. """
  15. def __init__(self, *args, **kwargs):
  16. super().__init__(*args, **kwargs)
  17. # Lazy initialization list.
  18. self._proposer: Top1Proposer
  19. def init_device(self):
  20. super().init_device()
  21. self._proposer = Top1Proposer(
  22. weakref.proxy(self), # type: ignore[arg-type]
  23. self.device,
  24. self.vocab_size,
  25. max_proposal_len=self.max_model_len,
  26. )
  27. def set_include_gpu_probs_tensor(self):
  28. pass
  29. def set_should_modify_greedy_probs_inplace(self):
  30. pass
  31. @torch.inference_mode()
  32. def sampler_output(
  33. self,
  34. execute_model_req: ExecuteModelRequest,
  35. sample_len: int,
  36. # Unused parameter.
  37. seq_ids_with_bonus_token_in_last_step: Set[int],
  38. ) -> Tuple[List[SamplerOutput], bool]:
  39. """Run the model forward pass to generate sample_len future tokens.
  40. Returns the list of sampler output, one per layer, along with indicator
  41. of whether torch tensor in sampler output need to be transposed in
  42. latter sampler_output_to_torch logic.
  43. For medusa worker, this indicator shall be False.
  44. """
  45. self._raise_if_unsupported(execute_model_req)
  46. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  47. seq_lens, query_lens = self._prepare_input_tensors(
  48. seq_group_metadata_list)
  49. generators = self.model_runner.get_generators(
  50. execute_model_req.finished_requests_ids)
  51. sampling_metadata = SamplingMetadata.prepare(
  52. seq_group_metadata_list, seq_lens, query_lens, self.device,
  53. self.model_runner.pin_memory, generators)
  54. model_outputs = self.model_runner.model.generate_proposals(
  55. previous_hidden_states=execute_model_req.previous_hidden_states.
  56. hidden_states,
  57. sampling_metadata=sampling_metadata)
  58. return model_outputs, False
  59. def _prepare_input_tensors(
  60. self,
  61. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  62. ) -> Tuple[List[int], List[int]]:
  63. if not seq_group_metadata_list:
  64. return [], []
  65. seq_lens: List[int] = []
  66. query_lens: List[int] = []
  67. for seq_group_metadata in seq_group_metadata_list:
  68. is_prompt = seq_group_metadata.is_prompt
  69. for seq_data in seq_group_metadata.seq_data.values():
  70. seq_data_len = seq_data.get_len()
  71. if is_prompt:
  72. context_len = seq_data.get_num_computed_tokens()
  73. seq_len = min(
  74. seq_data_len,
  75. context_len + seq_group_metadata.token_chunk_size)
  76. seq_lens.append(seq_len)
  77. query_lens.append(seq_len - context_len)
  78. else:
  79. seq_lens.append(seq_data_len)
  80. query_lens.append(1)
  81. return seq_lens, query_lens
  82. def get_spec_proposals(
  83. self,
  84. execute_model_req: ExecuteModelRequest,
  85. seq_ids_with_bonus_token_in_last_step: Set[int],
  86. ) -> SpeculativeProposals:
  87. """Produce speculations given an input batch of sequences. The number of
  88. speculative tokens per sequence is determined by max_proposal_len.
  89. """
  90. return self._proposer.get_spec_proposals(
  91. execute_model_req, seq_ids_with_bonus_token_in_last_step)
  92. def _raise_if_unsupported(
  93. self,
  94. execute_model_req: ExecuteModelRequest,
  95. ) -> None:
  96. """MedusaWorker does not yet implement support for cache swap
  97. operations or beam search.
  98. """
  99. if any([
  100. execute_model_req.blocks_to_swap_in,
  101. execute_model_req.blocks_to_swap_out,
  102. execute_model_req.blocks_to_copy
  103. ]):
  104. raise NotImplementedError(
  105. "MedusaWorker does not support cache operations")
  106. if any(
  107. len(seq_group_metadata.seq_data.keys()) != 1
  108. for seq_group_metadata in
  109. execute_model_req.seq_group_metadata_list):
  110. raise NotImplementedError(
  111. "MedusaWorker does not support beam search.")