ray_tpu_executor.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import asyncio
  2. import os
  3. from collections import defaultdict
  4. from itertools import islice, repeat
  5. from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
  6. Union)
  7. from loguru import logger
  8. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  9. from aphrodite.common.utils import (get_aphrodite_instance_id,
  10. get_distributed_init_method, get_ip,
  11. get_open_port, make_async)
  12. from aphrodite.executor.executor_base import ExecutorAsyncBase
  13. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  14. from aphrodite.executor.tpu_executor import TPUExecutor
  15. if ray is not None:
  16. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  17. if TYPE_CHECKING:
  18. from ray.util.placement_group import PlacementGroup
  19. APHRODITE_TRACE_FUNCTION = int(os.getenv("APHRODITE_TRACE_FUNCTION", 0))
  20. class RayTPUExecutor(TPUExecutor):
  21. def __init__(self, *args, **kwargs):
  22. # This is non-None when the execute model loop is running
  23. # in the parallel workers. It's a coroutine in the AsyncAphrodite case.
  24. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
  25. # Updated by implementations that require additional args to be passed
  26. # to the _run_workers execute_model call
  27. self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
  28. super().__init__(*args, **kwargs)
  29. def _init_executor(self) -> None:
  30. assert self.parallel_config.distributed_executor_backend == "ray"
  31. placement_group = self.parallel_config.placement_group
  32. # Disable Ray usage stats collection.
  33. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  34. if ray_usage != "1":
  35. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  36. # Create the parallel TPU workers.
  37. self._init_workers_ray(placement_group)
  38. def _init_workers_ray(self, placement_group: "PlacementGroup",
  39. **ray_remote_kwargs):
  40. # The driver dummy worker does not actually use any resources.
  41. # It holds the resource for the driver worker.
  42. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
  43. # The remaining workers are the actual ray actors.
  44. self.workers: List[RayWorkerWrapper] = []
  45. # Create the workers.
  46. driver_ip = get_ip()
  47. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  48. if not bundle.get("TPU", 0):
  49. continue
  50. scheduling_strategy = PlacementGroupSchedulingStrategy(
  51. placement_group=placement_group,
  52. placement_group_capture_child_tasks=True,
  53. placement_group_bundle_index=bundle_id,
  54. )
  55. assert self.speculative_config is None
  56. worker_module_name = "aphrodite.task_handler.tpu_worker"
  57. worker_class_name = "TPUWorker"
  58. worker = ray.remote(
  59. num_cpus=0,
  60. resources={"TPU": 1},
  61. scheduling_strategy=scheduling_strategy,
  62. **ray_remote_kwargs,
  63. )(RayWorkerWrapper).remote(
  64. worker_module_name=worker_module_name,
  65. worker_class_name=worker_class_name,
  66. trust_remote_code=self.model_config.trust_remote_code,
  67. )
  68. worker_ip = ray.get(worker.get_node_ip.remote())
  69. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  70. # If the worker is on the same node as the driver, we use it
  71. # as the resource holder for the driver process.
  72. self.driver_dummy_worker = worker
  73. self.driver_worker = RayWorkerWrapper(
  74. worker_module_name=worker_module_name,
  75. worker_class_name=worker_class_name,
  76. trust_remote_code=self.model_config.trust_remote_code,
  77. )
  78. else:
  79. # Else, added to the list of workers.
  80. self.workers.append(worker)
  81. if self.driver_dummy_worker is None:
  82. raise ValueError(
  83. "Ray does not allocate any TPUs on the driver node. Consider "
  84. "adjusting the Ray placement group or running the driver on a "
  85. "TPU node.")
  86. # Get the set of TPU IDs used on each node.
  87. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  88. use_dummy_driver=True)
  89. node_workers = defaultdict(list)
  90. for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
  91. node_workers[node_id].append(i)
  92. APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
  93. # Set environment variables for the driver and workers.
  94. all_args_to_update_environment_variables = [({
  95. "APHRODITE_INSTANCE_ID":
  96. APHRODITE_INSTANCE_ID,
  97. "APHRODITE_TRACE_FUNCTION":
  98. str(APHRODITE_TRACE_FUNCTION),
  99. }, ) for _ in worker_node_and_gpu_ids]
  100. self._run_workers("update_environment_variables",
  101. all_args=all_args_to_update_environment_variables)
  102. if len(node_workers) == 1:
  103. # in single node case, we don't need to get the IP address.
  104. # the loopback address is sufficient
  105. # NOTE: a node may have several IP addresses, one for each
  106. # network interface. `get_ip()` might return any of them,
  107. # while they might not work for communication inside the node
  108. # if the network setup is complicated. Using the loopback address
  109. # solves this issue, as it always works for communication inside
  110. # the node.
  111. driver_ip = "127.0.0.1"
  112. distributed_init_method = get_distributed_init_method(
  113. driver_ip, get_open_port())
  114. # Initialize the actual workers inside worker wrapper.
  115. init_worker_all_kwargs = [
  116. self._get_worker_kwargs(
  117. local_rank=node_workers[node_id].index(rank),
  118. rank=rank,
  119. distributed_init_method=distributed_init_method,
  120. ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
  121. ]
  122. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  123. self._run_workers("init_device")
  124. self._run_workers("load_model",
  125. max_concurrent_workers=self.parallel_config.
  126. max_parallel_loading_workers)
  127. def _driver_execute_model(
  128. self,
  129. execute_model_req: Optional[ExecuteModelRequest] = None
  130. ) -> List[SamplerOutput]:
  131. """Run execute_model in the driver worker.
  132. Passing None will cause the driver to stop the model execution
  133. loop running in each of the remote workers.
  134. """
  135. return self.driver_worker.execute_method("execute_model",
  136. execute_model_req)
  137. def _run_workers(
  138. self,
  139. method: str,
  140. *args,
  141. async_run_remote_workers_only: bool = False,
  142. all_args: Optional[List[Tuple[Any, ...]]] = None,
  143. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  144. use_dummy_driver: bool = False,
  145. max_concurrent_workers: Optional[int] = None,
  146. use_ray_compiled_dag: bool = False,
  147. **kwargs,
  148. ) -> Any:
  149. """Runs the given method on all workers. Can be used in the following
  150. ways:
  151. - async_run_remote_workers_only: If True the method will be run only
  152. in the remote workers, not the driver worker. It will also be
  153. run asynchronously and return a list of futures rather than blocking
  154. on the results.
  155. - args/kwargs: All workers share the same args/kwargs
  156. - all_args/all_kwargs: args/kwargs for each worker are specified
  157. individually
  158. """
  159. if max_concurrent_workers:
  160. raise NotImplementedError(
  161. "max_concurrent_workers is not supported yet.")
  162. count = len(self.workers)
  163. all_worker_args = repeat(args, count) if all_args is None \
  164. else islice(all_args, 1, None)
  165. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  166. else islice(all_kwargs, 1, None)
  167. # Start the ray workers first.
  168. ray_worker_outputs = [
  169. worker.execute_method.remote(method, *worker_args, **worker_kwargs)
  170. for (worker, worker_args, worker_kwargs
  171. ) in zip(self.workers, all_worker_args, all_worker_kwargs)
  172. ]
  173. if async_run_remote_workers_only:
  174. # Just return futures
  175. return ray_worker_outputs
  176. driver_args = args if all_args is None else all_args[0]
  177. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  178. # Start the driver worker after all the ray workers.
  179. if not use_dummy_driver:
  180. driver_worker_output = self.driver_worker.execute_method(
  181. method, *driver_args, **driver_kwargs)
  182. else:
  183. assert self.driver_dummy_worker is not None
  184. driver_worker_output = ray.get(
  185. self.driver_dummy_worker.execute_method.remote(
  186. method, *driver_args, **driver_kwargs))
  187. # Get the results of the ray workers.
  188. if self.workers:
  189. ray_worker_outputs = ray.get(ray_worker_outputs)
  190. return [driver_worker_output] + ray_worker_outputs
  191. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  192. """Wait for futures returned from _run_workers() with
  193. async_run_remote_workers_only to complete."""
  194. ray.get(parallel_worker_tasks)
  195. def determine_num_available_blocks(self) -> Tuple[int, int]:
  196. num_blocks = self._run_workers("determine_num_available_blocks", )
  197. num_tpu_blocks = min(b[0] for b in num_blocks)
  198. num_cpu_blocks = min(b[1] for b in num_blocks)
  199. return num_tpu_blocks, num_cpu_blocks
  200. def initialize_cache(self, num_gpu_blocks: int,
  201. num_cpu_blocks: int) -> None:
  202. logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
  203. num_cpu_blocks)
  204. self.cache_config.num_gpu_blocks = num_gpu_blocks
  205. self.cache_config.num_cpu_blocks = num_cpu_blocks
  206. self._run_workers("initialize_cache",
  207. num_gpu_blocks=num_gpu_blocks,
  208. num_cpu_blocks=num_cpu_blocks)
  209. def execute_model(
  210. self,
  211. execute_model_req: ExecuteModelRequest,
  212. ) -> List[SamplerOutput]:
  213. if self.parallel_worker_tasks is None:
  214. self.parallel_worker_tasks = self._run_workers(
  215. "start_worker_execution_loop",
  216. async_run_remote_workers_only=True,
  217. **self.extra_execute_model_run_workers_kwargs)
  218. # Only the driver worker returns the sampling results.
  219. return self._driver_execute_model(execute_model_req)
  220. def stop_remote_worker_execution_loop(self) -> None:
  221. if self.parallel_worker_tasks is None:
  222. return
  223. self._driver_execute_model()
  224. parallel_worker_tasks = self.parallel_worker_tasks
  225. self.parallel_worker_tasks = None
  226. # Ensure that workers exit model loop cleanly
  227. # (this will raise otherwise)
  228. self._wait_for_tasks_completion(parallel_worker_tasks)
  229. class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase):
  230. def __init__(self, *args, **kwargs):
  231. super().__init__(*args, **kwargs)
  232. self.driver_exec_method = make_async(self.driver_worker.execute_method)
  233. async def execute_model_async(
  234. self,
  235. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  236. if self.parallel_worker_tasks is None:
  237. # Start model execution loop running in the parallel workers
  238. self.parallel_worker_tasks = asyncio.create_task(
  239. self._start_worker_execution_loop())
  240. # Only the driver worker returns the sampling results.
  241. return await self._driver_execute_model_async(execute_model_req)
  242. async def stop_remote_worker_execution_loop_async(self) -> None:
  243. if self.parallel_worker_tasks is None:
  244. return
  245. await self._driver_execute_model_async()
  246. parallel_worker_tasks = self.parallel_worker_tasks
  247. self.parallel_worker_tasks = None
  248. # Ensure that workers exit model loop cleanly
  249. # (this will raise otherwise)
  250. await parallel_worker_tasks
  251. async def _driver_execute_model_async(
  252. self,
  253. execute_model_req: Optional[ExecuteModelRequest] = None
  254. ) -> List[SamplerOutput]:
  255. return await self.driver_exec_method("execute_model",
  256. execute_model_req)
  257. async def _start_worker_execution_loop(self):
  258. coros = [
  259. worker.execute_method.remote("start_worker_execution_loop")
  260. for worker in self.workers
  261. ]
  262. return await asyncio.gather(*coros)