ray_tpu_executor.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import asyncio
  2. import os
  3. from collections import defaultdict
  4. from itertools import islice, repeat
  5. from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
  6. Union)
  7. from loguru import logger
  8. from aphrodite import envs
  9. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  10. from aphrodite.common.utils import (get_aphrodite_instance_id,
  11. get_distributed_init_method, get_ip,
  12. get_open_port, make_async)
  13. from aphrodite.executor.executor_base import ExecutorAsyncBase
  14. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  15. from aphrodite.executor.tpu_executor import TPUExecutor
  16. if ray is not None:
  17. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  18. if TYPE_CHECKING:
  19. from ray.util.placement_group import PlacementGroup
  20. APHRODITE_TRACE_FUNCTION = envs.APHRODITE_TRACE_FUNCTION
  21. class RayTPUExecutor(TPUExecutor):
  22. def __init__(self, *args, **kwargs):
  23. # This is non-None when the execute model loop is running
  24. # in the parallel workers. It's a coroutine in the AsyncAphrodite case.
  25. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
  26. # Updated by implementations that require additional args to be passed
  27. # to the _run_workers execute_model call
  28. self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
  29. super().__init__(*args, **kwargs)
  30. def _init_executor(self) -> None:
  31. assert self.parallel_config.distributed_executor_backend == "ray"
  32. placement_group = self.parallel_config.placement_group
  33. # Disable Ray usage stats collection.
  34. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  35. if ray_usage != "1":
  36. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  37. # Create the parallel TPU workers.
  38. self._init_workers_ray(placement_group)
  39. def _init_workers_ray(self, placement_group: "PlacementGroup",
  40. **ray_remote_kwargs):
  41. # The driver dummy worker does not actually use any resources.
  42. # It holds the resource for the driver worker.
  43. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
  44. # The remaining workers are the actual ray actors.
  45. self.workers: List[RayWorkerWrapper] = []
  46. # Create the workers.
  47. driver_ip = get_ip()
  48. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  49. if not bundle.get("TPU", 0):
  50. continue
  51. scheduling_strategy = PlacementGroupSchedulingStrategy(
  52. placement_group=placement_group,
  53. placement_group_capture_child_tasks=True,
  54. placement_group_bundle_index=bundle_id,
  55. )
  56. assert self.speculative_config is None
  57. worker_module_name = "aphrodite.task_handler.tpu_worker"
  58. worker_class_name = "TPUWorker"
  59. worker = ray.remote(
  60. num_cpus=0,
  61. resources={"TPU": 1},
  62. scheduling_strategy=scheduling_strategy,
  63. **ray_remote_kwargs,
  64. )(RayWorkerWrapper).remote(
  65. worker_module_name=worker_module_name,
  66. worker_class_name=worker_class_name,
  67. trust_remote_code=self.model_config.trust_remote_code,
  68. )
  69. worker_ip = ray.get(worker.get_node_ip.remote())
  70. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  71. # If the worker is on the same node as the driver, we use it
  72. # as the resource holder for the driver process.
  73. self.driver_dummy_worker = worker
  74. self.driver_worker = RayWorkerWrapper(
  75. worker_module_name=worker_module_name,
  76. worker_class_name=worker_class_name,
  77. trust_remote_code=self.model_config.trust_remote_code,
  78. )
  79. else:
  80. # Else, added to the list of workers.
  81. self.workers.append(worker)
  82. if self.driver_dummy_worker is None:
  83. raise ValueError(
  84. "Ray does not allocate any TPUs on the driver node. Consider "
  85. "adjusting the Ray placement group or running the driver on a "
  86. "TPU node.")
  87. # Get the set of TPU IDs used on each node.
  88. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  89. use_dummy_driver=True)
  90. node_workers = defaultdict(list)
  91. for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
  92. node_workers[node_id].append(i)
  93. APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
  94. # Set environment variables for the driver and workers.
  95. all_args_to_update_environment_variables = [({
  96. "APHRODITE_INSTANCE_ID":
  97. APHRODITE_INSTANCE_ID,
  98. "APHRODITE_TRACE_FUNCTION":
  99. str(APHRODITE_TRACE_FUNCTION),
  100. }, ) for _ in worker_node_and_gpu_ids]
  101. self._run_workers("update_environment_variables",
  102. all_args=all_args_to_update_environment_variables)
  103. if len(node_workers) == 1:
  104. # in single node case, we don't need to get the IP address.
  105. # the loopback address is sufficient
  106. # NOTE: a node may have several IP addresses, one for each
  107. # network interface. `get_ip()` might return any of them,
  108. # while they might not work for communication inside the node
  109. # if the network setup is complicated. Using the loopback address
  110. # solves this issue, as it always works for communication inside
  111. # the node.
  112. driver_ip = "127.0.0.1"
  113. distributed_init_method = get_distributed_init_method(
  114. driver_ip, get_open_port())
  115. # Initialize the actual workers inside worker wrapper.
  116. init_worker_all_kwargs = [
  117. self._get_worker_kwargs(
  118. local_rank=node_workers[node_id].index(rank),
  119. rank=rank,
  120. distributed_init_method=distributed_init_method,
  121. ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
  122. ]
  123. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  124. self._run_workers("init_device")
  125. self._run_workers("load_model",
  126. max_concurrent_workers=self.parallel_config.
  127. max_parallel_loading_workers)
  128. def _driver_execute_model(
  129. self,
  130. execute_model_req: Optional[ExecuteModelRequest] = None
  131. ) -> List[SamplerOutput]:
  132. """Run execute_model in the driver worker.
  133. Passing None will cause the driver to stop the model execution
  134. loop running in each of the remote workers.
  135. """
  136. return self.driver_worker.execute_method("execute_model",
  137. execute_model_req)
  138. def _run_workers(
  139. self,
  140. method: str,
  141. *args,
  142. async_run_remote_workers_only: bool = False,
  143. all_args: Optional[List[Tuple[Any, ...]]] = None,
  144. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  145. use_dummy_driver: bool = False,
  146. max_concurrent_workers: Optional[int] = None,
  147. use_ray_compiled_dag: bool = False,
  148. **kwargs,
  149. ) -> Any:
  150. """Runs the given method on all workers. Can be used in the following
  151. ways:
  152. - async_run_remote_workers_only: If True the method will be run only
  153. in the remote workers, not the driver worker. It will also be
  154. run asynchronously and return a list of futures rather than blocking
  155. on the results.
  156. - args/kwargs: All workers share the same args/kwargs
  157. - all_args/all_kwargs: args/kwargs for each worker are specified
  158. individually
  159. """
  160. if max_concurrent_workers:
  161. raise NotImplementedError(
  162. "max_concurrent_workers is not supported yet.")
  163. count = len(self.workers)
  164. all_worker_args = repeat(args, count) if all_args is None \
  165. else islice(all_args, 1, None)
  166. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  167. else islice(all_kwargs, 1, None)
  168. # Start the ray workers first.
  169. ray_worker_outputs = [
  170. worker.execute_method.remote(method, *worker_args, **worker_kwargs)
  171. for (worker, worker_args, worker_kwargs
  172. ) in zip(self.workers, all_worker_args, all_worker_kwargs)
  173. ]
  174. if async_run_remote_workers_only:
  175. # Just return futures
  176. return ray_worker_outputs
  177. driver_args = args if all_args is None else all_args[0]
  178. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  179. # Start the driver worker after all the ray workers.
  180. if not use_dummy_driver:
  181. driver_worker_output = self.driver_worker.execute_method(
  182. method, *driver_args, **driver_kwargs)
  183. else:
  184. assert self.driver_dummy_worker is not None
  185. driver_worker_output = ray.get(
  186. self.driver_dummy_worker.execute_method.remote(
  187. method, *driver_args, **driver_kwargs))
  188. # Get the results of the ray workers.
  189. if self.workers:
  190. ray_worker_outputs = ray.get(ray_worker_outputs)
  191. return [driver_worker_output] + ray_worker_outputs
  192. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  193. """Wait for futures returned from _run_workers() with
  194. async_run_remote_workers_only to complete."""
  195. ray.get(parallel_worker_tasks)
  196. def determine_num_available_blocks(self) -> Tuple[int, int]:
  197. num_blocks = self._run_workers("determine_num_available_blocks", )
  198. num_tpu_blocks = min(b[0] for b in num_blocks)
  199. num_cpu_blocks = min(b[1] for b in num_blocks)
  200. return num_tpu_blocks, num_cpu_blocks
  201. def initialize_cache(self, num_gpu_blocks: int,
  202. num_cpu_blocks: int) -> None:
  203. logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
  204. num_cpu_blocks)
  205. self.cache_config.num_gpu_blocks = num_gpu_blocks
  206. self.cache_config.num_cpu_blocks = num_cpu_blocks
  207. self._run_workers("initialize_cache",
  208. num_gpu_blocks=num_gpu_blocks,
  209. num_cpu_blocks=num_cpu_blocks)
  210. def execute_model(
  211. self,
  212. execute_model_req: ExecuteModelRequest,
  213. ) -> List[SamplerOutput]:
  214. if self.parallel_worker_tasks is None:
  215. self.parallel_worker_tasks = self._run_workers(
  216. "start_worker_execution_loop",
  217. async_run_remote_workers_only=True,
  218. **self.extra_execute_model_run_workers_kwargs)
  219. # Only the driver worker returns the sampling results.
  220. return self._driver_execute_model(execute_model_req)
  221. def stop_remote_worker_execution_loop(self) -> None:
  222. if self.parallel_worker_tasks is None:
  223. return
  224. self._driver_execute_model()
  225. parallel_worker_tasks = self.parallel_worker_tasks
  226. self.parallel_worker_tasks = None
  227. # Ensure that workers exit model loop cleanly
  228. # (this will raise otherwise)
  229. self._wait_for_tasks_completion(parallel_worker_tasks)
  230. class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase):
  231. def __init__(self, *args, **kwargs):
  232. super().__init__(*args, **kwargs)
  233. self.driver_exec_method = make_async(self.driver_worker.execute_method)
  234. async def execute_model_async(
  235. self,
  236. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  237. if self.parallel_worker_tasks is None:
  238. # Start model execution loop running in the parallel workers
  239. self.parallel_worker_tasks = asyncio.create_task(
  240. self._start_worker_execution_loop())
  241. # Only the driver worker returns the sampling results.
  242. return await self._driver_execute_model_async(execute_model_req)
  243. async def stop_remote_worker_execution_loop_async(self) -> None:
  244. if self.parallel_worker_tasks is None:
  245. return
  246. await self._driver_execute_model_async()
  247. parallel_worker_tasks = self.parallel_worker_tasks
  248. self.parallel_worker_tasks = None
  249. # Ensure that workers exit model loop cleanly
  250. # (this will raise otherwise)
  251. await parallel_worker_tasks
  252. async def _driver_execute_model_async(
  253. self,
  254. execute_model_req: Optional[ExecuteModelRequest] = None
  255. ) -> List[SamplerOutput]:
  256. return await self.driver_exec_method("execute_model",
  257. execute_model_req)
  258. async def _start_worker_execution_loop(self):
  259. coros = [
  260. worker.execute_method.remote("start_worker_execution_loop")
  261. for worker in self.workers
  262. ]
  263. return await asyncio.gather(*coros)