1
0

multi_step_worker.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. import copy
  2. import weakref
  3. from typing import Dict, List, Set, Tuple
  4. import torch
  5. from aphrodite.common.sequence import (ExecuteModelRequest, HiddenStates,
  6. SamplerOutput, SequenceData,
  7. SequenceGroupMetadata)
  8. from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
  9. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  10. SpeculativeProposer)
  11. from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
  12. from aphrodite.spec_decode.top1_proposer import Top1Proposer
  13. from aphrodite.task_handler.worker import Worker
  14. class MultiStepWorker(Worker, ProposerWorkerBase):
  15. """The MultiStepWorker is equivalent to a Worker except that it allows
  16. multiple forward passes in a single call, assuming the scheduler has
  17. allocated enough space to store the additional KV. This reduces overhead
  18. by invoking the scheduler less.
  19. The MultiStepWorker does not support cache swap operations, or beam search.
  20. Cache swap operations do not require large modifications. On the other hand,
  21. beam search requires memory allocations during sequence forks and thus
  22. requires more thought for MultiStepWorker support.
  23. """
  24. def __init__(self, *args, **kwargs):
  25. super().__init__(*args, **kwargs)
  26. # Lazy initialization list.
  27. self._proposer: SpeculativeProposer
  28. def init_device(self) -> None:
  29. super().init_device()
  30. self._proposer = Top1Proposer(
  31. weakref.proxy(self), # type: ignore[arg-type]
  32. self.device,
  33. self.vocab_size,
  34. max_proposal_len=self.max_model_len,
  35. )
  36. def set_include_gpu_probs_tensor(self) -> None:
  37. # Need include_gpu_probs_tensor for MultiStepWorker
  38. self.model_runner.model.sampler.include_gpu_probs_tensor = True
  39. def set_should_modify_greedy_probs_inplace(self) -> None:
  40. self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
  41. True)
  42. @torch.inference_mode()
  43. def sampler_output(
  44. self,
  45. execute_model_req: ExecuteModelRequest,
  46. sample_len: int,
  47. seq_ids_with_bonus_token_in_last_step: Set[int],
  48. ) -> Tuple[List[SamplerOutput], bool]:
  49. """Run the model forward pass sample_len times. Returns the list of
  50. sampler output, one per model forward pass, along with indicator of
  51. whether torch tensor in sampler output need to be transposed in latter
  52. sampler_output_to_torch logic.
  53. For multi step worker, this indicator shall be True.
  54. """
  55. self._raise_if_unsupported(execute_model_req)
  56. # Expand the batch for sequences with a bonus token.
  57. # Perform a forward pass on the expanded batch and filter the
  58. # response to retain only the original sequences' responses.
  59. expanded_request, indices_of_seq_with_bonus_tokens =\
  60. self._expand_execute_model_request(
  61. execute_model_req, seq_ids_with_bonus_token_in_last_step)
  62. # Run model sample_len times.
  63. model_outputs: List[SamplerOutput] = []
  64. if isinstance(
  65. self.model_runner, TP1DraftModelRunner
  66. ) and self.model_runner.supports_gpu_multi_step(expanded_request):
  67. # Here we run the draft_model_runner with multi-step prepare
  68. # on the GPU directly
  69. expanded_request.num_steps = sample_len
  70. model_outputs = self.execute_model(
  71. execute_model_req=expanded_request)
  72. else:
  73. # Here we run multi-step directly, with every step prepared
  74. # on the CPU.
  75. # TODO: Remove this branch once DraftModelRunner supports TP>1
  76. # and other restrictions that are part of DraftModelRunner's
  77. # supports_gpu_multi_step(..)
  78. for _ in range(sample_len):
  79. model_output: List[SamplerOutput] = super().execute_model(
  80. execute_model_req=expanded_request)
  81. assert (len(model_output) == 1
  82. ), "composing multistep workers not supported"
  83. model_output = model_output[0]
  84. self._append_new_tokens(
  85. model_output, expanded_request.seq_group_metadata_list)
  86. model_outputs.append(model_output)
  87. filtered_model_outputs = self._filter_model_output(
  88. model_outputs, indices_of_seq_with_bonus_tokens)
  89. return filtered_model_outputs, True
  90. @staticmethod
  91. def _expand_execute_model_request(
  92. execute_model_req: ExecuteModelRequest,
  93. seq_with_bonus_token_in_last_step: set,
  94. ) -> Tuple[ExecuteModelRequest, List[int]]:
  95. """
  96. Expands the execute model request based on sequences with bonus
  97. tokens.
  98. For each sequence with a bonus token, this method creates a new
  99. sequence without the bonus token and adds it to the execute model
  100. request. The original sequence groups are also retained. The indices
  101. of the original sequence groups are returned for further processing.
  102. Args:
  103. execute_model_req (ExecuteModelRequest): The original execute
  104. model request.
  105. seq_with_bonus_token_in_last_step (set): Set of sequence IDs that
  106. contain bonus tokens.
  107. Returns:
  108. Tuple[ExecuteModelRequest, List[int]]: The updated execute model
  109. request with expanded sequences and a list of indices corresponding
  110. to the original sequence groups.
  111. """
  112. updated_seq_group_metadata_list: List[SequenceGroupMetadata] = []
  113. updated_execute_model_req = execute_model_req.clone(
  114. updated_seq_group_metadata_list)
  115. indices_of_original_sequence_groups = []
  116. for seq_group in execute_model_req.seq_group_metadata_list:
  117. seq_group_has_bonus_tokens = False
  118. for seq_id, _ in seq_group.seq_data.items():
  119. # Identify sequences with bonus tokens in the sequence group.
  120. if seq_id in seq_with_bonus_token_in_last_step:
  121. seq_group_has_bonus_tokens = True
  122. break
  123. if seq_group_has_bonus_tokens:
  124. #Create new sequences without the last bonus token. These new
  125. # sequence have the same sequence id as the original sequence.
  126. # We create a new sequence group and add them there.
  127. updated_seq_group_without_bonus_token = \
  128. MultiStepWorker._copy_seq_metadata_excluding_last_token(
  129. seq_group, seq_with_bonus_token_in_last_step)
  130. updated_seq_group_metadata_list.append(
  131. updated_seq_group_without_bonus_token)
  132. # Add the original sequence group.
  133. updated_seq_group_metadata_list.append(
  134. MultiStepWorker._shallow_copy_seq_group_metadata(seq_group))
  135. # Record the index of the original sequence group.
  136. indices_of_original_sequence_groups.append(
  137. len(updated_seq_group_metadata_list) - 1)
  138. updated_execute_model_req.seq_group_metadata_list =\
  139. updated_seq_group_metadata_list
  140. if isinstance(updated_execute_model_req.previous_hidden_states,
  141. HiddenStates):
  142. updated_execute_model_req.previous_hidden_states\
  143. .expand_with_bonus_tokens(seq_with_bonus_token_in_last_step)
  144. return updated_execute_model_req, indices_of_original_sequence_groups
  145. @staticmethod
  146. def _filter_model_output(
  147. expanded_batch_outputs: List[SamplerOutput],
  148. output_indices_to_retain: List[int]) -> List[SamplerOutput]:
  149. """
  150. Filters the model output to include only the specified sequence
  151. outputs. This method contracts the expanded batch output from the
  152. model to retain the outputs of only those sequences indicated by the
  153. provided indices.
  154. Args:
  155. expanded_batch_output (List[SamplerOutput]): The expanded output
  156. batch from the model.
  157. output_indices_to_retain (List[int]): Indices of the model outputs
  158. to retain.
  159. Returns:
  160. List[SamplerOutput]: A list containing the filtered model
  161. outputs for the specified indices.
  162. """
  163. return [
  164. SamplerOutput(
  165. outputs=[
  166. expanded_batch_output.outputs[i]
  167. for i in output_indices_to_retain
  168. ] if len(expanded_batch_output.outputs) > 0 else [],
  169. sampled_token_probs=(
  170. expanded_batch_output.
  171. sampled_token_probs[output_indices_to_retain]
  172. if expanded_batch_output.sampled_token_probs is not None
  173. else None),
  174. logprobs=(
  175. expanded_batch_output.logprobs[output_indices_to_retain]
  176. if expanded_batch_output.logprobs is not None else None),
  177. sampled_token_ids=(expanded_batch_output.
  178. sampled_token_ids[output_indices_to_retain]
  179. if expanded_batch_output.sampled_token_ids
  180. is not None else None))
  181. for expanded_batch_output in expanded_batch_outputs
  182. ]
  183. def get_spec_proposals(
  184. self,
  185. execute_model_req: ExecuteModelRequest,
  186. seq_ids_with_bonus_token_in_last_step: set,
  187. ) -> SpeculativeProposals:
  188. """Produce speculations given an input batch of sequences. The number of
  189. speculative tokens per sequence is determined by max_proposal_len.
  190. """
  191. return self._proposer.get_spec_proposals(
  192. execute_model_req, seq_ids_with_bonus_token_in_last_step)
  193. @staticmethod
  194. def _append_new_tokens(
  195. model_output: List[SamplerOutput],
  196. seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
  197. """Given model output from a single run, append the tokens to the
  198. sequences. This is normally done outside of the worker, but it is
  199. required if the worker is to perform multiple forward passes.
  200. """
  201. for seq_group_metadata, sequence_group_outputs in zip(
  202. seq_group_metadata_list, model_output):
  203. seq_group_metadata.is_prompt = False
  204. for seq_output in sequence_group_outputs.samples:
  205. # NOTE: Beam search is not supported, so we can assume that
  206. # parent_seq_id == seq_id.
  207. seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
  208. token_id = seq_output.output_token
  209. token_logprob = seq_output.logprobs[token_id]
  210. seq.append_token_id(token_id, token_logprob.logprob)
  211. seq.update_num_computed_tokens(1)
  212. @staticmethod
  213. def _shallow_copy_seq_group_metadata(
  214. seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata:
  215. """Copy input data structures to remove side-effects when input data
  216. structures are shared with other modules.
  217. Helpful when the Aphrodite scheduler runs in the same process as
  218. the worker. The alternative is deep-copying (or other form of deep
  219. copy); this has performance downsides.
  220. """
  221. # Shallow-copy the SequenceGroupMetadata. This allows us to
  222. # append tokens and change is_prompt without external side-effects.
  223. # We must shallow-copy seq_group_metadata as is_prompt could change.
  224. new_seq_group_metadata = copy.copy(seq_group_metadata)
  225. # We must shallow-copy seq_data as we will append token ids
  226. new_seq_data: Dict[int, SequenceData] = {}
  227. for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
  228. new_seq_data[seq_id] = copy.copy(old_seq_data)
  229. new_seq_data[seq_id].output_token_ids =\
  230. old_seq_data.output_token_ids[:]
  231. new_seq_group_metadata.seq_data = new_seq_data
  232. return new_seq_group_metadata
  233. @staticmethod
  234. def _copy_seq_metadata_excluding_last_token(
  235. seq_group_metadata: SequenceGroupMetadata,
  236. seq_ids_to_copy: Set[int],
  237. ) -> SequenceGroupMetadata:
  238. """
  239. Creates a shallow copy of the given SequenceGroupMetadata, retaining
  240. only the sequence IDs specified in seq_ids_to_copy. For each of these
  241. sequence IDs, all output_token_ids except the last one are copied.
  242. Sequence IDs not in seq_ids_to_copy are excluded from the copy.
  243. Parameters:
  244. seq_group_metadata (SequenceGroupMetadata): The original sequence
  245. group metadata.
  246. seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
  247. copy.
  248. Returns:
  249. SequenceGroupMetadata: A shallow copy of the sequence group metadata
  250. with the specified modifications.
  251. """
  252. # Shallow-copy the SequenceGroupMetadata.
  253. new_seq_group_metadata = copy.copy(seq_group_metadata)
  254. # Shallow-copy seq_data and modify the output_token_ids.
  255. new_seq_data: Dict[int, SequenceData] = {}
  256. for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
  257. if (seq_id in seq_ids_to_copy):
  258. new_seq_data[seq_id] = copy.copy(old_seq_data)
  259. # Copy all the output token ids except the last.
  260. # Also reduce num_computed_tokens by 1 since we are not
  261. # including the last output token.
  262. # NOTE: num_computed_tokens is not directly used by the
  263. # speculative decoding workers, as it is only relevant for
  264. # chunked prefill, which is disabled for speculative decoding.
  265. # However, to maintain consistency in num_computed_tokens,
  266. # we update it here.
  267. new_seq_data[seq_id].output_token_ids =\
  268. old_seq_data.output_token_ids[:-1]
  269. new_seq_data[seq_id].update_num_computed_tokens(-1)
  270. new_seq_group_metadata.seq_data = new_seq_data
  271. return new_seq_group_metadata
  272. def _assert_enough_kv_space(
  273. self, seq_group_metadata_list: List[SequenceGroupMetadata],
  274. num_steps: int) -> None:
  275. """Assert there are enough physical blocks per sequence to store the
  276. current KV plus additional KV from num_steps tokens.
  277. """
  278. assert self.model_runner.block_size is not None
  279. for seq_group_metadata in seq_group_metadata_list:
  280. # Only one seq_id is guaranteed because there is no beam search.
  281. seq_id = list(seq_group_metadata.seq_data.keys())[0]
  282. seq = seq_group_metadata.seq_data[seq_id]
  283. # After num_steps, the seq len will be the current seq len
  284. # plus one token per step.
  285. final_seq_len = seq.get_len() + num_steps
  286. # We will have final_seq_len - 1 KV because Aphrodite saves KV for
  287. # a token in the iteration after the token was generated.
  288. required_num_kv_slots = final_seq_len - 1
  289. # The allocated number of kv slots is the number of allocated blocks
  290. # times the number of slots of block.
  291. number_physical_blocks = len(
  292. seq_group_metadata.block_tables[seq_id])
  293. allocated_kv_slots = (number_physical_blocks *
  294. self.model_runner.block_size)
  295. if required_num_kv_slots > allocated_kv_slots:
  296. request_id = seq_group_metadata.request_id
  297. raise ValueError(
  298. "The worker attempted to run "
  299. f"{num_steps} times but found insufficient KV space for "
  300. f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
  301. f"{required_num_kv_slots=}).")
  302. def _raise_if_unsupported(
  303. self,
  304. execute_model_req: ExecuteModelRequest,
  305. ) -> None:
  306. """MultiStepWorker does not yet implement support for cache swap
  307. operations or beam search.
  308. """
  309. if any([
  310. execute_model_req.blocks_to_swap_in,
  311. execute_model_req.blocks_to_swap_out,
  312. execute_model_req.blocks_to_copy
  313. ]):
  314. raise NotImplementedError(
  315. "MultiStepWorker does not support cache operations")
  316. if any(
  317. len(seq_group_metadata.seq_data.keys()) != 1
  318. for seq_group_metadata in
  319. execute_model_req.seq_group_metadata_list):
  320. raise NotImplementedError(
  321. "MultiStepWorker does not support beam search.")