multi_step_worker.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. from dataclasses import dataclass
  2. from typing import List, Optional, Tuple
  3. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  4. from aphrodite.distributed import broadcast_tensor_dict, get_pp_group
  5. from aphrodite.task_handler.model_runner_base import BroadcastableModelInput
  6. from aphrodite.task_handler.multi_step_model_runner import (
  7. MultiStepModelRunner, StatefulModelInput)
  8. from aphrodite.task_handler.worker import Worker, WorkerInput
  9. @dataclass
  10. class MultiStepState:
  11. worker_input: WorkerInput
  12. model_input: StatefulModelInput
  13. class MultiStepWorker(Worker):
  14. def __init__(self, *args, **kwargs):
  15. super().__init__(*args, **kwargs)
  16. base_model_runner = self.model_runner
  17. # for multi-step model, wrap the model runner with MultiStepModelRunner
  18. self.model_runner = MultiStepModelRunner(
  19. base_model_runner,
  20. base_model_runner.model_config,
  21. base_model_runner.parallel_config,
  22. base_model_runner.scheduler_config,
  23. base_model_runner.device_config,
  24. base_model_runner.cache_config,
  25. load_config=base_model_runner.load_config,
  26. lora_config=self.lora_config,
  27. kv_cache_dtype=self.cache_config.cache_dtype,
  28. is_driver_worker=base_model_runner.is_driver_worker,
  29. prompt_adapter_config=base_model_runner.prompt_adapter_config,
  30. )
  31. pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
  32. self.multi_step_states: List[Optional[MultiStepState]] = [
  33. None
  34. ] * pipeline_parallel_size
  35. self.temp_output = None
  36. def _get_driver_input_and_broadcast(
  37. self, execute_model_req: ExecuteModelRequest
  38. ) -> Tuple[BroadcastableModelInput, WorkerInput]:
  39. """
  40. Get the driver input and broadcast it to other workers.
  41. """
  42. assert self.is_driver_worker
  43. virtual_engine = execute_model_req.virtual_engine
  44. is_first_multi_step = execute_model_req.is_first_multi_step
  45. if is_first_multi_step:
  46. # on first step we prepare the worker input and model input normally
  47. worker_input: WorkerInput = self.prepare_worker_input(
  48. execute_model_req=execute_model_req
  49. )
  50. model_input: StatefulModelInput = (
  51. self.model_runner.prepare_model_input(
  52. execute_model_req.seq_group_metadata_list,
  53. execute_model_req.virtual_engine,
  54. execute_model_req.finished_requests_ids,
  55. )
  56. )
  57. else:
  58. # on subsequent steps we reuse the worker input and model input
  59. multi_step_state = self.multi_step_states[virtual_engine]
  60. worker_input = multi_step_state.worker_input
  61. model_input = multi_step_state.model_input
  62. frozen_model_input = model_input.frozen_model_input
  63. assert frozen_model_input is not None
  64. assert frozen_model_input.attn_metadata is not None
  65. # clear the cached decode metadata so that it can be recomputed on
  66. # the workers
  67. frozen_model_input.attn_metadata._cached_decode_metadata = None
  68. model_input.is_first_multi_step = is_first_multi_step
  69. model_input.is_last_step = execute_model_req.is_last_step
  70. if not is_first_multi_step:
  71. # we broadcast the last sampled token ids to all TP workers so they
  72. # can update their model input metadata in-place.
  73. self._prepare_last_sampled_token_ids_for_tp_workers(
  74. execute_model_req=execute_model_req, model_input=model_input
  75. )
  76. if self.do_metadata_broadcast:
  77. broadcast_data = worker_input.as_broadcastable_tensor_dict()
  78. broadcast_data.update(model_input.as_broadcastable_tensor_dict())
  79. broadcast_tensor_dict(broadcast_data, src=0)
  80. return model_input, worker_input
  81. def _prepare_last_sampled_token_ids_for_tp_workers(
  82. self,
  83. execute_model_req: ExecuteModelRequest,
  84. model_input: StatefulModelInput,
  85. ) -> None:
  86. """
  87. Prepare the last sampled token ids for TP workers. If it's the last
  88. PP rank, then the last sampled token ids are already in the model_input.
  89. If it is NOT the last PP rank, then we need to get the last sampled
  90. token that is cached in the execute_model_req.
  91. """
  92. if get_pp_group().is_last_rank:
  93. assert (
  94. model_input.cached_outputs[-1].sampler_output.sampled_token_ids
  95. is None
  96. )
  97. assert model_input.cached_outputs[-1].sampled_token_ids is not None
  98. model_input.last_sampled_token_ids = model_input.cached_outputs[
  99. -1
  100. ].sampled_token_ids
  101. # free sampled token ids from the previous step if it has been
  102. # pythonized. Cannot free the last sampled token ids because
  103. # we need it for GPU advance_step.
  104. for output in model_input.cached_outputs[:-1]:
  105. if output.pythonized:
  106. output.sampled_token_ids = None
  107. else:
  108. # otherwise we need to get the cached sampled token ids from the
  109. # execute_model_req
  110. assert execute_model_req.last_sampled_token_ids is not None
  111. model_input.last_sampled_token_ids = (
  112. execute_model_req.last_sampled_token_ids.cuda()
  113. )
  114. model_input.add_sampler_output(
  115. SamplerOutput(outputs=[], sampled_token_ids=None),
  116. model_input.last_sampled_token_ids,
  117. )
  118. # free sampled token ids from the previous step.
  119. # TODO: we could reuse the sampled token ids tensor from
  120. # the previous step instead.
  121. for output in model_input.cached_outputs[:-1]:
  122. output.sampled_token_ids = None
  123. assert model_input.cached_outputs[-1].sampled_token_ids is not None
  124. def prepare_input(
  125. self,
  126. execute_model_req: Optional[ExecuteModelRequest] = None,
  127. ) -> Optional[Tuple[StatefulModelInput, WorkerInput]]:
  128. """
  129. Depending on the current state of the request and multi step worker,
  130. this method may skip the normal _prepare_model_input and
  131. _prepare_worker_input methods and instead used cached values.
  132. """
  133. if self.is_driver_worker:
  134. if execute_model_req is None:
  135. if self.do_metadata_broadcast:
  136. # This signals that there's no more requests to process for
  137. # now. All workers are running infinite loop with
  138. # broadcast_tensor_dict, and it stops the loop when the
  139. # driver broadcasts an empty input. Send an empty input to
  140. # notify all other workers to stop their execution loop.
  141. broadcast_tensor_dict({}, src=0)
  142. return None
  143. virtual_engine = execute_model_req.virtual_engine
  144. model_input, worker_input = self._get_driver_input_and_broadcast(
  145. execute_model_req
  146. )
  147. assert isinstance(model_input, StatefulModelInput)
  148. if execute_model_req.is_first_multi_step:
  149. # cache the worker input and model input for the next steps
  150. self.multi_step_states[virtual_engine] = MultiStepState(
  151. worker_input=worker_input, model_input=model_input
  152. )
  153. # if TP workers
  154. else:
  155. broadcast_data = self._get_worker_input_from_broadcast()
  156. # if the driver has sent an empty input, we should stop the worker
  157. # loop
  158. if broadcast_data is None:
  159. return None
  160. model_input, worker_input = broadcast_data
  161. assert isinstance(model_input, StatefulModelInput)
  162. virtual_engine = worker_input.virtual_engine
  163. if model_input.is_first_multi_step:
  164. pass
  165. # TODO: Can cache the worker input and model input for the
  166. # next steps. See below for details
  167. else:
  168. # TODO: possible to also cache and reuse the cached worker
  169. # input and model input. The idea is essentially the delta
  170. # optimization for model_inputs. Where the TP workers can cache
  171. # the model input states and we only broadcast the delta need
  172. # for the next step (sampled_token_ids from the previous step)
  173. assert isinstance(model_input, StatefulModelInput)
  174. # we need to update the last sampled token ids in the model
  175. # input for the workers so that they can run inplace
  176. # advance_step
  177. model_input.add_sampler_output(
  178. SamplerOutput(outputs=[], sampled_token_ids=None),
  179. model_input.last_sampled_token_ids,
  180. )
  181. assert model_input is not None
  182. assert worker_input is not None
  183. return model_input, worker_input