ray_gpu_executor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import asyncio
  2. import copy
  3. from collections import defaultdict
  4. import os
  5. import pickle
  6. from typing import TYPE_CHECKING, Any, Dict, List, Optional
  7. from loguru import logger
  8. from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
  9. ParallelConfig, SchedulerConfig,
  10. LoRAConfig, VisionLanguageConfig,
  11. SpeculativeConfig)
  12. from aphrodite.engine.ray_tools import RayWorkerAphrodite, ray
  13. from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
  14. from aphrodite.lora.request import LoRARequest
  15. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  16. from aphrodite.common.utils import (set_cuda_visible_devices, get_ip,
  17. get_open_port, get_distributed_init_method,
  18. make_async)
  19. if ray is not None:
  20. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  21. if TYPE_CHECKING:
  22. from ray.util.placement_group import PlacementGroup
  23. # If the env var is set, it uses the Ray's compiled DAG API
  24. # which optimizes the control plane overhead.
  25. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  26. USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
  27. class RayGPUExecutor(ExecutorBase):
  28. def __init__(
  29. self,
  30. model_config: ModelConfig,
  31. cache_config: CacheConfig,
  32. parallel_config: ParallelConfig,
  33. scheduler_config: SchedulerConfig,
  34. device_config: DeviceConfig,
  35. lora_config: Optional[LoRAConfig],
  36. vision_language_config: Optional[VisionLanguageConfig],
  37. speculative_config: Optional[SpeculativeConfig],
  38. ) -> None:
  39. self.model_config = model_config
  40. self.cache_config = cache_config
  41. self.lora_config = lora_config
  42. self.parallel_config = parallel_config
  43. self.scheduler_config = scheduler_config
  44. self.device_config = device_config
  45. self.vision_language_config = vision_language_config
  46. assert (not speculative_config
  47. ), "Speculative decoding not yet supported for RayGPU backend."
  48. assert self.parallel_config.worker_use_ray
  49. placement_group = self.parallel_config.placement_group
  50. # Disable Ray usage stats collection.
  51. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  52. if ray_usage != "1":
  53. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  54. # Create the parallel GPU workers.
  55. self._init_workers_ray(placement_group)
  56. self.forward_dag = None
  57. if USE_RAY_COMPILED_DAG:
  58. self.forward_dag = self._compiled_ray_dag()
  59. def _init_workers_ray(self, placement_group: "PlacementGroup",
  60. **ray_remote_kwargs):
  61. if self.parallel_config.tensor_parallel_size == 1:
  62. # For single GPU case, we use a ray worker with constrained memory.
  63. num_gpus = self.cache_config.gpu_memory_utilization
  64. else:
  65. # Otherwise, the ray workers are allocated with a full GPU.
  66. num_gpus = 1
  67. # The driver dummy worker does not actually use any resources.
  68. # It holds the resource for the driver worker.
  69. self.driver_dummy_worker: RayWorkerAphrodite = None
  70. # The remaining workers are the actual ray actors.
  71. self.workers: List[RayWorkerAphrodite] = []
  72. # Create the workers.
  73. driver_ip = get_ip()
  74. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  75. if not bundle.get("GPU", 0):
  76. continue
  77. scheduling_strategy = PlacementGroupSchedulingStrategy(
  78. placement_group=placement_group,
  79. placement_group_capture_child_tasks=True,
  80. placement_group_bundle_index=bundle_id,
  81. )
  82. worker = ray.remote(
  83. num_cpus=0,
  84. num_gpus=num_gpus,
  85. scheduling_strategy=scheduling_strategy,
  86. **ray_remote_kwargs,
  87. )(RayWorkerAphrodite).remote(self.model_config.trust_remote_code)
  88. worker_ip = ray.get(worker.get_node_ip.remote())
  89. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  90. # If the worker is on the same node as the driver, we use it
  91. # as the resource holder for the driver process.
  92. self.driver_dummy_worker = worker
  93. else:
  94. # Else, added to the list of workers.
  95. self.workers.append(worker)
  96. if self.driver_dummy_worker is None:
  97. raise ValueError(
  98. "Ray does not allocate any GPUs on the driver node. Consider "
  99. "adjusting the Ray placement group or running the driver on a "
  100. "GPU node.")
  101. # Get the set of GPU IDs used on each node.
  102. driver_node_id, driver_gpu_ids = ray.get(
  103. self.driver_dummy_worker.get_node_and_gpu_ids.remote())
  104. worker_node_and_gpu_ids = ray.get(
  105. [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
  106. node_workers = defaultdict(list)
  107. node_gpus = defaultdict(list)
  108. node_workers[driver_node_id].append(0)
  109. node_gpus[driver_node_id].extend(driver_gpu_ids)
  110. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
  111. start=1):
  112. node_workers[node_id].append(i)
  113. node_gpus[node_id].extend(gpu_ids)
  114. for node_id, gpu_ids in node_gpus.items():
  115. node_gpus[node_id] = sorted(gpu_ids)
  116. # Set CUDA_VISIBLE_DEVICES for the driver and workers.
  117. set_cuda_visible_devices(node_gpus[driver_node_id])
  118. for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
  119. worker.set_cuda_visible_devices.remote(node_gpus[node_id])
  120. distributed_init_method = get_distributed_init_method(
  121. driver_ip, get_open_port())
  122. # Lazy import the Worker to avoid importing torch.cuda/xformers
  123. # before CUDA_VISIBLE_DEVICES is set in the Worker
  124. from aphrodite.task_handler.worker import Worker
  125. model_config = copy.deepcopy(self.model_config)
  126. parallel_config = copy.deepcopy(self.parallel_config)
  127. scheduler_config = copy.deepcopy(self.scheduler_config)
  128. device_config = copy.deepcopy(self.device_config)
  129. lora_config = copy.deepcopy(self.lora_config)
  130. cache_config = copy.deepcopy(self.cache_config)
  131. vision_language_config = copy.deepcopy(self.vision_language_config)
  132. # Initialize the actual workers with the Worker class.
  133. for rank, (worker, (node_id, _)) in enumerate(
  134. zip(self.workers, worker_node_and_gpu_ids),
  135. start=1,
  136. ):
  137. local_rank = node_workers[node_id].index(rank)
  138. worker.init_worker.remote(
  139. lambda rank=rank, local_rank=local_rank: Worker(
  140. model_config=model_config,
  141. parallel_config=parallel_config,
  142. scheduler_config=scheduler_config,
  143. device_config=device_config,
  144. cache_config=cache_config,
  145. local_rank=local_rank,
  146. rank=rank,
  147. distributed_init_method=distributed_init_method,
  148. lora_config=lora_config,
  149. vision_language_config=vision_language_config,
  150. ))
  151. # Initialize the driver worker with the Worker class.
  152. driver_rank = 0
  153. driver_local_rank = node_workers[driver_node_id].index(driver_rank)
  154. self.driver_worker = Worker(
  155. model_config=self.model_config,
  156. parallel_config=self.parallel_config,
  157. scheduler_config=self.scheduler_config,
  158. device_config=self.device_config,
  159. cache_config=self.cache_config,
  160. local_rank=driver_local_rank,
  161. rank=driver_rank,
  162. distributed_init_method=distributed_init_method,
  163. lora_config=self.lora_config,
  164. vision_language_config=self.vision_language_config,
  165. is_driver_worker=True,
  166. )
  167. self._run_workers("init_device")
  168. self._run_workers(
  169. "load_model",
  170. max_concurrent_workers=self.parallel_config.
  171. max_parallel_loading_workers,
  172. )
  173. def determine_num_available_blocks(self) -> tuple[int, int]:
  174. """Determine the number of available KV blocks.
  175. This invokes `determine_num_available_blocks` on each worker and takes
  176. the min of the results, guaranteeing that the selected cache sizes are
  177. compatible with all workers.
  178. Returns:
  179. - tuple[num_gpu_blocks, num_cpu_blocks]
  180. """
  181. # Get the maximum number of blocks that can be allocated on GPU and CPU.
  182. num_blocks = self._run_workers("determine_num_available_blocks", )
  183. # Since we use a shared centralized controller, we take the minimum
  184. # number of blocks across all workers to make sure all the memory
  185. # operators can be applied to all workers.
  186. num_gpu_blocks = min(b[0] for b in num_blocks)
  187. num_cpu_blocks = min(b[1] for b in num_blocks)
  188. return num_gpu_blocks, num_cpu_blocks
  189. def initialize_cache(self, num_gpu_blocks: int,
  190. num_cpu_blocks: int) -> None:
  191. """Initialize the KV cache in all workers.
  192. """
  193. # NOTE: We log here to avoid multiple logs when number of workers is
  194. # greater than one. We could log in the engine, but not all executors
  195. # have GPUs.
  196. logger.info(f"# GPU blocks: {num_gpu_blocks}, "
  197. f"# CPU blocks: {num_cpu_blocks}")
  198. logger.info(
  199. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  200. )
  201. self.cache_config.num_gpu_blocks = num_gpu_blocks
  202. self.cache_config.num_cpu_blocks = num_cpu_blocks
  203. self._run_workers("initialize_cache",
  204. num_gpu_blocks=num_gpu_blocks,
  205. num_cpu_blocks=num_cpu_blocks)
  206. def execute_model(self,
  207. seq_group_metadata_list: List[SequenceGroupMetadata],
  208. blocks_to_swap_in: Dict[int, int],
  209. blocks_to_swap_out: Dict[int, int],
  210. blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
  211. all_outputs = self._run_workers(
  212. "execute_model",
  213. driver_kwargs={
  214. "seq_group_metadata_list": seq_group_metadata_list,
  215. "blocks_to_swap_in": blocks_to_swap_in,
  216. "blocks_to_swap_out": blocks_to_swap_out,
  217. "blocks_to_copy": blocks_to_copy,
  218. },
  219. use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
  220. # Only the driver worker returns the sampling results.
  221. output = all_outputs[0]
  222. return output
  223. def add_lora(self, lora_request: LoRARequest) -> bool:
  224. assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
  225. return self._run_workers(
  226. "add_lora",
  227. lora_request=lora_request,
  228. )
  229. def remove_lora(self, lora_id: int) -> bool:
  230. assert lora_id > 0, "lora_id must be greater than 0."
  231. return self._run_workers(
  232. "remove_lora",
  233. lora_id=lora_id,
  234. )
  235. def list_loras(self) -> List[int]:
  236. return self._run_workers("list_loras")
  237. def _run_workers(
  238. self,
  239. method: str,
  240. *args,
  241. driver_args: Optional[List[Any]] = None,
  242. driver_kwargs: Optional[Dict[str, Any]] = None,
  243. max_concurrent_workers: Optional[int] = None,
  244. use_ray_compiled_dag: bool = False,
  245. **kwargs,
  246. ) -> Any:
  247. """Runs the given method on all workers."""
  248. if max_concurrent_workers:
  249. raise NotImplementedError(
  250. "max_concurrent_workers is not supported yet.")
  251. if use_ray_compiled_dag:
  252. # Right now, compiled DAG can only accept a single
  253. # input. TODO: Fix it.
  254. output_channels = self.forward_dag.execute(1)
  255. else:
  256. # Start the ray workers first.
  257. ray_worker_outputs = [
  258. worker.execute_method.remote(method, *args, **kwargs)
  259. for worker in self.workers
  260. ]
  261. if driver_args is None:
  262. driver_args = args
  263. if driver_kwargs is None:
  264. driver_kwargs = kwargs
  265. # Start the driver worker after all the ray workers.
  266. driver_worker_output = getattr(self.driver_worker,
  267. method)(*driver_args, **driver_kwargs)
  268. # Get the results of the ray workers.
  269. if self.workers:
  270. if use_ray_compiled_dag:
  271. try:
  272. ray_worker_outputs = [
  273. pickle.loads(chan.begin_read())
  274. for chan in output_channels
  275. ]
  276. finally:
  277. # Has to call end_read in order to reuse the DAG.
  278. for chan in output_channels:
  279. chan.end_read()
  280. else:
  281. ray_worker_outputs = ray.get(ray_worker_outputs)
  282. return [driver_worker_output] + ray_worker_outputs
  283. def _compiled_ray_dag(self):
  284. import pkg_resources
  285. required_version = "2.9"
  286. current_version = pkg_resources.get_distribution("ray").version
  287. if current_version < required_version:
  288. raise ValueError(f"Ray version {required_version} or greater is "
  289. f"required, but found {current_version}")
  290. from ray.dag import MultiOutputNode, InputNode
  291. assert self.parallel_config.worker_use_ray
  292. # Right now, compiled DAG requires at least 1 arg. We send
  293. # a dummy value for now. It will be fixed soon.
  294. with InputNode() as input_data:
  295. forward_dag = MultiOutputNode([
  296. worker.execute_model_compiled_dag_remote.bind(input_data)
  297. for worker in self.workers
  298. ])
  299. return forward_dag.experimental_compile()
  300. def check_health(self) -> None:
  301. """Raises an error if engine is unhealthy."""
  302. self._check_if_any_actor_is_dead()
  303. def _check_if_any_actor_is_dead(self):
  304. if not self.workers:
  305. return
  306. dead_actors = []
  307. for actor in self.workers:
  308. actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
  309. if actor_state["State"] == "DEAD":
  310. dead_actors.append(actor)
  311. if dead_actors:
  312. raise RuntimeError("At least one Worker is dead. "
  313. f"Dead Workers: {dead_actors}. ")
  314. class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
  315. async def _run_workers_async(
  316. self,
  317. method: str,
  318. *args,
  319. driver_args: Optional[List[Any]] = None,
  320. driver_kwargs: Optional[Dict[str, Any]] = None,
  321. **kwargs,
  322. ) -> Any:
  323. """Runs the given method on all workers."""
  324. coros = []
  325. if driver_args is None:
  326. driver_args = args
  327. if driver_kwargs is None:
  328. driver_kwargs = kwargs
  329. # Run the driver worker asynchronously.
  330. driver_executor = make_async(getattr(self.driver_worker, method))
  331. coros.append(driver_executor(*driver_args, **driver_kwargs))
  332. # Run the ray workers asynchronously.
  333. for worker in self.workers:
  334. coros.append(worker.execute_method.remote(method, *args, **kwargs))
  335. all_outputs = await asyncio.gather(*coros)
  336. return all_outputs
  337. async def execute_model_async(
  338. self,
  339. seq_group_metadata_list: List[SequenceGroupMetadata],
  340. blocks_to_swap_in: Dict[int, int],
  341. blocks_to_swap_out: Dict[int, int],
  342. blocks_to_copy: Dict[int, List[int]],
  343. ) -> SamplerOutput:
  344. all_outputs = await self._run_workers_async(
  345. "execute_model",
  346. driver_kwargs={
  347. "seq_group_metadata_list": seq_group_metadata_list,
  348. "blocks_to_swap_in": blocks_to_swap_in,
  349. "blocks_to_swap_out": blocks_to_swap_out,
  350. "blocks_to_copy": blocks_to_copy,
  351. })
  352. # Only the driver worker returns the sampling results.
  353. output = all_outputs[0]
  354. return output
  355. async def check_health_async(self) -> None:
  356. """Raises an error if engine is unhealthy."""
  357. self._check_if_any_actor_is_dead()