multi_step_worker.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from dataclasses import dataclass
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  5. from aphrodite.distributed import broadcast_tensor_dict, get_pp_group
  6. from aphrodite.task_handler.model_runner_base import BroadcastableModelInput
  7. from aphrodite.task_handler.multi_step_model_runner import (
  8. MultiStepModelRunner, StatefulModelInput)
  9. from aphrodite.task_handler.worker import Worker, WorkerInput
  10. @dataclass
  11. class MultiStepState:
  12. worker_input: WorkerInput
  13. model_input: StatefulModelInput
  14. class MultiStepWorker(Worker):
  15. def __init__(self, *args, **kwargs):
  16. super().__init__(*args, **kwargs)
  17. base_model_runner = self.model_runner
  18. # for multi-step model, wrap the model runner with MultiStepModelRunner
  19. self.model_runner = MultiStepModelRunner(
  20. base_model_runner,
  21. base_model_runner.model_config,
  22. base_model_runner.parallel_config,
  23. base_model_runner.scheduler_config,
  24. base_model_runner.device_config,
  25. base_model_runner.cache_config,
  26. load_config=base_model_runner.load_config,
  27. lora_config=self.lora_config,
  28. kv_cache_dtype=self.cache_config.cache_dtype,
  29. is_driver_worker=base_model_runner.is_driver_worker,
  30. prompt_adapter_config=base_model_runner.prompt_adapter_config,
  31. )
  32. pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
  33. self.multi_step_states: List[Optional[MultiStepState]] = [
  34. None
  35. ] * pipeline_parallel_size
  36. self.temp_output = None
  37. def _get_driver_input_and_broadcast(
  38. self, execute_model_req: ExecuteModelRequest
  39. ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
  40. """
  41. Get the driver input and broadcast it to other workers.
  42. """
  43. assert self.is_driver_worker
  44. virtual_engine = execute_model_req.virtual_engine
  45. is_first_multi_step = execute_model_req.is_first_multi_step
  46. if is_first_multi_step:
  47. # on first step we prepare the worker input and model input normally
  48. worker_input: WorkerInput = self.prepare_worker_input(
  49. execute_model_req=execute_model_req
  50. )
  51. model_input: StatefulModelInput = (
  52. self.model_runner.prepare_model_input(
  53. execute_model_req.seq_group_metadata_list,
  54. execute_model_req.virtual_engine,
  55. execute_model_req.finished_requests_ids,
  56. )
  57. )
  58. else:
  59. # on subsequent steps we reuse the worker input and model input
  60. multi_step_state = self.multi_step_states[virtual_engine]
  61. worker_input = multi_step_state.worker_input
  62. model_input = multi_step_state.model_input
  63. frozen_model_input = model_input.frozen_model_input
  64. assert frozen_model_input is not None
  65. assert frozen_model_input.attn_metadata is not None
  66. # clear the cached decode metadata so that it can be recomputed on
  67. # the workers
  68. frozen_model_input.attn_metadata._cached_decode_metadata = None
  69. model_input.is_first_multi_step = is_first_multi_step
  70. model_input.is_last_step = execute_model_req.is_last_step
  71. if not is_first_multi_step:
  72. # we broadcast the last sampled token ids to all TP workers so they
  73. # can update their model input metadata in-place.
  74. self._prepare_last_sampled_token_ids_for_tp_workers(
  75. execute_model_req=execute_model_req, model_input=model_input
  76. )
  77. if self.do_metadata_broadcast:
  78. broadcast_data = worker_input.as_broadcastable_tensor_dict()
  79. broadcast_data.update(model_input.as_broadcastable_tensor_dict())
  80. broadcast_tensor_dict(broadcast_data, src=0)
  81. # Retuning empty dict here to keep this compatible with
  82. # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
  83. return model_input, worker_input, {}
  84. def _prepare_last_sampled_token_ids_for_tp_workers(
  85. self,
  86. execute_model_req: ExecuteModelRequest,
  87. model_input: StatefulModelInput,
  88. ) -> None:
  89. """
  90. Prepare the last sampled token ids for TP workers. If it's the last
  91. PP rank, then the last sampled token ids are already in the model_input.
  92. If it is NOT the last PP rank, then we need to get the last sampled
  93. token that is cached in the execute_model_req.
  94. """
  95. if get_pp_group().is_last_rank:
  96. assert (
  97. model_input.cached_outputs[-1].sampler_output.sampled_token_ids
  98. is None
  99. )
  100. assert model_input.cached_outputs[-1].sampled_token_ids is not None
  101. model_input.last_sampled_token_ids = model_input.cached_outputs[
  102. -1
  103. ].sampled_token_ids
  104. # free sampled token ids from the previous step if it has been
  105. # pythonized. Cannot free the last sampled token ids because
  106. # we need it for GPU advance_step.
  107. for output in model_input.cached_outputs[:-1]:
  108. if output.pythonized:
  109. output.sampled_token_ids = None
  110. else:
  111. # otherwise we need to get the cached sampled token ids from the
  112. # execute_model_req
  113. assert execute_model_req.last_sampled_token_ids is not None
  114. model_input.last_sampled_token_ids = (
  115. execute_model_req.last_sampled_token_ids.cuda()
  116. )
  117. model_input.add_sampler_output(
  118. SamplerOutput(outputs=[], sampled_token_ids=None),
  119. model_input.last_sampled_token_ids,
  120. )
  121. # free sampled token ids from the previous step.
  122. # TODO: we could reuse the sampled token ids tensor from
  123. # the previous step instead.
  124. for output in model_input.cached_outputs[:-1]:
  125. output.sampled_token_ids = None
  126. assert model_input.cached_outputs[-1].sampled_token_ids is not None
  127. def prepare_input(
  128. self,
  129. execute_model_req: Optional[ExecuteModelRequest] = None,
  130. ) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
  131. torch.Tensor]]]:
  132. """
  133. Depending on the current state of the request and multi step worker,
  134. this method may skip the normal _prepare_model_input and
  135. _prepare_worker_input methods and instead used cached values.
  136. """
  137. if self.is_driver_worker:
  138. if execute_model_req is None:
  139. if self.do_metadata_broadcast:
  140. # This signals that there's no more requests to process for
  141. # now. All workers are running infinite loop with
  142. # broadcast_tensor_dict, and it stops the loop when the
  143. # driver broadcasts an empty input. Send an empty input to
  144. # notify all other workers to stop their execution loop.
  145. broadcast_tensor_dict({}, src=0)
  146. return None
  147. virtual_engine = execute_model_req.virtual_engine
  148. (model_input, worker_input,
  149. kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
  150. assert isinstance(model_input, StatefulModelInput)
  151. if execute_model_req.is_first_multi_step:
  152. # cache the worker input and model input for the next steps
  153. self.multi_step_states[virtual_engine] = MultiStepState(
  154. worker_input=worker_input, model_input=model_input
  155. )
  156. # if TP workers
  157. else:
  158. broadcast_data = self._get_worker_input_from_broadcast()
  159. # if the driver has sent an empty input, we should stop the worker
  160. # loop
  161. if broadcast_data is None:
  162. return None
  163. model_input, worker_input, kwargs = broadcast_data
  164. assert isinstance(model_input, StatefulModelInput)
  165. virtual_engine = worker_input.virtual_engine
  166. if model_input.is_first_multi_step:
  167. pass
  168. # TODO: Can cache the worker input and model input for the
  169. # next steps. See below for details
  170. else:
  171. # TODO: possible to also cache and reuse the cached worker
  172. # input and model input. The idea is essentially the delta
  173. # optimization for model_inputs. Where the TP workers can cache
  174. # the model input states and we only broadcast the delta need
  175. # for the next step (sampled_token_ids from the previous step)
  176. assert isinstance(model_input, StatefulModelInput)
  177. # we need to update the last sampled token ids in the model
  178. # input for the workers so that they can run inplace
  179. # advance_step
  180. model_input.add_sampler_output(
  181. SamplerOutput(outputs=[], sampled_token_ids=None),
  182. model_input.last_sampled_token_ids,
  183. )
  184. assert model_input is not None
  185. assert worker_input is not None
  186. return model_input, worker_input, kwargs