multi_step_worker.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import copy
  2. from typing import Dict, List, Tuple
  3. import torch
  4. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  5. from aphrodite.spec_decode.interfaces import SpeculativeProposals
  6. from aphrodite.spec_decode.top1_proposer import Top1Proposer
  7. from aphrodite.task_handler.worker import Worker
  8. class MultiStepWorker(Worker):
  9. """The MultiStepWorker is equivalent to a Worker except that it allows
  10. multiple forward passes in a single call, assuming the scheduler has
  11. allocated enough space to store the additional KV. This reduces overhead
  12. by invoking the scheduler less.
  13. The MultiStepWorker does not support cache swap operations, or beam search.
  14. Cache swap operations do not require large modifications. On the other hand,
  15. beam search requires memory allocations during sequence forks and thus
  16. requires more thought for MultiStepWorker support.
  17. """
  18. def __init__(self, *args, **kwargs):
  19. super().__init__(*args, **kwargs)
  20. # Lazy initialization list.
  21. self._proposer: Top1Proposer
  22. def init_device(self):
  23. super().init_device()
  24. self._proposer = Top1Proposer(
  25. self,
  26. self.device,
  27. self.vocab_size,
  28. max_proposal_len=self.max_model_len,
  29. )
  30. def set_include_gpu_probs_tensor(self):
  31. # Need include_gpu_probs_tensor for multi_step_worker
  32. self.model_runner.model.sampler.include_gpu_probs_tensor = True
  33. @torch.inference_mode()
  34. def sampler_output(
  35. self,
  36. seq_group_metadata_list: List[SequenceGroupMetadata],
  37. blocks_to_swap_in: Dict[int, int],
  38. blocks_to_swap_out: Dict[int, int],
  39. blocks_to_copy: Dict[int, List[int]],
  40. sample_len: int,
  41. ) -> Tuple[List[SamplerOutput], bool]:
  42. """Run the model forward pass sample_len times. Returns the list of
  43. sampler output, one per model forward pass.
  44. """
  45. self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
  46. blocks_to_swap_out, blocks_to_copy)
  47. # Shallow copy input data so modifications (such as appending tokens)
  48. # do not cause side-effects.
  49. copied_seq_group_metadata_list = self._shallow_copy_inputs(
  50. seq_group_metadata_list)
  51. # Assert enough KV space for sample_len tokens per sequence.
  52. self._assert_enough_kv_space(seq_group_metadata_list, sample_len)
  53. # Run model sample_len times.
  54. model_outputs = []
  55. for _ in range(sample_len):
  56. model_output = super().execute_model(
  57. seq_group_metadata_list=copied_seq_group_metadata_list,
  58. blocks_to_swap_in=blocks_to_swap_in,
  59. blocks_to_swap_out=blocks_to_swap_out,
  60. blocks_to_copy=blocks_to_copy,
  61. )
  62. assert (len(model_output) == 1
  63. ), "composing multistep workers not supported"
  64. model_output = model_output[0]
  65. self._append_new_tokens(model_output,
  66. copied_seq_group_metadata_list)
  67. model_outputs.append(model_output)
  68. return model_outputs, True
  69. def get_spec_proposals(
  70. self,
  71. seq_group_metadata_list: List[SequenceGroupMetadata],
  72. blocks_to_swap_in: Dict[int, int],
  73. blocks_to_swap_out: Dict[int, int],
  74. blocks_to_copy: Dict[int, List[int]],
  75. max_proposal_len: int,
  76. ) -> SpeculativeProposals:
  77. """Produce speculations given an input batch of sequences. The number of
  78. speculative tokens per sequence is determined by max_proposal_len.
  79. """
  80. return self._proposer.get_proposals(
  81. seq_group_metadata_list,
  82. blocks_to_swap_in,
  83. blocks_to_swap_out,
  84. blocks_to_copy,
  85. max_proposal_len,
  86. )
  87. def _append_new_tokens(
  88. self, model_output: SamplerOutput,
  89. seq_group_metadata_list: SequenceGroupMetadata) -> None:
  90. """Given model output from a single run, append the tokens to the
  91. sequences. This is normally done outside of the worker, but it is
  92. required if the worker is to perform multiple forward passes.
  93. """
  94. for seq_group_metadata, sequence_group_outputs in zip(
  95. seq_group_metadata_list, model_output):
  96. seq_group_metadata.is_prompt = False
  97. for seq_output in sequence_group_outputs.samples:
  98. # NOTE: Beam search is not supported, so we can assume that
  99. # parent_seq_id == seq_id.
  100. seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
  101. token_id = seq_output.output_token
  102. token_logprob = seq_output.logprobs[token_id]
  103. seq.append_token_id(token_id, token_logprob.logprob)
  104. def _shallow_copy_inputs(
  105. self, seq_group_metadata_list: List[SequenceGroupMetadata]
  106. ) -> List[SequenceGroupMetadata]:
  107. """Copy input data structures to remove side-effects when input data
  108. structures are shared with other modules.
  109. Helpful when the Aphrodite scheduler runs in the same process as the
  110. worker. The alternative is deep-copying (or other form of deep copy);
  111. this has performance downsides.
  112. """
  113. # Shallow-copy the list of SequenceGroupMetadata. This allows us to
  114. # append tokens and change is_prompt without external side-effects.
  115. new_seq_group_metadata_list = []
  116. for old_seq_group_metadata in seq_group_metadata_list:
  117. # We must shallow-copy seq_group_metadata as is_prompt could change.
  118. seq_group_metadata = copy.copy(old_seq_group_metadata)
  119. new_seq_group_metadata_list.append(seq_group_metadata)
  120. # We must shallow-copy seq_data as we will append token ids
  121. new_seq_data = {}
  122. for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
  123. new_seq_data[seq_id] = copy.copy(old_seq_data)
  124. new_seq_data[
  125. seq_id].output_token_ids = old_seq_data.output_token_ids[:]
  126. seq_group_metadata.seq_data = new_seq_data
  127. return new_seq_group_metadata_list
  128. def _assert_enough_kv_space(
  129. self, seq_group_metadata_list: List[SequenceGroupMetadata],
  130. num_steps: int) -> None:
  131. """Assert there are enough physical blocks per sequence to store the
  132. current KV plus additional KV from num_steps tokens.
  133. """
  134. assert self.model_runner.block_size is not None
  135. for seq_group_metadata in seq_group_metadata_list:
  136. # Only one seq_id is guaranteed because there is no beam search.
  137. seq_id = list(seq_group_metadata.seq_data.keys())[0]
  138. seq = seq_group_metadata.seq_data[seq_id]
  139. # After num_steps, the seq len will be the current seq len
  140. # plus one token per step.
  141. final_seq_len = seq.get_len() + num_steps
  142. # We will have final_seq_len - 1 KV because Aphrodite saves KV for
  143. # a token in the iteration after the token was generated.
  144. required_num_kv_slots = final_seq_len - 1
  145. # The allocated number of kv slots is the number of allocated blocks
  146. # times the number of slots of block.
  147. number_physical_blocks = len(
  148. seq_group_metadata.block_tables[seq_id])
  149. allocated_kv_slots = (number_physical_blocks *
  150. self.model_runner.block_size)
  151. if required_num_kv_slots > allocated_kv_slots:
  152. request_id = seq_group_metadata.request_id
  153. raise ValueError(
  154. "The worker attempted to run "
  155. f"{num_steps} times but found insufficient KV space for "
  156. f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
  157. f"{required_num_kv_slots=}).")
  158. def _raise_if_unsupported(
  159. self,
  160. seq_group_metadata_list: List[SequenceGroupMetadata],
  161. blocks_to_swap_in: Dict[int, int],
  162. blocks_to_swap_out: Dict[int, int],
  163. blocks_to_copy: Dict[int, List[int]],
  164. ) -> None:
  165. """MultiStepWorker does not yet implement support for cache swap
  166. operations or beam search.
  167. """
  168. if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
  169. raise NotImplementedError(
  170. "MultiStepWorker does not support cache operations")
  171. if any(
  172. len(seq_group_metadata.seq_data.keys()) != 1
  173. for seq_group_metadata in seq_group_metadata_list):
  174. raise NotImplementedError(
  175. "MultiStepWorker does not support beam search.")