ray_gpu_executor.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import asyncio
  2. import os
  3. import pickle
  4. from collections import defaultdict
  5. from itertools import islice, repeat
  6. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
  7. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  8. from aphrodite.common.utils import (get_aphrodite_instance_id,
  9. get_distributed_init_method, get_ip,
  10. get_open_port, make_async)
  11. from aphrodite.executor.distributed_gpu_executor import (
  12. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  13. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  14. if ray is not None:
  15. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  16. if TYPE_CHECKING:
  17. from ray.util.placement_group import PlacementGroup
  18. # If the env var is set, it uses the Ray's compiled DAG API
  19. # which optimizes the control plane overhead.
  20. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  21. USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
  22. class RayGPUExecutor(DistributedGPUExecutor):
  23. def _init_executor(self) -> None:
  24. assert self.parallel_config.distributed_executor_backend == "ray"
  25. placement_group = self.parallel_config.placement_group
  26. # Disable Ray usage stats collection.
  27. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  28. if ray_usage != "1":
  29. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  30. # Create the parallel GPU workers.
  31. self._init_workers_ray(placement_group)
  32. self.forward_dag = None
  33. if USE_RAY_COMPILED_DAG:
  34. self.forward_dag = self._compiled_ray_dag()
  35. def _configure_ray_workers_use_nsight(self,
  36. ray_remote_kwargs) -> Dict[str, Any]:
  37. # If nsight profiling is enabled, we need to set the profiling
  38. # configuration for the ray workers as runtime env.
  39. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
  40. runtime_env.update({
  41. "nsight": {
  42. "t": "cuda,cudnn,cublas",
  43. "o": "'worker_process_%p'",
  44. "cuda-graph-trace": "node",
  45. }
  46. })
  47. return ray_remote_kwargs
  48. def _init_workers_ray(self, placement_group: "PlacementGroup",
  49. **ray_remote_kwargs):
  50. if self.parallel_config.tensor_parallel_size == 1:
  51. # For single GPU case, we use a ray worker with constrained memory.
  52. num_gpus = self.cache_config.gpu_memory_utilization
  53. else:
  54. # Otherwise, the ray workers are allocated with a full GPU.
  55. num_gpus = 1
  56. # The driver dummy worker does not actually use any resources.
  57. # It holds the resource for the driver worker.
  58. self.driver_dummy_worker: RayWorkerWrapper = None
  59. # The remaining workers are the actual ray actors.
  60. self.workers: List[RayWorkerWrapper] = []
  61. if self.parallel_config.ray_workers_use_nsight:
  62. ray_remote_kwargs = self._configure_ray_workers_use_nsight(
  63. ray_remote_kwargs)
  64. # Create the workers.
  65. driver_ip = get_ip()
  66. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  67. if not bundle.get("GPU", 0):
  68. continue
  69. scheduling_strategy = PlacementGroupSchedulingStrategy(
  70. placement_group=placement_group,
  71. placement_group_capture_child_tasks=True,
  72. placement_group_bundle_index=bundle_id,
  73. )
  74. if self.speculative_config is not None:
  75. worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
  76. worker_class_name = "create_spec_worker"
  77. else:
  78. worker_module_name = "aphrodite.task_handler.worker"
  79. worker_class_name = "Worker"
  80. worker = ray.remote(
  81. num_cpus=0,
  82. num_gpus=num_gpus,
  83. scheduling_strategy=scheduling_strategy,
  84. **ray_remote_kwargs,
  85. )(RayWorkerWrapper).remote(
  86. worker_module_name=worker_module_name,
  87. worker_class_name=worker_class_name,
  88. trust_remote_code=self.model_config.trust_remote_code,
  89. )
  90. worker_ip = ray.get(worker.get_node_ip.remote())
  91. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  92. # If the worker is on the same node as the driver, we use it
  93. # as the resource holder for the driver process.
  94. self.driver_dummy_worker = worker
  95. self.driver_worker = RayWorkerWrapper(
  96. worker_module_name=worker_module_name,
  97. worker_class_name=worker_class_name,
  98. trust_remote_code=self.model_config.trust_remote_code,
  99. )
  100. else:
  101. # Else, added to the list of workers.
  102. self.workers.append(worker)
  103. if self.driver_dummy_worker is None:
  104. raise ValueError(
  105. "Ray does not allocate any GPUs on the driver node. Consider "
  106. "adjusting the Ray placement group or running the driver on a "
  107. "GPU node.")
  108. # Get the set of GPU IDs used on each node.
  109. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  110. use_dummy_driver=True)
  111. node_workers = defaultdict(list)
  112. node_gpus = defaultdict(list)
  113. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
  114. node_workers[node_id].append(i)
  115. node_gpus[node_id].extend(gpu_ids)
  116. for node_id, gpu_ids in node_gpus.items():
  117. node_gpus[node_id] = sorted(gpu_ids)
  118. APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
  119. # Set environment variables for the driver and workers.
  120. all_args_to_update_environment_variables = [({
  121. "CUDA_VISIBLE_DEVICES":
  122. ",".join(map(str, node_gpus[node_id])),
  123. "APHRODITE_INSTANCE_ID":
  124. APHRODITE_INSTANCE_ID,
  125. "APHRODITE_TRACE_FUNCTION":
  126. os.getenv("APHRODITE_TRACE_FUNCTION", "0"),
  127. }, ) for (node_id, _) in worker_node_and_gpu_ids]
  128. self._run_workers("update_environment_variables",
  129. all_args=all_args_to_update_environment_variables)
  130. distributed_init_method = get_distributed_init_method(
  131. driver_ip, get_open_port())
  132. # Initialize the actual workers inside worker wrapper.
  133. init_worker_all_kwargs = [
  134. self._get_worker_kwargs(
  135. local_rank=node_workers[node_id].index(rank),
  136. rank=rank,
  137. distributed_init_method=distributed_init_method,
  138. ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
  139. ]
  140. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  141. self._run_workers("init_device")
  142. self._run_workers("load_model",
  143. max_concurrent_workers=self.parallel_config.
  144. max_parallel_loading_workers)
  145. def execute_model(
  146. self,
  147. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  148. all_outputs = self._run_workers(
  149. "execute_model",
  150. driver_kwargs={"execute_model_req": execute_model_req},
  151. use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
  152. # Only the driver worker returns the sampling results.
  153. return all_outputs[0]
  154. def _run_workers(
  155. self,
  156. method: str,
  157. *args,
  158. driver_args: Optional[Tuple[Any, ...]] = None,
  159. driver_kwargs: Optional[Dict[str, Any]] = None,
  160. all_args: Optional[List[Tuple[Any, ...]]] = None,
  161. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  162. use_dummy_driver: bool = False,
  163. max_concurrent_workers: Optional[int] = None,
  164. use_ray_compiled_dag: bool = False,
  165. **kwargs,
  166. ) -> Any:
  167. """Runs the given method on all workers. Can be used in the following
  168. ways:
  169. - args/kwargs: All workers share the same args/kwargs
  170. - args/kwargs and driver_args/driver_kwargs: Driver worker has
  171. different args
  172. - all_args/all_kwargs: args/kwargs for each worker are specified
  173. individually
  174. """
  175. if max_concurrent_workers:
  176. raise NotImplementedError(
  177. "max_concurrent_workers is not supported yet.")
  178. if driver_args is None:
  179. driver_args = args if all_args is None else all_args[0]
  180. if driver_kwargs is None:
  181. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  182. count = len(self.workers)
  183. all_worker_args = repeat(args, count) if all_args is None \
  184. else islice(all_args, 1, None)
  185. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  186. else islice(all_kwargs, 1, None)
  187. if use_ray_compiled_dag:
  188. # Right now, compiled DAG can only accept a single
  189. # input. TODO: Fix it.
  190. assert self.forward_dag is not None
  191. output_channels = self.forward_dag.execute(1)
  192. else:
  193. # Start the ray workers first.
  194. ray_worker_outputs = [
  195. worker.execute_method.remote(method, *worker_args,
  196. **worker_kwargs)
  197. for (worker, worker_args, worker_kwargs
  198. ) in zip(self.workers, all_worker_args, all_worker_kwargs)
  199. ]
  200. # Start the driver worker after all the ray workers.
  201. if not use_dummy_driver:
  202. driver_worker_output = self.driver_worker.execute_method(
  203. method, *driver_args, **driver_kwargs)
  204. else:
  205. driver_worker_output = ray.get(
  206. self.driver_dummy_worker.execute_method.remote(
  207. method, *driver_args, **driver_kwargs))
  208. # Get the results of the ray workers.
  209. if self.workers:
  210. if use_ray_compiled_dag:
  211. try:
  212. ray_worker_outputs = [
  213. pickle.loads(chan.begin_read())
  214. for chan in output_channels
  215. ]
  216. finally:
  217. # Has to call end_read in order to reuse the DAG.
  218. for chan in output_channels:
  219. chan.end_read()
  220. else:
  221. ray_worker_outputs = ray.get(ray_worker_outputs)
  222. return [driver_worker_output] + ray_worker_outputs
  223. def _compiled_ray_dag(self):
  224. import pkg_resources
  225. required_version = "2.9"
  226. current_version = pkg_resources.get_distribution("ray").version
  227. if current_version < required_version:
  228. raise ValueError(f"Ray version {required_version} or greater is "
  229. f"required, but found {current_version}")
  230. from ray.dag import InputNode, MultiOutputNode
  231. assert self.parallel_config.distributed_executor_backend == "ray"
  232. # Right now, compiled DAG requires at least 1 arg. We send
  233. # a dummy value for now. It will be fixed soon.
  234. with InputNode() as input_data:
  235. forward_dag = MultiOutputNode([
  236. worker.execute_model_compiled_dag_remote.bind(input_data)
  237. for worker in self.workers
  238. ])
  239. return forward_dag.experimental_compile()
  240. def check_health(self) -> None:
  241. """Raises an error if engine is unhealthy."""
  242. self._check_if_any_actor_is_dead()
  243. def _check_if_any_actor_is_dead(self):
  244. if not self.workers:
  245. return
  246. dead_actors = []
  247. for actor in self.workers:
  248. actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
  249. if actor_state["State"] == "DEAD":
  250. dead_actors.append(actor)
  251. if dead_actors:
  252. raise RuntimeError("At least one Worker is dead. "
  253. f"Dead Workers: {dead_actors}. ")
  254. class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
  255. def __init__(self, *args, **kwargs):
  256. super().__init__(*args, **kwargs)
  257. self.driver_executor = make_async(self.driver_worker.execute_method)
  258. async def _run_workers_async(
  259. self,
  260. method: str,
  261. *args,
  262. driver_args: Optional[Tuple[Any, ...]] = None,
  263. driver_kwargs: Optional[Dict[str, Any]] = None,
  264. **kwargs,
  265. ) -> Any:
  266. """Runs the given method on all workers."""
  267. coros = []
  268. if driver_args is None:
  269. driver_args = args
  270. if driver_kwargs is None:
  271. driver_kwargs = kwargs
  272. coros.append(
  273. self.driver_executor(method, *driver_args, **driver_kwargs))
  274. # Run the ray workers asynchronously.
  275. for worker in self.workers:
  276. coros.append(worker.execute_method.remote(method, *args, **kwargs))
  277. all_outputs = await asyncio.gather(*coros)
  278. return all_outputs