medusa_worker.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import weakref
  2. from typing import List, Optional, Set, Tuple
  3. import torch
  4. from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
  5. SequenceGroupMetadata)
  6. from aphrodite.modeling import SamplingMetadata
  7. from aphrodite.spec_decode.interfaces import SpeculativeProposals
  8. from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
  9. from aphrodite.spec_decode.top1_proposer import Top1Proposer
  10. from aphrodite.task_handler.worker import Worker
  11. class MedusaWorker(NonLLMProposerWorkerBase, Worker):
  12. """Worker for Medusa.
  13. """
  14. def __init__(self, *args, **kwargs):
  15. super().__init__(*args, **kwargs)
  16. # Lazy initialization list.
  17. self._proposer: Top1Proposer
  18. def init_device(self):
  19. super().init_device()
  20. self._proposer = Top1Proposer(
  21. weakref.proxy(self), # type: ignore[arg-type]
  22. self.device,
  23. self.vocab_size,
  24. max_proposal_len=self.max_model_len,
  25. )
  26. def set_include_gpu_probs_tensor(self):
  27. pass
  28. def set_should_modify_greedy_probs_inplace(self):
  29. pass
  30. @torch.inference_mode()
  31. def sampler_output(
  32. self,
  33. execute_model_req: ExecuteModelRequest,
  34. sample_len: int,
  35. # Unused parameter.
  36. seq_ids_with_bonus_token_in_last_step: Set[int],
  37. ) -> Tuple[List[SamplerOutput], bool]:
  38. """Run the model forward pass to generate sample_len future tokens.
  39. Returns the list of sampler output, one per layer, along with indicator
  40. of whether torch tensor in sampler output need to be transposed in
  41. latter sampler_output_to_torch logic.
  42. For medusa worker, this indicator shall be False.
  43. """
  44. self._raise_if_unsupported(execute_model_req)
  45. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  46. seq_lens, query_lens = self._prepare_input_tensors(
  47. seq_group_metadata_list)
  48. generators = self.model_runner.get_generators(
  49. execute_model_req.finished_requests_ids)
  50. sampling_metadata = SamplingMetadata.prepare(
  51. seq_group_metadata_list, seq_lens, query_lens, self.device,
  52. self.model_runner.pin_memory, generators)
  53. model_outputs = self.model_runner.model.generate_proposals(
  54. previous_hidden_states=execute_model_req.previous_hidden_states.
  55. hidden_states,
  56. sampling_metadata=sampling_metadata)
  57. return model_outputs, False
  58. def _prepare_input_tensors(
  59. self,
  60. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  61. ) -> Tuple[List[int], List[int]]:
  62. if not seq_group_metadata_list:
  63. return [], []
  64. seq_lens: List[int] = []
  65. query_lens: List[int] = []
  66. for seq_group_metadata in seq_group_metadata_list:
  67. is_prompt = seq_group_metadata.is_prompt
  68. for seq_data in seq_group_metadata.seq_data.values():
  69. seq_data_len = seq_data.get_len()
  70. if is_prompt:
  71. context_len = seq_data.get_num_computed_tokens()
  72. seq_len = min(
  73. seq_data_len,
  74. context_len + seq_group_metadata.token_chunk_size)
  75. seq_lens.append(seq_len)
  76. query_lens.append(seq_len - context_len)
  77. else:
  78. seq_lens.append(seq_data_len)
  79. query_lens.append(1)
  80. return seq_lens, query_lens
  81. def get_spec_proposals(
  82. self,
  83. execute_model_req: ExecuteModelRequest,
  84. seq_ids_with_bonus_token_in_last_step: Set[int],
  85. ) -> SpeculativeProposals:
  86. """Produce speculations given an input batch of sequences. The number of
  87. speculative tokens per sequence is determined by max_proposal_len.
  88. """
  89. return self._proposer.get_spec_proposals(
  90. execute_model_req, seq_ids_with_bonus_token_in_last_step)
  91. def _raise_if_unsupported(
  92. self,
  93. execute_model_req: ExecuteModelRequest,
  94. ) -> None:
  95. """MedusaWorker does not yet implement support for cache swap
  96. operations or beam search.
  97. """
  98. if any([
  99. execute_model_req.blocks_to_swap_in,
  100. execute_model_req.blocks_to_swap_out,
  101. execute_model_req.blocks_to_copy
  102. ]):
  103. raise NotImplementedError(
  104. "MedusaWorker does not support cache operations")
  105. if any(
  106. len(seq_group_metadata.seq_data.keys()) != 1
  107. for seq_group_metadata in
  108. execute_model_req.seq_group_metadata_list):
  109. raise NotImplementedError(
  110. "MedusaWorker does not support beam search.")