multi_step_worker.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import dataclasses
  2. from dataclasses import dataclass
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. from aphrodite.common.sequence import ExecuteModelRequest
  6. from aphrodite.distributed import broadcast_tensor_dict, get_pp_group
  7. from aphrodite.modeling.layers.sampler import SamplerOutput
  8. from aphrodite.worker.model_runner_base import BroadcastableModelInput
  9. from aphrodite.worker.multi_step_model_runner import (
  10. MultiStepModelRunner, StatefulModelInput)
  11. from aphrodite.worker.worker import Worker, WorkerInput
  12. @dataclass
  13. class MultiStepState:
  14. worker_input: WorkerInput
  15. model_input: StatefulModelInput
  16. class MultiStepWorker(Worker):
  17. def __init__(self, *args, **kwargs):
  18. super().__init__(*args, **kwargs)
  19. base_model_runner = self.model_runner
  20. # for multi-step model, wrap the model runner with MultiStepModelRunner
  21. self.model_runner = MultiStepModelRunner(
  22. base_model_runner,
  23. base_model_runner.model_config,
  24. base_model_runner.parallel_config,
  25. base_model_runner.scheduler_config,
  26. base_model_runner.device_config,
  27. base_model_runner.cache_config,
  28. load_config=base_model_runner.load_config,
  29. lora_config=self.lora_config,
  30. kv_cache_dtype=self.cache_config.cache_dtype,
  31. is_driver_worker=base_model_runner.is_driver_worker,
  32. prompt_adapter_config=base_model_runner.prompt_adapter_config,
  33. )
  34. pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
  35. self.multi_step_states: List[Optional[MultiStepState]] = [
  36. None
  37. ] * pipeline_parallel_size
  38. self.temp_output = None
  39. def _get_driver_input_and_broadcast(
  40. self, execute_model_req: ExecuteModelRequest
  41. ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
  42. """
  43. Get the driver input and broadcast it to other workers.
  44. """
  45. assert self.is_driver_worker
  46. virtual_engine = execute_model_req.virtual_engine
  47. is_first_multi_step = execute_model_req.is_first_multi_step
  48. if is_first_multi_step:
  49. # on first step we prepare the worker input and model input normally
  50. worker_input: WorkerInput = self.prepare_worker_input(
  51. execute_model_req=execute_model_req
  52. )
  53. model_input: StatefulModelInput = (
  54. self.model_runner.prepare_model_input(
  55. execute_model_req.seq_group_metadata_list,
  56. execute_model_req.virtual_engine,
  57. execute_model_req.finished_requests_ids))
  58. if execute_model_req.async_callback:
  59. model_input.frozen_model_input = dataclasses.replace( # type: ignore
  60. model_input.frozen_model_input,
  61. async_callback=execute_model_req.async_callback)
  62. else:
  63. # on subsequent steps we reuse the worker input and model input
  64. multi_step_state = self.multi_step_states[virtual_engine]
  65. worker_input = multi_step_state.worker_input
  66. model_input = multi_step_state.model_input
  67. frozen_model_input = model_input.frozen_model_input
  68. assert frozen_model_input is not None
  69. assert frozen_model_input.attn_metadata is not None
  70. # clear the cached decode metadata so that it can be recomputed on
  71. # the workers
  72. frozen_model_input.attn_metadata._cached_decode_metadata = None
  73. model_input.is_first_multi_step = is_first_multi_step
  74. model_input.is_last_step = execute_model_req.is_last_step
  75. if not is_first_multi_step:
  76. # we broadcast the last sampled token ids to all TP workers so they
  77. # can update their model input metadata in-place.
  78. self._prepare_last_sampled_token_ids_for_tp_workers(
  79. execute_model_req=execute_model_req, model_input=model_input
  80. )
  81. if self.do_metadata_broadcast:
  82. broadcast_data = worker_input.as_broadcastable_tensor_dict()
  83. broadcast_data.update(model_input.as_broadcastable_tensor_dict())
  84. broadcast_tensor_dict(broadcast_data, src=0)
  85. # Retuning empty dict here to keep this compatible with
  86. # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
  87. return model_input, worker_input, {}
  88. def _prepare_last_sampled_token_ids_for_tp_workers(
  89. self,
  90. execute_model_req: ExecuteModelRequest,
  91. model_input: StatefulModelInput,
  92. ) -> None:
  93. """
  94. Prepare the last sampled token ids for TP workers. If it's the last
  95. PP rank, then the last sampled token ids are already in the model_input.
  96. If it is NOT the last PP rank, then we need to get the last sampled
  97. token that is cached in the execute_model_req.
  98. """
  99. if get_pp_group().is_last_rank:
  100. assert (
  101. model_input.cached_outputs[-1].sampler_output.sampled_token_ids
  102. is None
  103. )
  104. assert model_input.cached_outputs[-1].sampled_token_ids is not None
  105. model_input.last_sampled_token_ids = model_input.cached_outputs[
  106. -1
  107. ].sampled_token_ids
  108. # free sampled token ids from the previous step if it has been
  109. # pythonized. Cannot free the last sampled token ids because
  110. # we need it for GPU advance_step.
  111. for output in model_input.cached_outputs[:-1]:
  112. if output.pythonized:
  113. output.sampled_token_ids = None
  114. else:
  115. # otherwise we need to get the cached sampled token ids from the
  116. # execute_model_req
  117. assert execute_model_req.last_sampled_token_ids is not None
  118. model_input.last_sampled_token_ids = (
  119. execute_model_req.last_sampled_token_ids.cuda()
  120. )
  121. model_input.add_sampler_output(
  122. SamplerOutput(outputs=[], sampled_token_ids=None),
  123. model_input.last_sampled_token_ids,
  124. )
  125. # free sampled token ids from the previous step.
  126. # TODO: we could reuse the sampled token ids tensor from
  127. # the previous step instead.
  128. for output in model_input.cached_outputs[:-1]:
  129. output.sampled_token_ids = None
  130. assert model_input.cached_outputs[-1].sampled_token_ids is not None
  131. def prepare_input(
  132. self,
  133. execute_model_req: Optional[ExecuteModelRequest] = None,
  134. ) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
  135. torch.Tensor]]]:
  136. """
  137. Depending on the current state of the request and multi step worker,
  138. this method may skip the normal _prepare_model_input and
  139. _prepare_worker_input methods and instead used cached values.
  140. """
  141. if self.is_driver_worker:
  142. if execute_model_req is None:
  143. if self.do_metadata_broadcast:
  144. # This signals that there's no more requests to process for
  145. # now. All workers are running infinite loop with
  146. # broadcast_tensor_dict, and it stops the loop when the
  147. # driver broadcasts an empty input. Send an empty input to
  148. # notify all other workers to stop their execution loop.
  149. broadcast_tensor_dict({}, src=0)
  150. return None
  151. virtual_engine = execute_model_req.virtual_engine
  152. (model_input, worker_input,
  153. kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
  154. assert isinstance(model_input, StatefulModelInput)
  155. if execute_model_req.is_first_multi_step:
  156. # cache the worker input and model input for the next steps
  157. self.multi_step_states[virtual_engine] = MultiStepState(
  158. worker_input=worker_input, model_input=model_input
  159. )
  160. # if TP workers
  161. else:
  162. broadcast_data = self._get_worker_input_from_broadcast()
  163. # if the driver has sent an empty input, we should stop the worker
  164. # loop
  165. if broadcast_data is None:
  166. return None
  167. model_input, worker_input, kwargs = broadcast_data
  168. assert isinstance(model_input, StatefulModelInput)
  169. virtual_engine = worker_input.virtual_engine
  170. if model_input.is_first_multi_step:
  171. pass
  172. # TODO: Can cache the worker input and model input for the
  173. # next steps. See below for details
  174. else:
  175. # TODO: possible to also cache and reuse the cached worker
  176. # input and model input. The idea is essentially the delta
  177. # optimization for model_inputs. Where the TP workers can cache
  178. # the model input states and we only broadcast the delta need
  179. # for the next step (sampled_token_ids from the previous step)
  180. assert isinstance(model_input, StatefulModelInput)
  181. # we need to update the last sampled token ids in the model
  182. # input for the workers so that they can run inplace
  183. # advance_step
  184. model_input.add_sampler_output(
  185. SamplerOutput(outputs=[], sampled_token_ids=None),
  186. model_input.last_sampled_token_ids,
  187. )
  188. assert model_input is not None
  189. assert worker_input is not None
  190. return model_input, worker_input, kwargs