ray_gpu_executor.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. import asyncio
  2. import os
  3. from collections import defaultdict
  4. from itertools import islice, repeat
  5. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
  6. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  7. from aphrodite.common.utils import (_run_task_with_lock,
  8. error_on_invalid_device_count_status,
  9. get_aphrodite_instance_id,
  10. get_distributed_init_method, get_ip,
  11. get_open_port, make_async)
  12. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  13. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  14. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  15. if ray is not None:
  16. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  17. if TYPE_CHECKING:
  18. from ray.util.placement_group import PlacementGroup
  19. # If the env var is set, it uses the Ray's compiled DAG API
  20. # which optimizes the control plane overhead.
  21. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  22. APHRODITE_USE_RAY_COMPILED_DAG = bool(
  23. os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
  24. APHRODITE_TRACE_FUNCTION = int(os.getenv("APHRODITE_TRACE_FUNCTION", 0))
  25. APHRODITE_USE_RAY_SPMD_WORKER = bool(
  26. os.getenv("APHRODITE_USE_RAY_SPMD_WORKER", 0))
  27. class RayGPUExecutor(DistributedGPUExecutor):
  28. uses_ray: bool = True
  29. def _init_executor(self) -> None:
  30. # If the env var is set, it uses the Ray's compiled DAG API
  31. # which optimizes the control plane overhead.
  32. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  33. # Currently, this requires USE_RAY_SPMD_WORKER=True.
  34. self.use_ray_compiled_dag = APHRODITE_USE_RAY_COMPILED_DAG
  35. # If the env var is set, then we do not distinguish between the
  36. # "driver worker" vs other workers. Also, the rank 0 worker will
  37. # be executed in a remote Ray worker. Currently this requires
  38. # USE_RAY_COMPILED_DAG=True.
  39. self.use_ray_spmd_worker = APHRODITE_USE_RAY_SPMD_WORKER
  40. if self.use_ray_compiled_dag:
  41. assert self.use_ray_spmd_worker, (
  42. "APHRODITE_USE_RAY_COMPILED_DAG=1 requires "
  43. "APHRODITE_USE_RAY_SPMD_WORKER=1")
  44. if self.use_ray_spmd_worker:
  45. # TODO: Support SPMD worker for non-DAG Ray executor.
  46. assert self.use_ray_compiled_dag, (
  47. "APHRODITE_USE_RAY_SPMD_WORKER=1 requires "
  48. "APHRODITE_USE_RAY_COMPILED_DAG=1")
  49. assert self.uses_ray
  50. placement_group = self.parallel_config.placement_group
  51. # Disable Ray usage stats collection.
  52. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  53. if ray_usage != "1":
  54. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  55. # Create the parallel GPU workers.
  56. self._init_workers_ray(placement_group)
  57. self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
  58. def _configure_ray_workers_use_nsight(self,
  59. ray_remote_kwargs) -> Dict[str, Any]:
  60. # If nsight profiling is enabled, we need to set the profiling
  61. # configuration for the ray workers as runtime env.
  62. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
  63. runtime_env.update({
  64. "nsight": {
  65. "t": "cuda,cudnn,cublas",
  66. "o": "'worker_process_%p'",
  67. "cuda-graph-trace": "node",
  68. }
  69. })
  70. return ray_remote_kwargs
  71. def _get_worker_wrapper_args(self) -> Dict[str, Any]:
  72. if self.speculative_config is not None:
  73. worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
  74. worker_class_name = "create_spec_worker"
  75. else:
  76. worker_module_name = "aphrodite.task_handler.worker"
  77. worker_class_name = "Worker"
  78. return dict(
  79. worker_module_name=worker_module_name,
  80. worker_class_name=worker_class_name,
  81. trust_remote_code=self.model_config.trust_remote_code,
  82. )
  83. def _init_workers_ray(self, placement_group: "PlacementGroup",
  84. **ray_remote_kwargs):
  85. if (self.parallel_config.tensor_parallel_size == 1
  86. and self.parallel_config.pipeline_parallel_size == 1):
  87. # For single GPU case, we use a ray worker with constrained memory.
  88. num_gpus = self.cache_config.gpu_memory_utilization
  89. else:
  90. # Otherwise, the ray workers are allocated with a full GPU.
  91. num_gpus = 1
  92. # The driver dummy worker does not actually use any resources.
  93. # It holds the resource for the driver worker.
  94. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
  95. # The remaining workers are the actual ray actors.
  96. self.workers: List[RayWorkerWrapper] = []
  97. if self.parallel_config.ray_workers_use_nsight:
  98. ray_remote_kwargs = self._configure_ray_workers_use_nsight(
  99. ray_remote_kwargs)
  100. # Create the workers.
  101. driver_ip = get_ip()
  102. worker_wrapper_kwargs = self._get_worker_wrapper_args()
  103. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  104. if not bundle.get("GPU", 0):
  105. continue
  106. scheduling_strategy = PlacementGroupSchedulingStrategy(
  107. placement_group=placement_group,
  108. placement_group_capture_child_tasks=True,
  109. placement_group_bundle_index=bundle_id,
  110. )
  111. worker = ray.remote(
  112. num_cpus=0,
  113. num_gpus=num_gpus,
  114. scheduling_strategy=scheduling_strategy,
  115. **ray_remote_kwargs,
  116. )(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
  117. if self.use_ray_spmd_worker:
  118. self.workers.append(worker)
  119. else:
  120. worker_ip = ray.get(worker.get_node_ip.remote())
  121. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  122. # If the worker is on the same node as the driver, we use it
  123. # as the resource holder for the driver process.
  124. self.driver_dummy_worker = worker
  125. self.driver_worker = RayWorkerWrapper(
  126. **worker_wrapper_kwargs)
  127. else:
  128. # Else, added to the list of workers.
  129. self.workers.append(worker)
  130. if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
  131. raise ValueError(
  132. "Ray does not allocate any GPUs on the driver node. Consider "
  133. "adjusting the Ray placement group or running the driver on a "
  134. "GPU node.")
  135. # Get the set of GPU IDs used on each node.
  136. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  137. use_dummy_driver=True)
  138. # the order in `worker_node_and_gpu_ids` does not necessarily match
  139. # the machine boundaries. We need to make sure that workers in the
  140. # same node are assigned consecutive ranks.
  141. # examples:
  142. # [('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [1]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [2]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [3]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [1]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [2]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [3])] # noqa
  143. # initialize worker ranks with -1 (unassigned)
  144. worker_ranks = [-1 for x in worker_node_and_gpu_ids]
  145. current_rank = 0
  146. while -1 in worker_ranks:
  147. # whenever we find an unassigned worker, find the node
  148. index = worker_ranks.index(-1)
  149. current_node_id = worker_node_and_gpu_ids[index][0]
  150. # assign ranks to all workers in the same node
  151. for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
  152. if node_id == current_node_id:
  153. worker_ranks[i] = current_rank
  154. current_rank += 1
  155. # with the above example, worker_ranks will be [0, 4, 5, 6, 7, 1, 2, 3]
  156. node_workers = defaultdict(list) # node id -> list of worker ranks
  157. node_gpus = defaultdict(list) # node id -> list of gpu ids
  158. for worker_rank, (node_id, gpu_ids) in zip(worker_ranks,
  159. worker_node_and_gpu_ids):
  160. node_workers[node_id].append(worker_rank)
  161. # `gpu_ids` can be a list of strings or integers.
  162. # convert them to integers for consistency.
  163. # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
  164. # string sorting is not sufficient.
  165. gpu_ids = [int(x) for x in gpu_ids]
  166. node_gpus[node_id].extend(gpu_ids)
  167. for node_id, gpu_ids in node_gpus.items():
  168. node_gpus[node_id] = sorted(gpu_ids)
  169. APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
  170. # Set environment variables for the driver and workers.
  171. all_args_to_update_environment_variables = [({
  172. "CUDA_VISIBLE_DEVICES":
  173. ",".join(map(str, node_gpus[node_id])),
  174. "APHRODITE_INSTANCE_ID":
  175. APHRODITE_INSTANCE_ID,
  176. "APHRODITE_TRACE_FUNCTION":
  177. str(APHRODITE_TRACE_FUNCTION),
  178. }, ) for (node_id, _) in worker_node_and_gpu_ids]
  179. self._run_workers("update_environment_variables",
  180. all_args=all_args_to_update_environment_variables)
  181. if len(node_gpus) == 1:
  182. # in single node case, we don't need to get the IP address.
  183. # the loopback address is sufficient
  184. # NOTE: a node may have several IP addresses, one for each
  185. # network interface. `get_ip()` might return any of them,
  186. # while they might not work for communication inside the node
  187. # if the network setup is complicated. Using the loopback address
  188. # solves this issue, as it always works for communication inside
  189. # the node.
  190. driver_ip = "127.0.0.1"
  191. distributed_init_method = get_distributed_init_method(
  192. driver_ip, get_open_port())
  193. error_on_invalid_device_count_status()
  194. # Initialize the actual workers inside worker wrapper.
  195. init_worker_all_kwargs = [
  196. self._get_worker_kwargs(
  197. local_rank=node_workers[node_id].index(rank),
  198. rank=rank,
  199. distributed_init_method=distributed_init_method,
  200. ) for rank, (node_id,
  201. _) in zip(worker_ranks, worker_node_and_gpu_ids)
  202. ]
  203. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  204. self._run_workers("init_device")
  205. self._run_workers("load_model",
  206. max_concurrent_workers=self.parallel_config.
  207. max_parallel_loading_workers)
  208. # This is the list of workers that are rank 0 of each TP group EXCEPT
  209. # global rank 0. These are the workers that will broadcast to the
  210. # rest of the workers.
  211. self.tp_driver_workers: List[RayWorkerWrapper] = []
  212. # This is the list of workers that are not drivers and not the first
  213. # worker in a TP group. These are the workers that will be
  214. # broadcasted to.
  215. self.non_driver_workers: List[RayWorkerWrapper] = []
  216. # Enforce rank order for correct rank to return final output.
  217. for rank, worker in sorted(zip(worker_ranks[1:], self.workers)):
  218. # We need to skip the driver worker, which we
  219. # do by skipping worker_ranks[0] which is always 0.
  220. if rank % self.parallel_config.tensor_parallel_size == 0:
  221. self.tp_driver_workers.append(worker)
  222. else:
  223. self.non_driver_workers.append(worker)
  224. def _driver_execute_model(
  225. self, execute_model_req: Optional[ExecuteModelRequest]
  226. ) -> Optional[List[SamplerOutput]]:
  227. """Run execute_model in the driver worker.
  228. Passing None will cause the driver to stop the model execution
  229. loop running in each of the remote workers.
  230. """
  231. assert not self.use_ray_spmd_worker, (
  232. "driver_worker does not exist for APHRODITE_USE_RAY_SPMD_WORKER=1")
  233. return self.driver_worker.execute_method("execute_model",
  234. execute_model_req)
  235. def execute_model(
  236. self,
  237. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  238. if not self.use_ray_spmd_worker:
  239. return super().execute_model(execute_model_req)
  240. if self.forward_dag is None:
  241. self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
  242. outputs = ray.get(self.forward_dag.execute(execute_model_req))
  243. return outputs[0]
  244. def _run_workers(
  245. self,
  246. method: str,
  247. *args,
  248. async_run_tensor_parallel_workers_only: bool = False,
  249. all_args: Optional[List[Tuple[Any, ...]]] = None,
  250. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  251. use_dummy_driver: bool = False,
  252. max_concurrent_workers: Optional[int] = None,
  253. **kwargs,
  254. ) -> Any:
  255. """Runs the given method on all workers. Can be used in the following
  256. ways:
  257. Args:
  258. - async_run_tensor_parallel_workers_only: If True the method will be
  259. run only in the remote TP workers, not the driver worker.
  260. It will also be run asynchronously and return a list of futures
  261. rather than blocking on the results.
  262. - args/kwargs: All workers share the same args/kwargs
  263. - all_args/all_kwargs: args/kwargs for each worker are specified
  264. individually
  265. """
  266. if self.use_ray_spmd_worker:
  267. assert not async_run_tensor_parallel_workers_only, (
  268. "async_run_tensor_parallel_workers_only is not supported for "
  269. "spmd mode.")
  270. if max_concurrent_workers:
  271. raise NotImplementedError(
  272. "max_concurrent_workers is not supported yet.")
  273. count = len(self.workers) if not \
  274. async_run_tensor_parallel_workers_only \
  275. else len(self.non_driver_workers)
  276. # If using SPMD worker, all workers are the same, so we should execute
  277. # the args on all workers. Otherwise, we skip the first worker's args
  278. # because those args will go to the driver worker.
  279. first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
  280. all_worker_args = repeat(args, count) if all_args is None \
  281. else islice(all_args, first_worker_args_index, None)
  282. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  283. else islice(all_kwargs, first_worker_args_index, None)
  284. # Start the ray workers first.
  285. ray_workers = self.workers
  286. if async_run_tensor_parallel_workers_only:
  287. ray_workers = self.non_driver_workers
  288. ray_worker_outputs = [
  289. worker.execute_method.remote(method, *worker_args, **worker_kwargs)
  290. for (worker, worker_args, worker_kwargs
  291. ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
  292. ]
  293. if async_run_tensor_parallel_workers_only:
  294. # Just return futures
  295. return ray_worker_outputs
  296. driver_worker_output = []
  297. # In SPMD mode, the driver worker is the same as any other worker,
  298. # so we only explicitly execute on the driver worker if using a
  299. # non-SPMD worker class.
  300. if not self.use_ray_spmd_worker:
  301. driver_args = args if all_args is None else all_args[0]
  302. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  303. # Start the driver worker after all the ray workers.
  304. if not use_dummy_driver:
  305. driver_worker_output = [
  306. self.driver_worker.execute_method(method, *driver_args,
  307. **driver_kwargs)
  308. ]
  309. else:
  310. assert self.driver_dummy_worker is not None
  311. driver_worker_output = [
  312. ray.get(
  313. self.driver_dummy_worker.execute_method.remote(
  314. method, *driver_args, **driver_kwargs))
  315. ]
  316. # Get the results of the ray workers.
  317. if self.workers:
  318. ray_worker_outputs = ray.get(ray_worker_outputs)
  319. return driver_worker_output + ray_worker_outputs
  320. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  321. """Wait for futures returned from _run_workers() with
  322. async_run_remote_workers_only to complete."""
  323. ray.get(parallel_worker_tasks)
  324. def _compiled_ray_dag(self, enable_asyncio: bool):
  325. import pkg_resources
  326. from packaging import version
  327. required_version = version.parse("2.32")
  328. current_version = version.parse(
  329. pkg_resources.get_distribution("ray").version)
  330. if current_version < required_version:
  331. raise ValueError(f"Ray version {required_version} or greater is "
  332. f"required, but found {current_version}")
  333. from ray.dag import InputNode, MultiOutputNode
  334. assert self.parallel_config.use_ray
  335. # Right now, compiled DAG requires at least 1 arg. We send
  336. # a dummy value for now. It will be fixed soon.
  337. with InputNode() as input_data:
  338. forward_dag = MultiOutputNode([
  339. worker.execute_model_spmd.bind( # type: ignore[attr-defined]
  340. input_data) for worker in self.workers
  341. ])
  342. return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
  343. def __del__(self):
  344. if self.forward_dag is not None:
  345. self.forward_dag.teardown()
  346. import ray
  347. for worker in self.workers:
  348. ray.kill(worker)
  349. class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
  350. def __init__(self, *args, **kwargs):
  351. super().__init__(*args, **kwargs)
  352. self.pp_locks: Optional[List[asyncio.Lock]] = None
  353. self.use_ray_spmd_worker = APHRODITE_USE_RAY_SPMD_WORKER
  354. if not self.use_ray_compiled_dag:
  355. self.driver_exec_method = make_async(
  356. self.driver_worker.execute_method)
  357. async def execute_model_async(
  358. self,
  359. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  360. if not self.use_ray_spmd_worker:
  361. return await super().execute_model_async(execute_model_req)
  362. if self.forward_dag is None:
  363. self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
  364. dag_future = await self.forward_dag.execute_async(execute_model_req)
  365. outputs = await dag_future
  366. return outputs[0]
  367. async def _driver_execute_model_async(
  368. self,
  369. execute_model_req: Optional[ExecuteModelRequest] = None
  370. ) -> List[SamplerOutput]:
  371. assert not self.use_ray_spmd_worker, (
  372. "driver_worker does not exist for APHRODITE_USE_RAY_SPMD_WORKER=1")
  373. if not self.tp_driver_workers:
  374. return await self.driver_exec_method("execute_model",
  375. execute_model_req)
  376. if self.pp_locks is None:
  377. # This locks each pipeline parallel stage so multiple virtual
  378. # engines can't execute on the same stage at the same time
  379. # We create the locks here to avoid creating them in the constructor
  380. # which uses a different asyncio loop.
  381. self.pp_locks = [
  382. asyncio.Lock()
  383. for _ in range(self.parallel_config.pipeline_parallel_size)
  384. ]
  385. tasks = [
  386. asyncio.create_task(
  387. _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
  388. "execute_model", execute_model_req))
  389. ]
  390. for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
  391. start=1):
  392. tasks.append(
  393. asyncio.create_task(
  394. _run_task_with_lock(driver_worker.execute_method.remote,
  395. self.pp_locks[pp_rank],
  396. "execute_model", execute_model_req)))
  397. results = await asyncio.gather(*tasks)
  398. # Only the last PP stage has the final results.
  399. return results[-1]
  400. async def _start_worker_execution_loop(self):
  401. assert not self.use_ray_spmd_worker, (
  402. "worker loop is disabled for APHRODITE_USE_RAY_SPMD_WORKER=1")
  403. coros = [
  404. worker.execute_method.remote("start_worker_execution_loop")
  405. for worker in self.non_driver_workers
  406. ]
  407. return await asyncio.gather(*coros)
  408. def __del__(self):
  409. if self.forward_dag is not None:
  410. self.forward_dag.teardown()
  411. import ray
  412. for worker in self.workers:
  413. ray.kill(worker)