multi_step_tpu_worker.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import dataclasses
  2. from typing import Dict, Optional, Tuple
  3. import torch
  4. from aphrodite.common.sequence import ExecuteModelRequest
  5. from aphrodite.distributed import broadcast_tensor_dict
  6. from aphrodite.worker.tpu_model_runner import ModelInputForTPU
  7. from aphrodite.worker.tpu_worker import TPUWorker
  8. from aphrodite.worker.worker_base import WorkerInput
  9. class MultiStepTPUWorker(TPUWorker):
  10. def __init__(self, *args, **kwargs):
  11. super().__init__(*args, **kwargs)
  12. self.cached_model_input: Optional[ModelInputForTPU] = None
  13. def _get_driver_input_and_broadcast(
  14. self, execute_model_req: ExecuteModelRequest
  15. ) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]:
  16. assert self.is_driver_worker
  17. assert execute_model_req.virtual_engine == 0
  18. is_first_multi_step = execute_model_req.is_first_multi_step
  19. is_last_step = execute_model_req.is_last_step
  20. if is_first_multi_step:
  21. worker_input: WorkerInput = self.prepare_worker_input(
  22. execute_model_req=execute_model_req
  23. )
  24. worker_input = dataclasses.replace(
  25. worker_input,
  26. num_steps=execute_model_req.num_lookahead_slots + 1,
  27. )
  28. model_input: ModelInputForTPU = (
  29. self.model_runner.prepare_model_input(
  30. execute_model_req.seq_group_metadata_list,
  31. execute_model_req.virtual_engine,
  32. execute_model_req.finished_requests_ids,
  33. )
  34. )
  35. if execute_model_req.async_callback:
  36. model_input = dataclasses.replace(
  37. model_input, async_callback=execute_model_req.async_callback
  38. )
  39. else:
  40. assert self.cached_model_input is not None
  41. model_input = self.cached_model_input
  42. worker_input = WorkerInput()
  43. model_input = dataclasses.replace(
  44. model_input,
  45. is_first_multi_step=is_first_multi_step,
  46. is_last_step=is_last_step,
  47. )
  48. if self.do_metadata_broadcast:
  49. if is_first_multi_step:
  50. broadcast_data = worker_input.as_broadcastable_tensor_dict()
  51. broadcast_data.update(
  52. model_input.as_broadcastable_tensor_dict()
  53. )
  54. broadcast_tensor_dict(broadcast_data, src=0)
  55. else:
  56. broadcast_data = {
  57. "is_first_multi_step": is_first_multi_step,
  58. "is_last_step": is_last_step,
  59. }
  60. broadcast_tensor_dict(broadcast_data, src=0)
  61. # Retuning empty dict here to keep this compatible with
  62. # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
  63. return model_input, worker_input, {}
  64. def prepare_input(
  65. self,
  66. execute_model_req: Optional[ExecuteModelRequest] = None,
  67. ) -> Optional[
  68. Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]
  69. ]:
  70. if self.is_driver_worker:
  71. if execute_model_req is None:
  72. if self.do_metadata_broadcast:
  73. broadcast_tensor_dict({}, src=0)
  74. return None
  75. model_input, worker_input, _ = self._get_driver_input_and_broadcast(
  76. execute_model_req
  77. )
  78. if model_input.is_first_multi_step:
  79. self.cached_model_input = model_input
  80. return model_input, worker_input, {}
  81. else:
  82. broadcast_data = broadcast_tensor_dict(src=0)
  83. if not broadcast_data:
  84. return None
  85. if len(broadcast_data) == 2:
  86. assert self.cached_model_input is not None
  87. self.cached_model_input = dataclasses.replace(
  88. self.cached_model_input,
  89. is_first_multi_step=broadcast_data["is_first_multi_step"],
  90. is_last_step=broadcast_data["is_last_step"],
  91. )
  92. empty_worker_input = WorkerInput()
  93. return self.cached_model_input, empty_worker_input, {}
  94. worker_input = WorkerInput.from_broadcasted_tensor_dict(
  95. broadcast_data
  96. )
  97. model_input = (
  98. self.model_runner.make_model_input_from_broadcasted_tensor_dict(
  99. broadcast_data
  100. )
  101. )
  102. self.cached_model_input = model_input
  103. return model_input, worker_input, {}