ray_gpu_executor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. import asyncio
  2. import copy
  3. from collections import defaultdict
  4. import os
  5. import pickle
  6. from typing import TYPE_CHECKING, Any, Dict, List, Set, Optional
  7. from loguru import logger
  8. from aphrodite.engine.ray_tools import RayWorkerAphrodite, ray
  9. from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
  10. from aphrodite.lora.request import LoRARequest
  11. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  12. from aphrodite.common.utils import (set_cuda_visible_devices, get_ip,
  13. get_open_port, get_distributed_init_method,
  14. make_async)
  15. if ray is not None:
  16. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  17. if TYPE_CHECKING:
  18. from ray.util.placement_group import PlacementGroup
  19. # If the env var is set, it uses the Ray's compiled DAG API
  20. # which optimizes the control plane overhead.
  21. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  22. USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
  23. class RayGPUExecutor(ExecutorBase):
  24. def _init_executor(self) -> None:
  25. assert (not self.speculative_config
  26. ), "Speculative decoding not yet supported for RayGPU backend."
  27. assert self.parallel_config.worker_use_ray
  28. placement_group = self.parallel_config.placement_group
  29. # Disable Ray usage stats collection.
  30. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  31. if ray_usage != "1":
  32. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  33. # Create the parallel GPU workers.
  34. self._init_workers_ray(placement_group)
  35. self.forward_dag = None
  36. if USE_RAY_COMPILED_DAG:
  37. self.forward_dag = self._compiled_ray_dag()
  38. def _configure_ray_workers_use_nsight(self,
  39. ray_remote_kwargs) -> Dict[str, Any]:
  40. # If nsight profiling is enabled, we need to set the profiling
  41. # configuration for the ray workers as runtime env.
  42. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
  43. runtime_env.update({
  44. "nsight": {
  45. "t": "cuda,cudnn,cublas",
  46. "o": "'worker_process_%p'",
  47. "cuda-graph-trace": "node",
  48. }
  49. })
  50. return ray_remote_kwargs
  51. def _init_workers_ray(self, placement_group: "PlacementGroup",
  52. **ray_remote_kwargs):
  53. if self.parallel_config.tensor_parallel_size == 1:
  54. # For single GPU case, we use a ray worker with constrained memory.
  55. num_gpus = self.cache_config.gpu_memory_utilization
  56. else:
  57. # Otherwise, the ray workers are allocated with a full GPU.
  58. num_gpus = 1
  59. # The driver dummy worker does not actually use any resources.
  60. # It holds the resource for the driver worker.
  61. self.driver_dummy_worker: RayWorkerAphrodite = None
  62. # The remaining workers are the actual ray actors.
  63. self.workers: List[RayWorkerAphrodite] = []
  64. if self.parallel_config.ray_workers_use_nsight:
  65. ray_remote_kwargs = self._configure_ray_workers_use_nsight(
  66. ray_remote_kwargs)
  67. # Create the workers.
  68. driver_ip = get_ip()
  69. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  70. if not bundle.get("GPU", 0):
  71. continue
  72. scheduling_strategy = PlacementGroupSchedulingStrategy(
  73. placement_group=placement_group,
  74. placement_group_capture_child_tasks=True,
  75. placement_group_bundle_index=bundle_id,
  76. )
  77. worker = ray.remote(
  78. num_cpus=0,
  79. num_gpus=num_gpus,
  80. scheduling_strategy=scheduling_strategy,
  81. **ray_remote_kwargs,
  82. )(RayWorkerAphrodite).remote(self.model_config.trust_remote_code)
  83. worker_ip = ray.get(worker.get_node_ip.remote())
  84. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  85. # If the worker is on the same node as the driver, we use it
  86. # as the resource holder for the driver process.
  87. self.driver_dummy_worker = worker
  88. else:
  89. # Else, added to the list of workers.
  90. self.workers.append(worker)
  91. if self.driver_dummy_worker is None:
  92. raise ValueError(
  93. "Ray does not allocate any GPUs on the driver node. Consider "
  94. "adjusting the Ray placement group or running the driver on a "
  95. "GPU node.")
  96. # Get the set of GPU IDs used on each node.
  97. driver_node_id, driver_gpu_ids = ray.get(
  98. self.driver_dummy_worker.get_node_and_gpu_ids.remote())
  99. worker_node_and_gpu_ids = ray.get(
  100. [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
  101. node_workers = defaultdict(list)
  102. node_gpus = defaultdict(list)
  103. node_workers[driver_node_id].append(0)
  104. node_gpus[driver_node_id].extend(driver_gpu_ids)
  105. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
  106. start=1):
  107. node_workers[node_id].append(i)
  108. node_gpus[node_id].extend(gpu_ids)
  109. for node_id, gpu_ids in node_gpus.items():
  110. node_gpus[node_id] = sorted(gpu_ids)
  111. # Set CUDA_VISIBLE_DEVICES for the driver and workers.
  112. set_cuda_visible_devices(node_gpus[driver_node_id])
  113. for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
  114. worker.set_cuda_visible_devices.remote(node_gpus[node_id])
  115. distributed_init_method = get_distributed_init_method(
  116. driver_ip, get_open_port())
  117. # Lazy import the Worker to avoid importing torch.cuda/xformers
  118. # before CUDA_VISIBLE_DEVICES is set in the Worker
  119. from aphrodite.task_handler.worker import Worker
  120. model_config = copy.deepcopy(self.model_config)
  121. parallel_config = copy.deepcopy(self.parallel_config)
  122. scheduler_config = copy.deepcopy(self.scheduler_config)
  123. device_config = copy.deepcopy(self.device_config)
  124. lora_config = copy.deepcopy(self.lora_config)
  125. cache_config = copy.deepcopy(self.cache_config)
  126. vision_language_config = copy.deepcopy(self.vision_language_config)
  127. # Initialize the actual workers with the Worker class.
  128. for rank, (worker, (node_id, _)) in enumerate(
  129. zip(self.workers, worker_node_and_gpu_ids),
  130. start=1,
  131. ):
  132. local_rank = node_workers[node_id].index(rank)
  133. worker.init_worker.remote(
  134. lambda rank=rank, local_rank=local_rank: Worker(
  135. model_config=model_config,
  136. parallel_config=parallel_config,
  137. scheduler_config=scheduler_config,
  138. device_config=device_config,
  139. cache_config=cache_config,
  140. local_rank=local_rank,
  141. rank=rank,
  142. distributed_init_method=distributed_init_method,
  143. lora_config=lora_config,
  144. vision_language_config=vision_language_config,
  145. ))
  146. # Initialize the driver worker with the Worker class.
  147. driver_rank = 0
  148. driver_local_rank = node_workers[driver_node_id].index(driver_rank)
  149. self.driver_worker = Worker(
  150. model_config=self.model_config,
  151. parallel_config=self.parallel_config,
  152. scheduler_config=self.scheduler_config,
  153. device_config=self.device_config,
  154. cache_config=self.cache_config,
  155. local_rank=driver_local_rank,
  156. rank=driver_rank,
  157. distributed_init_method=distributed_init_method,
  158. lora_config=self.lora_config,
  159. vision_language_config=self.vision_language_config,
  160. is_driver_worker=True,
  161. )
  162. self._run_workers("init_device")
  163. self._run_workers(
  164. "load_model",
  165. max_concurrent_workers=self.parallel_config.
  166. max_parallel_loading_workers,
  167. )
  168. def determine_num_available_blocks(self) -> tuple[int, int]:
  169. """Determine the number of available KV blocks.
  170. This invokes `determine_num_available_blocks` on each worker and takes
  171. the min of the results, guaranteeing that the selected cache sizes are
  172. compatible with all workers.
  173. Returns:
  174. - tuple[num_gpu_blocks, num_cpu_blocks]
  175. """
  176. # Get the maximum number of blocks that can be allocated on GPU and CPU.
  177. num_blocks = self._run_workers("determine_num_available_blocks", )
  178. # Since we use a shared centralized controller, we take the minimum
  179. # number of blocks across all workers to make sure all the memory
  180. # operators can be applied to all workers.
  181. num_gpu_blocks = min(b[0] for b in num_blocks)
  182. num_cpu_blocks = min(b[1] for b in num_blocks)
  183. return num_gpu_blocks, num_cpu_blocks
  184. def initialize_cache(self, num_gpu_blocks: int,
  185. num_cpu_blocks: int) -> None:
  186. """Initialize the KV cache in all workers.
  187. """
  188. # NOTE: We log here to avoid multiple logs when number of workers is
  189. # greater than one. We could log in the engine, but not all executors
  190. # have GPUs.
  191. logger.info(f"# GPU blocks: {num_gpu_blocks}, "
  192. f"# CPU blocks: {num_cpu_blocks}")
  193. logger.info(
  194. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  195. )
  196. self.cache_config.num_gpu_blocks = num_gpu_blocks
  197. self.cache_config.num_cpu_blocks = num_cpu_blocks
  198. self._run_workers("initialize_cache",
  199. num_gpu_blocks=num_gpu_blocks,
  200. num_cpu_blocks=num_cpu_blocks)
  201. def execute_model(self,
  202. seq_group_metadata_list: List[SequenceGroupMetadata],
  203. blocks_to_swap_in: Dict[int, int],
  204. blocks_to_swap_out: Dict[int, int],
  205. blocks_to_copy: Dict[int, List[int]],
  206. num_lookahead_slots: int = 0) -> List[SamplerOutput]:
  207. all_outputs = self._run_workers(
  208. "execute_model",
  209. driver_kwargs={
  210. "seq_group_metadata_list": seq_group_metadata_list,
  211. "blocks_to_swap_in": blocks_to_swap_in,
  212. "blocks_to_swap_out": blocks_to_swap_out,
  213. "blocks_to_copy": blocks_to_copy,
  214. },
  215. use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
  216. # Only the driver worker returns the sampling results.
  217. output = all_outputs[0]
  218. return output
  219. def add_lora(self, lora_request: LoRARequest) -> bool:
  220. assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
  221. return self._run_workers(
  222. "add_lora",
  223. lora_request=lora_request,
  224. )
  225. def remove_lora(self, lora_id: int) -> bool:
  226. assert lora_id > 0, "lora_id must be greater than 0."
  227. return self._run_workers(
  228. "remove_lora",
  229. lora_id=lora_id,
  230. )
  231. def list_loras(self) -> Set[int]:
  232. return self._run_workers("list_loras")
  233. def _run_workers(
  234. self,
  235. method: str,
  236. *args,
  237. driver_args: Optional[List[Any]] = None,
  238. driver_kwargs: Optional[Dict[str, Any]] = None,
  239. max_concurrent_workers: Optional[int] = None,
  240. use_ray_compiled_dag: bool = False,
  241. **kwargs,
  242. ) -> Any:
  243. """Runs the given method on all workers."""
  244. if max_concurrent_workers:
  245. raise NotImplementedError(
  246. "max_concurrent_workers is not supported yet.")
  247. if use_ray_compiled_dag:
  248. # Right now, compiled DAG can only accept a single
  249. # input. TODO: Fix it.
  250. output_channels = self.forward_dag.execute(1)
  251. else:
  252. # Start the ray workers first.
  253. ray_worker_outputs = [
  254. worker.execute_method.remote(method, *args, **kwargs)
  255. for worker in self.workers
  256. ]
  257. if driver_args is None:
  258. driver_args = args
  259. if driver_kwargs is None:
  260. driver_kwargs = kwargs
  261. # Start the driver worker after all the ray workers.
  262. driver_worker_output = getattr(self.driver_worker,
  263. method)(*driver_args, **driver_kwargs)
  264. # Get the results of the ray workers.
  265. if self.workers:
  266. if use_ray_compiled_dag:
  267. try:
  268. ray_worker_outputs = [
  269. pickle.loads(chan.begin_read())
  270. for chan in output_channels
  271. ]
  272. finally:
  273. # Has to call end_read in order to reuse the DAG.
  274. for chan in output_channels:
  275. chan.end_read()
  276. else:
  277. ray_worker_outputs = ray.get(ray_worker_outputs)
  278. return [driver_worker_output] + ray_worker_outputs
  279. def _compiled_ray_dag(self):
  280. import pkg_resources
  281. required_version = "2.9"
  282. current_version = pkg_resources.get_distribution("ray").version
  283. if current_version < required_version:
  284. raise ValueError(f"Ray version {required_version} or greater is "
  285. f"required, but found {current_version}")
  286. from ray.dag import MultiOutputNode, InputNode
  287. assert self.parallel_config.worker_use_ray
  288. # Right now, compiled DAG requires at least 1 arg. We send
  289. # a dummy value for now. It will be fixed soon.
  290. with InputNode() as input_data:
  291. forward_dag = MultiOutputNode([
  292. worker.execute_model_compiled_dag_remote.bind(input_data)
  293. for worker in self.workers
  294. ])
  295. return forward_dag.experimental_compile()
  296. def check_health(self) -> None:
  297. """Raises an error if engine is unhealthy."""
  298. self._check_if_any_actor_is_dead()
  299. def _check_if_any_actor_is_dead(self):
  300. if not self.workers:
  301. return
  302. dead_actors = []
  303. for actor in self.workers:
  304. actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
  305. if actor_state["State"] == "DEAD":
  306. dead_actors.append(actor)
  307. if dead_actors:
  308. raise RuntimeError("At least one Worker is dead. "
  309. f"Dead Workers: {dead_actors}. ")
  310. class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
  311. async def _run_workers_async(
  312. self,
  313. method: str,
  314. *args,
  315. driver_args: Optional[List[Any]] = None,
  316. driver_kwargs: Optional[Dict[str, Any]] = None,
  317. **kwargs,
  318. ) -> Any:
  319. """Runs the given method on all workers."""
  320. coros = []
  321. if driver_args is None:
  322. driver_args = args
  323. if driver_kwargs is None:
  324. driver_kwargs = kwargs
  325. # Run the driver worker asynchronously.
  326. driver_executor = make_async(getattr(self.driver_worker, method))
  327. coros.append(driver_executor(*driver_args, **driver_kwargs))
  328. # Run the ray workers asynchronously.
  329. for worker in self.workers:
  330. coros.append(worker.execute_method.remote(method, *args, **kwargs))
  331. all_outputs = await asyncio.gather(*coros)
  332. return all_outputs
  333. async def execute_model_async(
  334. self,
  335. seq_group_metadata_list: List[SequenceGroupMetadata],
  336. blocks_to_swap_in: Dict[int, int],
  337. blocks_to_swap_out: Dict[int, int],
  338. blocks_to_copy: Dict[int, List[int]],
  339. num_lookahead_slots: int = 0,
  340. ) -> SamplerOutput:
  341. all_outputs = await self._run_workers_async(
  342. "execute_model",
  343. driver_kwargs={
  344. "seq_group_metadata_list": seq_group_metadata_list,
  345. "blocks_to_swap_in": blocks_to_swap_in,
  346. "blocks_to_swap_out": blocks_to_swap_out,
  347. "blocks_to_copy": blocks_to_copy,
  348. "num_lookahead_slots": num_lookahead_slots,
  349. })
  350. # Only the driver worker returns the sampling results.
  351. output = all_outputs[0]
  352. return output