ray_gpu_executor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import asyncio
  2. import os
  3. import pickle
  4. from collections import defaultdict
  5. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
  6. from loguru import logger
  7. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  8. from aphrodite.common.utils import (get_distributed_init_method, get_ip,
  9. get_open_port, make_async)
  10. from aphrodite.engine.ray_tools import RayWorkerWrapper, ray
  11. from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
  12. from aphrodite.lora.request import LoRARequest
  13. if ray is not None:
  14. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  15. if TYPE_CHECKING:
  16. from ray.util.placement_group import PlacementGroup
  17. # If the env var is set, it uses the Ray's compiled DAG API
  18. # which optimizes the control plane overhead.
  19. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  20. USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
  21. class RayGPUExecutor(ExecutorBase):
  22. def _init_executor(self) -> None:
  23. assert (not self.speculative_config
  24. ), "Speculative decoding not yet supported for RayGPU backend."
  25. assert self.parallel_config.worker_use_ray
  26. placement_group = self.parallel_config.placement_group
  27. # Disable Ray usage stats collection.
  28. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  29. if ray_usage != "1":
  30. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  31. # Create the parallel GPU workers.
  32. self._init_workers_ray(placement_group)
  33. self.forward_dag = None
  34. if USE_RAY_COMPILED_DAG:
  35. self.forward_dag = self._compiled_ray_dag()
  36. def _configure_ray_workers_use_nsight(self,
  37. ray_remote_kwargs) -> Dict[str, Any]:
  38. # If nsight profiling is enabled, we need to set the profiling
  39. # configuration for the ray workers as runtime env.
  40. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
  41. runtime_env.update({
  42. "nsight": {
  43. "t": "cuda,cudnn,cublas",
  44. "o": "'worker_process_%p'",
  45. "cuda-graph-trace": "node",
  46. }
  47. })
  48. return ray_remote_kwargs
  49. def _init_workers_ray(self, placement_group: "PlacementGroup",
  50. **ray_remote_kwargs):
  51. if self.parallel_config.tensor_parallel_size == 1:
  52. # For single GPU case, we use a ray worker with constrained memory.
  53. num_gpus = self.cache_config.gpu_memory_utilization
  54. else:
  55. # Otherwise, the ray workers are allocated with a full GPU.
  56. num_gpus = 1
  57. # The driver dummy worker does not actually use any resources.
  58. # It holds the resource for the driver worker.
  59. self.driver_dummy_worker: RayWorkerWrapper = None
  60. # The remaining workers are the actual ray actors.
  61. self.workers: List[RayWorkerWrapper] = []
  62. if self.parallel_config.ray_workers_use_nsight:
  63. ray_remote_kwargs = self._configure_ray_workers_use_nsight(
  64. ray_remote_kwargs)
  65. # Create the workers.
  66. driver_ip = get_ip()
  67. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  68. if not bundle.get("GPU", 0):
  69. continue
  70. scheduling_strategy = PlacementGroupSchedulingStrategy(
  71. placement_group=placement_group,
  72. placement_group_capture_child_tasks=True,
  73. placement_group_bundle_index=bundle_id,
  74. )
  75. worker = ray.remote(
  76. num_cpus=0,
  77. num_gpus=num_gpus,
  78. scheduling_strategy=scheduling_strategy,
  79. **ray_remote_kwargs,
  80. )(RayWorkerWrapper).remote(
  81. worker_module_name="aphrodite.task_handler.worker",
  82. worker_class_name="Worker",
  83. )
  84. worker_ip = ray.get(worker.get_node_ip.remote())
  85. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  86. # If the worker is on the same node as the driver, we use it
  87. # as the resource holder for the driver process.
  88. self.driver_dummy_worker = worker
  89. self.driver_worker = RayWorkerWrapper(
  90. worker_module_name="aphrodite.task_handler.worker",
  91. worker_class_name="Worker",
  92. )
  93. else:
  94. # Else, added to the list of workers.
  95. self.workers.append(worker)
  96. if self.driver_dummy_worker is None:
  97. raise ValueError(
  98. "Ray does not allocate any GPUs on the driver node. Consider "
  99. "adjusting the Ray placement group or running the driver on a "
  100. "GPU node.")
  101. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  102. use_dummy_driver=True)
  103. node_workers = defaultdict(list)
  104. node_gpus = defaultdict(list)
  105. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
  106. node_workers[node_id].append(i)
  107. node_gpus[node_id].extend(gpu_ids)
  108. for node_id, gpu_ids in node_gpus.items():
  109. node_gpus[node_id] = sorted(gpu_ids)
  110. # Set CUDA_VISIBLE_DEVICES for the driver and workers.
  111. all_args_to_update_environment_variables = []
  112. for (node_id, _) in worker_node_and_gpu_ids:
  113. all_args_to_update_environment_variables.append([{
  114. "CUDA_VISIBLE_DEVICES":
  115. ",".join(map(str, node_gpus[node_id])),
  116. }])
  117. self._run_workers("update_environment_variables",
  118. all_args=all_args_to_update_environment_variables)
  119. distributed_init_method = get_distributed_init_method(
  120. driver_ip, get_open_port())
  121. def collect_arg_helper_func(**kwargs):
  122. return kwargs
  123. init_worker_all_kwargs = []
  124. for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
  125. local_rank = node_workers[node_id].index(rank)
  126. init_worker_all_kwargs.append(
  127. collect_arg_helper_func(
  128. model_config=self.model_config,
  129. parallel_config=self.parallel_config,
  130. scheduler_config=self.scheduler_config,
  131. device_config=self.device_config,
  132. cache_config=self.cache_config,
  133. local_rank=local_rank,
  134. rank=rank,
  135. distributed_init_method=distributed_init_method,
  136. lora_config=self.lora_config,
  137. vision_language_config=self.vision_language_config,
  138. ))
  139. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  140. self._run_workers("init_device")
  141. self._run_workers(
  142. "load_model",
  143. max_concurrent_workers=self.parallel_config.
  144. max_parallel_loading_workers,
  145. )
  146. def determine_num_available_blocks(self) -> tuple[int, int]:
  147. """Determine the number of available KV blocks.
  148. This invokes `determine_num_available_blocks` on each worker and takes
  149. the min of the results, guaranteeing that the selected cache sizes are
  150. compatible with all workers.
  151. Returns:
  152. - tuple[num_gpu_blocks, num_cpu_blocks]
  153. """
  154. # Get the maximum number of blocks that can be allocated on GPU and CPU.
  155. num_blocks = self._run_workers("determine_num_available_blocks", )
  156. # Since we use a shared centralized controller, we take the minimum
  157. # number of blocks across all workers to make sure all the memory
  158. # operators can be applied to all workers.
  159. num_gpu_blocks = min(b[0] for b in num_blocks)
  160. num_cpu_blocks = min(b[1] for b in num_blocks)
  161. return num_gpu_blocks, num_cpu_blocks
  162. def initialize_cache(self, num_gpu_blocks: int,
  163. num_cpu_blocks: int) -> None:
  164. """Initialize the KV cache in all workers.
  165. """
  166. # NOTE: We log here to avoid multiple logs when number of workers is
  167. # greater than one. We could log in the engine, but not all executors
  168. # have GPUs.
  169. logger.info(f"# GPU blocks: {num_gpu_blocks}, "
  170. f"# CPU blocks: {num_cpu_blocks}")
  171. logger.info(
  172. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  173. )
  174. self.cache_config.num_gpu_blocks = num_gpu_blocks
  175. self.cache_config.num_cpu_blocks = num_cpu_blocks
  176. self._run_workers("initialize_cache",
  177. num_gpu_blocks=num_gpu_blocks,
  178. num_cpu_blocks=num_cpu_blocks)
  179. def execute_model(self,
  180. seq_group_metadata_list: List[SequenceGroupMetadata],
  181. blocks_to_swap_in: Dict[int, int],
  182. blocks_to_swap_out: Dict[int, int],
  183. blocks_to_copy: Dict[int, List[int]],
  184. num_lookahead_slots: int = 0) -> List[SamplerOutput]:
  185. all_outputs = self._run_workers(
  186. "execute_model",
  187. driver_kwargs={
  188. "seq_group_metadata_list": seq_group_metadata_list,
  189. "blocks_to_swap_in": blocks_to_swap_in,
  190. "blocks_to_swap_out": blocks_to_swap_out,
  191. "blocks_to_copy": blocks_to_copy,
  192. },
  193. use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
  194. # Only the driver worker returns the sampling results.
  195. output = all_outputs[0]
  196. return output
  197. def add_lora(self, lora_request: LoRARequest) -> bool:
  198. assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
  199. return self._run_workers(
  200. "add_lora",
  201. lora_request=lora_request,
  202. )
  203. def remove_lora(self, lora_id: int) -> bool:
  204. assert lora_id > 0, "lora_id must be greater than 0."
  205. return self._run_workers(
  206. "remove_lora",
  207. lora_id=lora_id,
  208. )
  209. def list_loras(self) -> Set[int]:
  210. return self._run_workers("list_loras")
  211. def _run_workers(
  212. self,
  213. method: str,
  214. *args,
  215. driver_args: Optional[Tuple[Any]] = None,
  216. driver_kwargs: Optional[Dict[str, Any]] = None,
  217. all_args: Optional[List[List[Any]]] = None,
  218. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  219. use_dummy_driver: bool = False,
  220. max_concurrent_workers: Optional[int] = None,
  221. use_ray_compiled_dag: bool = False,
  222. **kwargs,
  223. ) -> Any:
  224. """Runs the given method on all workers."""
  225. if driver_args is None:
  226. driver_args = args
  227. if driver_kwargs is None:
  228. driver_kwargs = kwargs
  229. # For MyPy type checking
  230. assert driver_args is not None
  231. assert driver_kwargs is not None
  232. if all_args is None:
  233. all_args = [driver_args] + [args] * len(self.workers)
  234. if all_kwargs is None:
  235. all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers)
  236. assert all_args is not None
  237. assert all_kwargs is not None
  238. if max_concurrent_workers:
  239. raise NotImplementedError(
  240. "max_concurrent_workers is not supported yet.")
  241. if use_ray_compiled_dag:
  242. # Right now, compiled DAG can only accept a single
  243. # input. TODO: Fix it.
  244. output_channels = self.forward_dag.execute(1)
  245. else:
  246. # Start the ray workers first.
  247. ray_worker_outputs = [
  248. worker.execute_method.remote(method, *worker_args,
  249. **worker_kwargs)
  250. for (worker, worker_args, worker_kwargs
  251. ) in zip(self.workers, all_args[1:], all_kwargs[1:])
  252. ]
  253. if driver_args is None:
  254. driver_args = args
  255. if driver_kwargs is None:
  256. driver_kwargs = kwargs
  257. # Start the driver worker after all the ray workers.
  258. if not use_dummy_driver:
  259. driver_worker_output = self.driver_worker.execute_method(
  260. method, *all_args[0], **all_kwargs[0])
  261. else:
  262. driver_worker_output = ray.get(
  263. self.driver_dummy_worker.execute_method.remote(
  264. method, *all_args[0], **all_kwargs[0]))
  265. # Get the results of the ray workers.
  266. if self.workers:
  267. if use_ray_compiled_dag:
  268. try:
  269. ray_worker_outputs = [
  270. pickle.loads(chan.begin_read())
  271. for chan in output_channels
  272. ]
  273. finally:
  274. # Has to call end_read in order to reuse the DAG.
  275. for chan in output_channels:
  276. chan.end_read()
  277. else:
  278. ray_worker_outputs = ray.get(ray_worker_outputs)
  279. return [driver_worker_output] + ray_worker_outputs
  280. def _compiled_ray_dag(self):
  281. import pkg_resources
  282. required_version = "2.9"
  283. current_version = pkg_resources.get_distribution("ray").version
  284. if current_version < required_version:
  285. raise ValueError(f"Ray version {required_version} or greater is "
  286. f"required, but found {current_version}")
  287. from ray.dag import InputNode, MultiOutputNode
  288. assert self.parallel_config.worker_use_ray
  289. # Right now, compiled DAG requires at least 1 arg. We send
  290. # a dummy value for now. It will be fixed soon.
  291. with InputNode() as input_data:
  292. forward_dag = MultiOutputNode([
  293. worker.execute_model_compiled_dag_remote.bind(input_data)
  294. for worker in self.workers
  295. ])
  296. return forward_dag.experimental_compile()
  297. def check_health(self) -> None:
  298. """Raises an error if engine is unhealthy."""
  299. self._check_if_any_actor_is_dead()
  300. def _check_if_any_actor_is_dead(self):
  301. if not self.workers:
  302. return
  303. dead_actors = []
  304. for actor in self.workers:
  305. actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
  306. if actor_state["State"] == "DEAD":
  307. dead_actors.append(actor)
  308. if dead_actors:
  309. raise RuntimeError("At least one Worker is dead. "
  310. f"Dead Workers: {dead_actors}. ")
  311. class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
  312. async def _run_workers_async(
  313. self,
  314. method: str,
  315. *args,
  316. driver_args: Optional[List[Any]] = None,
  317. driver_kwargs: Optional[Dict[str, Any]] = None,
  318. **kwargs,
  319. ) -> Any:
  320. """Runs the given method on all workers."""
  321. coros = []
  322. if driver_args is None:
  323. driver_args = args
  324. if driver_kwargs is None:
  325. driver_kwargs = kwargs
  326. # Run the driver worker asynchronously.
  327. def helper():
  328. return self.driver_worker.execute_method(method, *driver_args,
  329. **driver_kwargs)
  330. driver_executor = make_async(helper)
  331. coros.append(driver_executor())
  332. # Run the ray workers asynchronously.
  333. for worker in self.workers:
  334. coros.append(worker.execute_method.remote(method, *args, **kwargs))
  335. all_outputs = await asyncio.gather(*coros)
  336. return all_outputs
  337. async def execute_model_async(
  338. self,
  339. seq_group_metadata_list: List[SequenceGroupMetadata],
  340. blocks_to_swap_in: Dict[int, int],
  341. blocks_to_swap_out: Dict[int, int],
  342. blocks_to_copy: Dict[int, List[int]],
  343. num_lookahead_slots: int = 0,
  344. ) -> SamplerOutput:
  345. all_outputs = await self._run_workers_async(
  346. "execute_model",
  347. driver_kwargs={
  348. "seq_group_metadata_list": seq_group_metadata_list,
  349. "blocks_to_swap_in": blocks_to_swap_in,
  350. "blocks_to_swap_out": blocks_to_swap_out,
  351. "blocks_to_copy": blocks_to_copy,
  352. "num_lookahead_slots": num_lookahead_slots,
  353. })
  354. # Only the driver worker returns the sampling results.
  355. output = all_outputs[0]
  356. return output