ray_gpu_executor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. import asyncio
  2. import os
  3. import pickle
  4. from collections import defaultdict
  5. from itertools import islice, repeat
  6. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
  7. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  8. from aphrodite.common.utils import (get_aphrodite_instance_id,
  9. get_distributed_init_method, get_ip,
  10. get_open_port, make_async)
  11. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  12. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  13. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  14. if ray is not None:
  15. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  16. if TYPE_CHECKING:
  17. from ray.util.placement_group import PlacementGroup
  18. # If the env var is set, it uses the Ray's compiled DAG API
  19. # which optimizes the control plane overhead.
  20. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  21. USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
  22. APHRODITE_TRACE_FUNCTION = int(os.getenv("APHRODITE_TRACE_FUNCTION", 0))
  23. class RayGPUExecutor(DistributedGPUExecutor):
  24. def _init_executor(self) -> None:
  25. assert self.parallel_config.distributed_executor_backend == "ray"
  26. placement_group = self.parallel_config.placement_group
  27. # Disable Ray usage stats collection.
  28. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  29. if ray_usage != "1":
  30. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  31. # Create the parallel GPU workers.
  32. self._init_workers_ray(placement_group)
  33. self.forward_dag = None
  34. if USE_RAY_COMPILED_DAG:
  35. self.forward_dag = self._compiled_ray_dag()
  36. self.extra_execute_model_run_workers_kwargs[
  37. "use_ray_compiled_dag"] = True
  38. def _configure_ray_workers_use_nsight(self,
  39. ray_remote_kwargs) -> Dict[str, Any]:
  40. # If nsight profiling is enabled, we need to set the profiling
  41. # configuration for the ray workers as runtime env.
  42. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
  43. runtime_env.update({
  44. "nsight": {
  45. "t": "cuda,cudnn,cublas",
  46. "o": "'worker_process_%p'",
  47. "cuda-graph-trace": "node",
  48. }
  49. })
  50. return ray_remote_kwargs
  51. def _init_workers_ray(self, placement_group: "PlacementGroup",
  52. **ray_remote_kwargs):
  53. if (self.parallel_config.tensor_parallel_size == 1
  54. and self.parallel_config.pipeline_parallel_size == 1):
  55. # For single GPU case, we use a ray worker with constrained memory.
  56. num_gpus = self.cache_config.gpu_memory_utilization
  57. else:
  58. # Otherwise, the ray workers are allocated with a full GPU.
  59. num_gpus = 1
  60. # The driver dummy worker does not actually use any resources.
  61. # It holds the resource for the driver worker.
  62. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
  63. # The remaining workers are the actual ray actors.
  64. self.workers: List[RayWorkerWrapper] = []
  65. if self.parallel_config.ray_workers_use_nsight:
  66. ray_remote_kwargs = self._configure_ray_workers_use_nsight(
  67. ray_remote_kwargs)
  68. # Create the workers.
  69. driver_ip = get_ip()
  70. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  71. if not bundle.get("GPU", 0):
  72. continue
  73. scheduling_strategy = PlacementGroupSchedulingStrategy(
  74. placement_group=placement_group,
  75. placement_group_capture_child_tasks=True,
  76. placement_group_bundle_index=bundle_id,
  77. )
  78. if self.speculative_config is not None:
  79. worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
  80. worker_class_name = "create_spec_worker"
  81. else:
  82. worker_module_name = "aphrodite.task_handler.worker"
  83. worker_class_name = "Worker"
  84. worker = ray.remote(
  85. num_cpus=0,
  86. num_gpus=num_gpus,
  87. scheduling_strategy=scheduling_strategy,
  88. **ray_remote_kwargs,
  89. )(RayWorkerWrapper).remote(
  90. worker_module_name=worker_module_name,
  91. worker_class_name=worker_class_name,
  92. trust_remote_code=self.model_config.trust_remote_code,
  93. )
  94. worker_ip = ray.get(worker.get_node_ip.remote())
  95. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  96. # If the worker is on the same node as the driver, we use it
  97. # as the resource holder for the driver process.
  98. self.driver_dummy_worker = worker
  99. self.driver_worker = RayWorkerWrapper(
  100. worker_module_name=worker_module_name,
  101. worker_class_name=worker_class_name,
  102. trust_remote_code=self.model_config.trust_remote_code,
  103. )
  104. else:
  105. # Else, added to the list of workers.
  106. self.workers.append(worker)
  107. if self.driver_dummy_worker is None:
  108. raise ValueError(
  109. "Ray does not allocate any GPUs on the driver node. Consider "
  110. "adjusting the Ray placement group or running the driver on a "
  111. "GPU node.")
  112. # Get the set of GPU IDs used on each node.
  113. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  114. use_dummy_driver=True)
  115. node_workers = defaultdict(list)
  116. node_gpus = defaultdict(list)
  117. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
  118. node_workers[node_id].append(i)
  119. # `gpu_ids` can be a list of strings or integers.
  120. # convert them to integers for consistency.
  121. # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
  122. # string sorting is not sufficient.
  123. gpu_ids = [int(x) for x in gpu_ids]
  124. node_gpus[node_id].extend(gpu_ids)
  125. for node_id, gpu_ids in node_gpus.items():
  126. node_gpus[node_id] = sorted(gpu_ids)
  127. APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
  128. # Set environment variables for the driver and workers.
  129. all_args_to_update_environment_variables = [({
  130. "CUDA_VISIBLE_DEVICES":
  131. ",".join(map(str, node_gpus[node_id])),
  132. "APHRODITE_INSTANCE_ID":
  133. APHRODITE_INSTANCE_ID,
  134. "APHRODITE_TRACE_FUNCTION":
  135. str(APHRODITE_TRACE_FUNCTION),
  136. }, ) for (node_id, _) in worker_node_and_gpu_ids]
  137. self._run_workers("update_environment_variables",
  138. all_args=all_args_to_update_environment_variables)
  139. if len(node_gpus) == 1:
  140. # in single node case, we don't need to get the IP address.
  141. # the loopback address is sufficient
  142. # NOTE: a node may have several IP addresses, one for each
  143. # network interface. `get_ip()` might return any of them,
  144. # while they might not work for communication inside the node
  145. # if the network setup is complicated. Using the loopback address
  146. # solves this issue, as it always works for communication inside
  147. # the node.
  148. driver_ip = "127.0.0.1"
  149. distributed_init_method = get_distributed_init_method(
  150. driver_ip, get_open_port())
  151. # Initialize the actual workers inside worker wrapper.
  152. init_worker_all_kwargs = [
  153. self._get_worker_kwargs(
  154. local_rank=node_workers[node_id].index(rank),
  155. rank=rank,
  156. distributed_init_method=distributed_init_method,
  157. ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
  158. ]
  159. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  160. self._run_workers("init_device")
  161. self._run_workers("load_model",
  162. max_concurrent_workers=self.parallel_config.
  163. max_parallel_loading_workers)
  164. # This is the list of workers that are rank 0 of each TP group EXCEPT
  165. # global rank 0. These are the workers that will broadcast to the
  166. # rest of the workers.
  167. self.tp_driver_workers: List[RayWorkerWrapper] = []
  168. # This is the list of workers that are not drivers and not the first
  169. # worker in a TP group. These are the workers that will be
  170. # broadcasted to.
  171. self.non_driver_workers: List[RayWorkerWrapper] = []
  172. for pp_rank in range(self.parallel_config.pipeline_parallel_size):
  173. for tp_rank in range(self.parallel_config.tensor_parallel_size):
  174. rank = (pp_rank *
  175. self.parallel_config.tensor_parallel_size) + tp_rank
  176. if rank == 0:
  177. pass
  178. elif rank % self.parallel_config.tensor_parallel_size == 0:
  179. self.tp_driver_workers.append(self.workers[rank - 1])
  180. else:
  181. self.non_driver_workers.append(self.workers[rank - 1])
  182. def _driver_execute_model(
  183. self, execute_model_req: Optional[ExecuteModelRequest]
  184. ) -> Optional[List[SamplerOutput]]:
  185. """Run execute_model in the driver worker.
  186. Passing None will cause the driver to stop the model execution
  187. loop running in each of the remote workers.
  188. """
  189. return self.driver_worker.execute_method("execute_model",
  190. execute_model_req)
  191. def _run_workers(
  192. self,
  193. method: str,
  194. *args,
  195. async_run_tensor_parallel_workers_only: bool = False,
  196. all_args: Optional[List[Tuple[Any, ...]]] = None,
  197. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  198. use_dummy_driver: bool = False,
  199. max_concurrent_workers: Optional[int] = None,
  200. use_ray_compiled_dag: bool = False,
  201. **kwargs,
  202. ) -> Any:
  203. """Runs the given method on all workers. Can be used in the following
  204. ways:
  205. Args:
  206. - async_run_tensor_parallel_workers_only: If True the method will be
  207. run only in the remote TP workers, not the driver worker.
  208. It will also be run asynchronously and return a list of futures
  209. rather than blocking on the results.
  210. - args/kwargs: All workers share the same args/kwargs
  211. - all_args/all_kwargs: args/kwargs for each worker are specified
  212. individually
  213. """
  214. if max_concurrent_workers:
  215. raise NotImplementedError(
  216. "max_concurrent_workers is not supported yet.")
  217. count = len(self.workers) if not \
  218. async_run_tensor_parallel_workers_only \
  219. else len(self.non_driver_workers)
  220. all_worker_args = repeat(args, count) if all_args is None \
  221. else islice(all_args, 1, None)
  222. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  223. else islice(all_kwargs, 1, None)
  224. if use_ray_compiled_dag:
  225. # Right now, compiled DAG can only accept a single
  226. # input. TODO: Fix it.
  227. assert self.forward_dag is not None
  228. output_channels = self.forward_dag.execute(1)
  229. ray_worker_outputs = []
  230. else:
  231. # Start the ray workers first.
  232. ray_workers = self.workers
  233. if async_run_tensor_parallel_workers_only:
  234. ray_workers = self.non_driver_workers
  235. ray_worker_outputs = [
  236. worker.execute_method.remote(method, *worker_args,
  237. **worker_kwargs)
  238. for (worker, worker_args, worker_kwargs
  239. ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
  240. ]
  241. if async_run_tensor_parallel_workers_only:
  242. # Just return futures
  243. return ray_worker_outputs
  244. driver_args = args if all_args is None else all_args[0]
  245. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  246. # Start the driver worker after all the ray workers.
  247. if not use_dummy_driver:
  248. driver_worker_output = self.driver_worker.execute_method(
  249. method, *driver_args, **driver_kwargs)
  250. else:
  251. assert self.driver_dummy_worker is not None
  252. driver_worker_output = ray.get(
  253. self.driver_dummy_worker.execute_method.remote(
  254. method, *driver_args, **driver_kwargs))
  255. # Get the results of the ray workers.
  256. if self.workers:
  257. if use_ray_compiled_dag:
  258. try:
  259. ray_worker_outputs = [
  260. pickle.loads(chan.begin_read())
  261. for chan in output_channels
  262. ]
  263. finally:
  264. # Has to call end_read in order to reuse the DAG.
  265. for chan in output_channels:
  266. chan.end_read()
  267. else:
  268. ray_worker_outputs = ray.get(ray_worker_outputs)
  269. return [driver_worker_output] + ray_worker_outputs
  270. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  271. """Wait for futures returned from _run_workers() with
  272. async_run_remote_workers_only to complete."""
  273. ray.get(parallel_worker_tasks)
  274. def _compiled_ray_dag(self):
  275. import pkg_resources
  276. required_version = "2.9"
  277. current_version = pkg_resources.get_distribution("ray").version
  278. if current_version < required_version:
  279. raise ValueError(f"Ray version {required_version} or greater is "
  280. f"required, but found {current_version}")
  281. from ray.dag import InputNode, MultiOutputNode
  282. assert self.parallel_config.distributed_executor_backend == "ray"
  283. # Right now, compiled DAG requires at least 1 arg. We send
  284. # a dummy value for now. It will be fixed soon.
  285. with InputNode() as input_data:
  286. forward_dag = MultiOutputNode([
  287. worker.execute_model_compiled_dag_remote.
  288. bind( # type: ignore[attr-defined]
  289. input_data) for worker in self.workers
  290. ])
  291. return forward_dag.experimental_compile()
  292. class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
  293. def __init__(self, *args, **kwargs):
  294. super().__init__(*args, **kwargs)
  295. self.driver_exec_method = make_async(self.driver_worker.execute_method)
  296. async def _driver_execute_model_async(
  297. self,
  298. execute_model_req: Optional[ExecuteModelRequest] = None
  299. ) -> List[SamplerOutput]:
  300. async def _run_task_with_lock(task, lock, *args, **kwargs):
  301. async with lock:
  302. return await task(*args, **kwargs)
  303. tasks = []
  304. tasks.append(
  305. asyncio.create_task(
  306. _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
  307. "execute_model", execute_model_req)))
  308. for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
  309. start=1):
  310. tasks.append(
  311. asyncio.create_task(
  312. _run_task_with_lock(driver_worker.execute_method.remote,
  313. self.pp_locks[pp_rank],
  314. "execute_model", execute_model_req)))
  315. results = await asyncio.gather(*tasks)
  316. # Only the last PP stage has the final results.
  317. return results[-1]
  318. async def _start_worker_execution_loop(self):
  319. coros = [
  320. worker.execute_method.remote("start_worker_execution_loop")
  321. for worker in self.non_driver_workers
  322. ]
  323. return await asyncio.gather(*coros)