ray_gpu_executor.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. import asyncio
  2. import os
  3. import pickle
  4. from collections import defaultdict
  5. from itertools import islice, repeat
  6. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
  7. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  8. from aphrodite.common.utils import (get_aphrodite_instance_id,
  9. get_distributed_init_method, get_ip,
  10. get_open_port, make_async)
  11. from aphrodite.executor.distributed_gpu_executor import (
  12. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  13. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  14. if ray is not None:
  15. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  16. if TYPE_CHECKING:
  17. from ray.util.placement_group import PlacementGroup
  18. # If the env var is set, it uses the Ray's compiled DAG API
  19. # which optimizes the control plane overhead.
  20. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  21. USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
  22. class RayGPUExecutor(DistributedGPUExecutor):
  23. def _init_executor(self) -> None:
  24. assert self.parallel_config.distributed_executor_backend == "ray"
  25. placement_group = self.parallel_config.placement_group
  26. # Disable Ray usage stats collection.
  27. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  28. if ray_usage != "1":
  29. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  30. # Create the parallel GPU workers.
  31. self._init_workers_ray(placement_group)
  32. self.forward_dag = None
  33. if USE_RAY_COMPILED_DAG:
  34. self.forward_dag = self._compiled_ray_dag()
  35. self.extra_execute_model_run_workers_kwargs[
  36. "use_ray_compiled_dag"] = True
  37. def _configure_ray_workers_use_nsight(self,
  38. ray_remote_kwargs) -> Dict[str, Any]:
  39. # If nsight profiling is enabled, we need to set the profiling
  40. # configuration for the ray workers as runtime env.
  41. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
  42. runtime_env.update({
  43. "nsight": {
  44. "t": "cuda,cudnn,cublas",
  45. "o": "'worker_process_%p'",
  46. "cuda-graph-trace": "node",
  47. }
  48. })
  49. return ray_remote_kwargs
  50. def _init_workers_ray(self, placement_group: "PlacementGroup",
  51. **ray_remote_kwargs):
  52. if self.parallel_config.tensor_parallel_size == 1:
  53. # For single GPU case, we use a ray worker with constrained memory.
  54. num_gpus = self.cache_config.gpu_memory_utilization
  55. else:
  56. # Otherwise, the ray workers are allocated with a full GPU.
  57. num_gpus = 1
  58. # The driver dummy worker does not actually use any resources.
  59. # It holds the resource for the driver worker.
  60. self.driver_dummy_worker: RayWorkerWrapper = None
  61. # The remaining workers are the actual ray actors.
  62. self.workers: List[RayWorkerWrapper] = []
  63. if self.parallel_config.ray_workers_use_nsight:
  64. ray_remote_kwargs = self._configure_ray_workers_use_nsight(
  65. ray_remote_kwargs)
  66. # Create the workers.
  67. driver_ip = get_ip()
  68. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  69. if not bundle.get("GPU", 0):
  70. continue
  71. scheduling_strategy = PlacementGroupSchedulingStrategy(
  72. placement_group=placement_group,
  73. placement_group_capture_child_tasks=True,
  74. placement_group_bundle_index=bundle_id,
  75. )
  76. if self.speculative_config is not None:
  77. worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
  78. worker_class_name = "create_spec_worker"
  79. else:
  80. worker_module_name = "aphrodite.task_handler.worker"
  81. worker_class_name = "Worker"
  82. worker = ray.remote(
  83. num_cpus=0,
  84. num_gpus=num_gpus,
  85. scheduling_strategy=scheduling_strategy,
  86. **ray_remote_kwargs,
  87. )(RayWorkerWrapper).remote(
  88. worker_module_name=worker_module_name,
  89. worker_class_name=worker_class_name,
  90. trust_remote_code=self.model_config.trust_remote_code,
  91. )
  92. worker_ip = ray.get(worker.get_node_ip.remote())
  93. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  94. # If the worker is on the same node as the driver, we use it
  95. # as the resource holder for the driver process.
  96. self.driver_dummy_worker = worker
  97. self.driver_worker = RayWorkerWrapper(
  98. worker_module_name=worker_module_name,
  99. worker_class_name=worker_class_name,
  100. trust_remote_code=self.model_config.trust_remote_code,
  101. )
  102. else:
  103. # Else, added to the list of workers.
  104. self.workers.append(worker)
  105. if self.driver_dummy_worker is None:
  106. raise ValueError(
  107. "Ray does not allocate any GPUs on the driver node. Consider "
  108. "adjusting the Ray placement group or running the driver on a "
  109. "GPU node.")
  110. # Get the set of GPU IDs used on each node.
  111. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  112. use_dummy_driver=True)
  113. node_workers = defaultdict(list)
  114. node_gpus = defaultdict(list)
  115. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
  116. node_workers[node_id].append(i)
  117. node_gpus[node_id].extend(gpu_ids)
  118. for node_id, gpu_ids in node_gpus.items():
  119. node_gpus[node_id] = sorted(gpu_ids)
  120. APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
  121. # Set environment variables for the driver and workers.
  122. all_args_to_update_environment_variables = [({
  123. "CUDA_VISIBLE_DEVICES":
  124. ",".join(map(str, node_gpus[node_id])),
  125. "APHRODITE_INSTANCE_ID":
  126. APHRODITE_INSTANCE_ID,
  127. "APHRODITE_TRACE_FUNCTION":
  128. os.getenv("APHRODITE_TRACE_FUNCTION", "0"),
  129. }, ) for (node_id, _) in worker_node_and_gpu_ids]
  130. self._run_workers("update_environment_variables",
  131. all_args=all_args_to_update_environment_variables)
  132. distributed_init_method = get_distributed_init_method(
  133. driver_ip, get_open_port())
  134. # Initialize the actual workers inside worker wrapper.
  135. init_worker_all_kwargs = [
  136. self._get_worker_kwargs(
  137. local_rank=node_workers[node_id].index(rank),
  138. rank=rank,
  139. distributed_init_method=distributed_init_method,
  140. ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
  141. ]
  142. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  143. self._run_workers("init_device")
  144. self._run_workers("load_model",
  145. max_concurrent_workers=self.parallel_config.
  146. max_parallel_loading_workers)
  147. def _driver_execute_model(
  148. self,
  149. execute_model_req: Optional[ExecuteModelRequest] = None
  150. ) -> List[SamplerOutput]:
  151. """Run execute_model in the driver worker.
  152. Passing None will cause the driver to stop the model execution
  153. loop running in each of the remote workers.
  154. """
  155. return self.driver_worker.execute_method("execute_model",
  156. execute_model_req)
  157. def _run_workers(
  158. self,
  159. method: str,
  160. *args,
  161. async_run_remote_workers_only: bool = False,
  162. all_args: Optional[List[Tuple[Any, ...]]] = None,
  163. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  164. use_dummy_driver: bool = False,
  165. max_concurrent_workers: Optional[int] = None,
  166. use_ray_compiled_dag: bool = False,
  167. **kwargs,
  168. ) -> Any:
  169. """Runs the given method on all workers. Can be used in the following
  170. ways:
  171. - async_run_remote_workers_only: If True the method will be run only
  172. in the remote workers, not the driver worker. It will also be
  173. run asynchronously and return a list of futures rather than blocking
  174. on the results.
  175. - args/kwargs: All workers share the same args/kwargs
  176. - all_args/all_kwargs: args/kwargs for each worker are specified
  177. individually
  178. """
  179. if max_concurrent_workers:
  180. raise NotImplementedError(
  181. "max_concurrent_workers is not supported yet.")
  182. count = len(self.workers)
  183. all_worker_args = repeat(args, count) if all_args is None \
  184. else islice(all_args, 1, None)
  185. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  186. else islice(all_kwargs, 1, None)
  187. if use_ray_compiled_dag:
  188. # Right now, compiled DAG can only accept a single
  189. # input. TODO: Fix it.
  190. assert self.forward_dag is not None
  191. output_channels = self.forward_dag.execute(1)
  192. ray_worker_outputs = []
  193. else:
  194. # Start the ray workers first.
  195. ray_worker_outputs = [
  196. worker.execute_method.remote(method, *worker_args,
  197. **worker_kwargs)
  198. for (worker, worker_args, worker_kwargs
  199. ) in zip(self.workers, all_worker_args, all_worker_kwargs)
  200. ]
  201. if async_run_remote_workers_only:
  202. # Just return futures
  203. return ray_worker_outputs
  204. driver_args = args if all_args is None else all_args[0]
  205. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  206. # Start the driver worker after all the ray workers.
  207. if not use_dummy_driver:
  208. driver_worker_output = self.driver_worker.execute_method(
  209. method, *driver_args, **driver_kwargs)
  210. else:
  211. driver_worker_output = ray.get(
  212. self.driver_dummy_worker.execute_method.remote(
  213. method, *driver_args, **driver_kwargs))
  214. # Get the results of the ray workers.
  215. if self.workers:
  216. if use_ray_compiled_dag:
  217. try:
  218. ray_worker_outputs = [
  219. pickle.loads(chan.begin_read())
  220. for chan in output_channels
  221. ]
  222. finally:
  223. # Has to call end_read in order to reuse the DAG.
  224. for chan in output_channels:
  225. chan.end_read()
  226. else:
  227. ray_worker_outputs = ray.get(ray_worker_outputs)
  228. return [driver_worker_output] + ray_worker_outputs
  229. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  230. """Wait for futures returned from _run_workers() with
  231. async_run_remote_workers_only to complete."""
  232. ray.get(parallel_worker_tasks)
  233. def _compiled_ray_dag(self):
  234. import pkg_resources
  235. required_version = "2.9"
  236. current_version = pkg_resources.get_distribution("ray").version
  237. if current_version < required_version:
  238. raise ValueError(f"Ray version {required_version} or greater is "
  239. f"required, but found {current_version}")
  240. from ray.dag import InputNode, MultiOutputNode
  241. assert self.parallel_config.distributed_executor_backend == "ray"
  242. # Right now, compiled DAG requires at least 1 arg. We send
  243. # a dummy value for now. It will be fixed soon.
  244. with InputNode() as input_data:
  245. forward_dag = MultiOutputNode([
  246. worker.execute_model_compiled_dag_remote.bind(input_data)
  247. for worker in self.workers
  248. ])
  249. return forward_dag.experimental_compile()
  250. def check_health(self) -> None:
  251. """Raises an error if engine is unhealthy."""
  252. self._check_if_any_actor_is_dead()
  253. def _check_if_any_actor_is_dead(self):
  254. if not self.workers:
  255. return
  256. dead_actors = []
  257. for actor in self.workers:
  258. actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
  259. if actor_state["State"] == "DEAD":
  260. dead_actors.append(actor)
  261. if dead_actors:
  262. raise RuntimeError("At least one Worker is dead. "
  263. f"Dead Workers: {dead_actors}. ")
  264. class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
  265. def __init__(self, *args, **kwargs):
  266. super().__init__(*args, **kwargs)
  267. self.driver_exec_method = make_async(self.driver_worker.execute_method)
  268. async def _driver_execute_model_async(
  269. self,
  270. execute_model_req: Optional[ExecuteModelRequest] = None
  271. ) -> List[SamplerOutput]:
  272. return await self.driver_exec_method("execute_model",
  273. execute_model_req)
  274. async def _start_worker_execution_loop(self):
  275. coros = [
  276. worker.execute_method.remote("start_worker_execution_loop")
  277. for worker in self.workers
  278. ]
  279. return await asyncio.gather(*coros)