ray_gpu_executor.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. import asyncio
  2. import os
  3. import pickle
  4. from collections import defaultdict
  5. from itertools import islice, repeat
  6. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
  7. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  8. from aphrodite.common.utils import (get_aphrodite_instance_id,
  9. get_distributed_init_method, get_ip,
  10. get_open_port, make_async)
  11. from aphrodite.executor.distributed_gpu_executor import (
  12. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  13. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  14. if ray is not None:
  15. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  16. if TYPE_CHECKING:
  17. from ray.util.placement_group import PlacementGroup
  18. # If the env var is set, it uses the Ray's compiled DAG API
  19. # which optimizes the control plane overhead.
  20. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  21. USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
  22. class RayGPUExecutor(DistributedGPUExecutor):
  23. def _init_executor(self) -> None:
  24. assert (not self.speculative_config
  25. ), "Speculative decoding not yet supported for RayGPU backend."
  26. assert self.parallel_config.worker_use_ray
  27. placement_group = self.parallel_config.placement_group
  28. # Disable Ray usage stats collection.
  29. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  30. if ray_usage != "1":
  31. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  32. # Create the parallel GPU workers.
  33. self._init_workers_ray(placement_group)
  34. self.forward_dag = None
  35. if USE_RAY_COMPILED_DAG:
  36. self.forward_dag = self._compiled_ray_dag()
  37. def _configure_ray_workers_use_nsight(self,
  38. ray_remote_kwargs) -> Dict[str, Any]:
  39. # If nsight profiling is enabled, we need to set the profiling
  40. # configuration for the ray workers as runtime env.
  41. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
  42. runtime_env.update({
  43. "nsight": {
  44. "t": "cuda,cudnn,cublas",
  45. "o": "'worker_process_%p'",
  46. "cuda-graph-trace": "node",
  47. }
  48. })
  49. return ray_remote_kwargs
  50. def _init_workers_ray(self, placement_group: "PlacementGroup",
  51. **ray_remote_kwargs):
  52. if self.parallel_config.tensor_parallel_size == 1:
  53. # For single GPU case, we use a ray worker with constrained memory.
  54. num_gpus = self.cache_config.gpu_memory_utilization
  55. else:
  56. # Otherwise, the ray workers are allocated with a full GPU.
  57. num_gpus = 1
  58. # The driver dummy worker does not actually use any resources.
  59. # It holds the resource for the driver worker.
  60. self.driver_dummy_worker: RayWorkerWrapper = None
  61. # The remaining workers are the actual ray actors.
  62. self.workers: List[RayWorkerWrapper] = []
  63. if self.parallel_config.ray_workers_use_nsight:
  64. ray_remote_kwargs = self._configure_ray_workers_use_nsight(
  65. ray_remote_kwargs)
  66. # Create the workers.
  67. driver_ip = get_ip()
  68. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  69. if not bundle.get("GPU", 0):
  70. continue
  71. scheduling_strategy = PlacementGroupSchedulingStrategy(
  72. placement_group=placement_group,
  73. placement_group_capture_child_tasks=True,
  74. placement_group_bundle_index=bundle_id,
  75. )
  76. worker = ray.remote(
  77. num_cpus=0,
  78. num_gpus=num_gpus,
  79. scheduling_strategy=scheduling_strategy,
  80. **ray_remote_kwargs,
  81. )(RayWorkerWrapper).remote(
  82. worker_module_name="aphrodite.task_handler.worker",
  83. worker_class_name="Worker",
  84. trust_remote_code=self.model_config.trust_remote_code,
  85. )
  86. worker_ip = ray.get(worker.get_node_ip.remote())
  87. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  88. # If the worker is on the same node as the driver, we use it
  89. # as the resource holder for the driver process.
  90. self.driver_dummy_worker = worker
  91. self.driver_worker = RayWorkerWrapper(
  92. worker_module_name="aphrodite.task_handler.worker",
  93. worker_class_name="Worker",
  94. trust_remote_code=self.model_config.trust_remote_code,
  95. )
  96. else:
  97. # Else, added to the list of workers.
  98. self.workers.append(worker)
  99. if self.driver_dummy_worker is None:
  100. raise ValueError(
  101. "Ray does not allocate any GPUs on the driver node. Consider "
  102. "adjusting the Ray placement group or running the driver on a "
  103. "GPU node.")
  104. # Get the set of GPU IDs used on each node.
  105. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  106. use_dummy_driver=True)
  107. node_workers = defaultdict(list)
  108. node_gpus = defaultdict(list)
  109. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
  110. node_workers[node_id].append(i)
  111. node_gpus[node_id].extend(gpu_ids)
  112. for node_id, gpu_ids in node_gpus.items():
  113. node_gpus[node_id] = sorted(gpu_ids)
  114. APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
  115. # Set environment variables for the driver and workers.
  116. all_args_to_update_environment_variables = [({
  117. "CUDA_VISIBLE_DEVICES":
  118. ",".join(map(str, node_gpus[node_id])),
  119. "APHRODITE_INSTANCE_ID":
  120. APHRODITE_INSTANCE_ID,
  121. "APHRODITE_TRACE_FUNCTION":
  122. os.getenv("APHRODITE_TRACE_FUNCTION", "0"),
  123. }, ) for (node_id, _) in worker_node_and_gpu_ids]
  124. self._run_workers("update_environment_variables",
  125. all_args=all_args_to_update_environment_variables)
  126. distributed_init_method = get_distributed_init_method(
  127. driver_ip, get_open_port())
  128. # Initialize the actual workers inside worker wrapper.
  129. init_worker_all_kwargs = [
  130. self._get_worker_kwargs(
  131. local_rank=node_workers[node_id].index(rank),
  132. rank=rank,
  133. distributed_init_method=distributed_init_method,
  134. ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
  135. ]
  136. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  137. self._run_workers("init_device")
  138. self._run_workers("load_model",
  139. max_concurrent_workers=self.parallel_config.
  140. max_parallel_loading_workers)
  141. def execute_model(self,
  142. seq_group_metadata_list: List[SequenceGroupMetadata],
  143. blocks_to_swap_in: Dict[int, int],
  144. blocks_to_swap_out: Dict[int, int],
  145. blocks_to_copy: Dict[int, List[int]],
  146. num_lookahead_slots: int = 0) -> List[SamplerOutput]:
  147. all_outputs = self._run_workers(
  148. "execute_model",
  149. driver_kwargs={
  150. "seq_group_metadata_list": seq_group_metadata_list,
  151. "blocks_to_swap_in": blocks_to_swap_in,
  152. "blocks_to_swap_out": blocks_to_swap_out,
  153. "blocks_to_copy": blocks_to_copy,
  154. "num_lookahead_slots": num_lookahead_slots,
  155. },
  156. use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
  157. # Only the driver worker returns the sampling results.
  158. return all_outputs[0]
  159. def _run_workers(
  160. self,
  161. method: str,
  162. *args,
  163. driver_args: Optional[Tuple[Any, ...]] = None,
  164. driver_kwargs: Optional[Dict[str, Any]] = None,
  165. all_args: Optional[List[Tuple[Any, ...]]] = None,
  166. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  167. use_dummy_driver: bool = False,
  168. max_concurrent_workers: Optional[int] = None,
  169. use_ray_compiled_dag: bool = False,
  170. **kwargs,
  171. ) -> Any:
  172. """Runs the given method on all workers. Can be used in the following
  173. ways:
  174. - args/kwargs: All workers share the same args/kwargs
  175. - args/kwargs and driver_args/driver_kwargs: Driver worker has
  176. different args
  177. - all_args/all_kwargs: args/kwargs for each worker are specified
  178. individually
  179. """
  180. if max_concurrent_workers:
  181. raise NotImplementedError(
  182. "max_concurrent_workers is not supported yet.")
  183. if driver_args is None:
  184. driver_args = args if all_args is None else all_args[0]
  185. if driver_kwargs is None:
  186. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  187. count = len(self.workers)
  188. all_worker_args = repeat(args, count) if all_args is None \
  189. else islice(all_args, 1, None)
  190. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  191. else islice(all_kwargs, 1, None)
  192. if use_ray_compiled_dag:
  193. # Right now, compiled DAG can only accept a single
  194. # input. TODO: Fix it.
  195. assert self.forward_dag is not None
  196. output_channels = self.forward_dag.execute(1)
  197. else:
  198. # Start the ray workers first.
  199. ray_worker_outputs = [
  200. worker.execute_method.remote(method, *worker_args,
  201. **worker_kwargs)
  202. for (worker, worker_args, worker_kwargs
  203. ) in zip(self.workers, all_worker_args, all_worker_kwargs)
  204. ]
  205. # Start the driver worker after all the ray workers.
  206. if not use_dummy_driver:
  207. driver_worker_output = self.driver_worker.execute_method(
  208. method, *driver_args, **driver_kwargs)
  209. else:
  210. driver_worker_output = ray.get(
  211. self.driver_dummy_worker.execute_method.remote(
  212. method, *driver_args, **driver_kwargs))
  213. # Get the results of the ray workers.
  214. if self.workers:
  215. if use_ray_compiled_dag:
  216. try:
  217. ray_worker_outputs = [
  218. pickle.loads(chan.begin_read())
  219. for chan in output_channels
  220. ]
  221. finally:
  222. # Has to call end_read in order to reuse the DAG.
  223. for chan in output_channels:
  224. chan.end_read()
  225. else:
  226. ray_worker_outputs = ray.get(ray_worker_outputs)
  227. return [driver_worker_output] + ray_worker_outputs
  228. def _compiled_ray_dag(self):
  229. import pkg_resources
  230. required_version = "2.9"
  231. current_version = pkg_resources.get_distribution("ray").version
  232. if current_version < required_version:
  233. raise ValueError(f"Ray version {required_version} or greater is "
  234. f"required, but found {current_version}")
  235. from ray.dag import InputNode, MultiOutputNode
  236. assert self.parallel_config.worker_use_ray
  237. # Right now, compiled DAG requires at least 1 arg. We send
  238. # a dummy value for now. It will be fixed soon.
  239. with InputNode() as input_data:
  240. forward_dag = MultiOutputNode([
  241. worker.execute_model_compiled_dag_remote.bind(input_data)
  242. for worker in self.workers
  243. ])
  244. return forward_dag.experimental_compile()
  245. def check_health(self) -> None:
  246. """Raises an error if engine is unhealthy."""
  247. self._check_if_any_actor_is_dead()
  248. def _check_if_any_actor_is_dead(self):
  249. if not self.workers:
  250. return
  251. dead_actors = []
  252. for actor in self.workers:
  253. actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
  254. if actor_state["State"] == "DEAD":
  255. dead_actors.append(actor)
  256. if dead_actors:
  257. raise RuntimeError("At least one Worker is dead. "
  258. f"Dead Workers: {dead_actors}. ")
  259. class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
  260. def __init__(self, *args, **kwargs):
  261. super().__init__(*args, **kwargs)
  262. self.driver_executor = make_async(self.driver_worker.execute_method)
  263. async def _run_workers_async(
  264. self,
  265. method: str,
  266. *args,
  267. driver_args: Optional[Tuple[Any, ...]] = None,
  268. driver_kwargs: Optional[Dict[str, Any]] = None,
  269. **kwargs,
  270. ) -> Any:
  271. """Runs the given method on all workers."""
  272. coros = []
  273. if driver_args is None:
  274. driver_args = args
  275. if driver_kwargs is None:
  276. driver_kwargs = kwargs
  277. coros.append(
  278. self.driver_executor(method, *driver_args, **driver_kwargs))
  279. # Run the ray workers asynchronously.
  280. for worker in self.workers:
  281. coros.append(worker.execute_method.remote(method, *args, **kwargs))
  282. all_outputs = await asyncio.gather(*coros)
  283. return all_outputs