ray_xpu_executor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. import asyncio
  2. import os
  3. import pickle
  4. from collections import defaultdict
  5. from itertools import islice, repeat
  6. from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
  7. Tuple, Union)
  8. from loguru import logger
  9. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  10. LoRAConfig, ModelConfig, ParallelConfig,
  11. SchedulerConfig, SpeculativeConfig,
  12. VisionLanguageConfig)
  13. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  14. from aphrodite.common.utils import (get_distributed_init_method, get_ip,
  15. get_open_port, make_async)
  16. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  17. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  18. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  19. from aphrodite.lora.request import LoRARequest
  20. if ray is not None:
  21. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  22. if TYPE_CHECKING:
  23. from ray.util.placement_group import PlacementGroup
  24. # If the env var is set, it uses the Ray's compiled DAG API
  25. # which optimizes the control plane overhead.
  26. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  27. USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
  28. class RayXPUExecutor(DistributedGPUExecutor):
  29. def __init__(
  30. self,
  31. model_config: ModelConfig,
  32. cache_config: CacheConfig,
  33. parallel_config: ParallelConfig,
  34. scheduler_config: SchedulerConfig,
  35. device_config: DeviceConfig,
  36. load_config: LoadConfig,
  37. lora_config: Optional[LoRAConfig],
  38. vision_language_config: Optional[VisionLanguageConfig],
  39. speculative_config: Optional[SpeculativeConfig],
  40. ) -> None:
  41. assert device_config.device_type == "xpu"
  42. assert (not speculative_config
  43. ), "Speculative decoding not yet supported for XPU backend"
  44. self.model_config = model_config
  45. self.cache_config = cache_config
  46. self.load_config = load_config
  47. self.lora_config = lora_config
  48. self.parallel_config = parallel_config
  49. self.scheduler_config = scheduler_config
  50. self.device_config = device_config
  51. self.vision_language_config = vision_language_config
  52. placement_group = self.parallel_config.placement_group
  53. # Disable Ray usage stats collection.
  54. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  55. if ray_usage != "1":
  56. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  57. # Create the parallel GPU workers.
  58. self._init_workers_ray(placement_group)
  59. # Profile the memory usage and initialize the cache.
  60. self.forward_dag = None
  61. if USE_RAY_COMPILED_DAG:
  62. self.forward_dag = self._compiled_ray_dag()
  63. # This is non-None when the execute model loop is running
  64. # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
  65. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
  66. # Updated by implementations that require additional args to be passed
  67. # to the _run_workers execute_model call
  68. self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
  69. def _init_executor(self) -> None:
  70. pass
  71. def determine_num_available_blocks(self) -> Tuple[int, int]:
  72. """Determine the number of available KV blocks.
  73. This invokes `determine_num_available_blocks` on each worker and takes
  74. the min of the results, guaranteeing that the selected cache sizes are
  75. compatible with all workers.
  76. Returns:
  77. - Tuple[num_gpu_blocks, num_cpu_blocks]
  78. """
  79. # Get the maximum number of blocks that can be allocated on GPU and CPU.
  80. num_blocks = self._run_workers("determine_num_available_blocks", )
  81. # Since we use a shared centralized controller, we take the minimum
  82. # number of blocks across all workers to make sure all the memory
  83. # operators can be applied to all workers.
  84. num_gpu_blocks = min(b[0] for b in num_blocks)
  85. num_cpu_blocks = min(b[1] for b in num_blocks)
  86. return num_gpu_blocks, num_cpu_blocks
  87. def _init_workers_ray(self, placement_group: "PlacementGroup",
  88. **ray_remote_kwargs):
  89. if self.parallel_config.tensor_parallel_size == 1:
  90. # For single GPU case, we use a ray worker with constrained memory.
  91. num_gpus = self.cache_config.gpu_memory_utilization
  92. else:
  93. # Otherwise, the ray workers are allocated with a full GPU.
  94. num_gpus = 1
  95. # The driver dummy worker does not actually use any resources.
  96. # It holds the resource for the driver worker.
  97. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
  98. # The remaining workers are the actual ray actors.
  99. self.workers: List[RayWorkerWrapper] = []
  100. # Create the workers.
  101. driver_ip = get_ip()
  102. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  103. if not bundle.get("GPU", 0):
  104. continue
  105. scheduling_strategy = PlacementGroupSchedulingStrategy(
  106. placement_group=placement_group,
  107. placement_group_capture_child_tasks=True,
  108. placement_group_bundle_index=bundle_id,
  109. )
  110. worker = ray.remote(
  111. num_cpus=0,
  112. num_gpus=num_gpus,
  113. scheduling_strategy=scheduling_strategy,
  114. **ray_remote_kwargs,
  115. )(RayWorkerWrapper).remote(
  116. worker_module_name="aphrodite.task_handler.xpu_worker",
  117. worker_class_name="XPUWorker",
  118. trust_remote_code=self.model_config.trust_remote_code,
  119. )
  120. worker_ip = ray.get(worker.get_node_ip.remote())
  121. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  122. # If the worker is on the same node as the driver, we use it
  123. # as the resource holder for the driver process.
  124. self.driver_dummy_worker = worker
  125. self.driver_worker = RayWorkerWrapper(
  126. worker_module_name="aphrodite.task_handler.xpu_worker",
  127. worker_class_name="XPUWorker",
  128. trust_remote_code=self.model_config.trust_remote_code,
  129. )
  130. else:
  131. # Else, added to the list of workers.
  132. self.workers.append(worker)
  133. if self.driver_dummy_worker is None:
  134. raise ValueError(
  135. "Ray does not allocate any GPUs on the driver node. Consider "
  136. "adjusting the Ray placement group or running the driver on a "
  137. "GPU node.")
  138. # Get the set of GPU IDs used on each node.
  139. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  140. use_dummy_driver=True)
  141. node_workers = defaultdict(list)
  142. node_gpus = defaultdict(list)
  143. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
  144. node_workers[node_id].append(i)
  145. node_gpus[node_id].extend(gpu_ids)
  146. for node_id, gpu_ids in node_gpus.items():
  147. node_gpus[node_id] = sorted(gpu_ids)
  148. # TODO: add env var for xpu
  149. distributed_init_method = get_distributed_init_method(
  150. driver_ip, get_open_port())
  151. def collect_arg_helper_func(**kwargs):
  152. # avoid writing `{"name": value}` manually
  153. return kwargs
  154. init_worker_all_kwargs = []
  155. # Initialize the actual workers inside worker wrapper.
  156. for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
  157. local_rank = node_workers[node_id].index(rank)
  158. init_worker_all_kwargs.append(
  159. collect_arg_helper_func(
  160. model_config=self.model_config,
  161. parallel_config=self.parallel_config,
  162. scheduler_config=self.scheduler_config,
  163. device_config=self.device_config,
  164. cache_config=self.cache_config,
  165. load_config=self.load_config,
  166. local_rank=local_rank,
  167. rank=rank,
  168. distributed_init_method=distributed_init_method,
  169. lora_config=self.lora_config,
  170. vision_language_config=self.vision_language_config,
  171. is_driver_worker=rank == 0,
  172. ))
  173. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  174. self._run_workers("init_device")
  175. self._run_workers(
  176. "load_model",
  177. max_concurrent_workers=self.parallel_config.
  178. max_parallel_loading_workers,
  179. )
  180. def initialize_cache(self, num_gpu_blocks: int,
  181. num_cpu_blocks: int) -> None:
  182. """Initialize the KV cache in all workers.
  183. """
  184. # NOTE: We log here to avoid multiple logs when number of workers is
  185. # greater than one. We could log in the engine, but not all executors
  186. # have GPUs.
  187. logger.info(f"# XPU blocks: {num_gpu_blocks}, "
  188. f"# CPU blocks: {num_cpu_blocks}")
  189. logger.info(
  190. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  191. )
  192. self.cache_config.num_gpu_blocks = num_gpu_blocks
  193. self.cache_config.num_cpu_blocks = num_cpu_blocks
  194. self._run_workers("initialize_cache",
  195. num_gpu_blocks=num_gpu_blocks,
  196. num_cpu_blocks=num_cpu_blocks)
  197. def _driver_execute_model(
  198. self,
  199. execute_model_req: Optional[ExecuteModelRequest] = None
  200. ) -> List[SamplerOutput]:
  201. """Run execute_model in the driver worker.
  202. Passing None will cause the driver to stop the model execution
  203. loop running in each of the remote workers.
  204. """
  205. return self.driver_worker.execute_method("execute_model",
  206. execute_model_req)
  207. def add_lora(self, lora_request: LoRARequest) -> bool:
  208. assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
  209. return self._run_workers(
  210. "add_lora",
  211. lora_request=lora_request,
  212. )
  213. def remove_lora(self, lora_id: int) -> bool:
  214. assert lora_id > 0, "lora_id must be greater than 0."
  215. return self._run_workers(
  216. "remove_lora",
  217. lora_id=lora_id,
  218. )
  219. def list_loras(self) -> Set[int]:
  220. return self._run_workers("list_loras")
  221. def pin_lora(self, lora_id: int) -> bool:
  222. assert lora_id > 0, "lora_id must be greater than 0."
  223. return self._run_workers(
  224. "pin_lora",
  225. lora_id=lora_id,
  226. )
  227. def _run_workers(
  228. self,
  229. method: str,
  230. *args,
  231. async_run_remote_workers_only: bool = False,
  232. all_args: Optional[List[Tuple[Any, ...]]] = None,
  233. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  234. use_dummy_driver: bool = False,
  235. max_concurrent_workers: Optional[int] = None,
  236. use_ray_compiled_dag: bool = False,
  237. **kwargs,
  238. ) -> Any:
  239. """Runs the given method on all workers. Can be used in the following
  240. ways:
  241. - args/kwargs: All workers share the same args/kwargs
  242. - args/kwargs and driver_args/driver_kwargs: Driver worker has
  243. different args
  244. - all_args/all_kwargs: args/kwargs for each worker are specified
  245. individually
  246. """
  247. if max_concurrent_workers:
  248. raise NotImplementedError(
  249. "max_concurrent_workers is not supported yet.")
  250. count = len(self.workers)
  251. all_worker_args = repeat(args, count) if all_args is None \
  252. else islice(all_args, 1, None)
  253. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  254. else islice(all_kwargs, 1, None)
  255. if use_ray_compiled_dag:
  256. # Right now, compiled DAG can only accept a single
  257. # input. TODO: Fix it.
  258. assert self.forward_dag is not None
  259. output_channels = self.forward_dag.execute(1)
  260. else:
  261. # Start the ray workers first.
  262. ray_worker_outputs = [
  263. worker.execute_method.remote(method, *worker_args,
  264. **worker_kwargs)
  265. for (worker, worker_args, worker_kwargs
  266. ) in zip(self.workers, all_worker_args, all_worker_kwargs)
  267. ]
  268. if async_run_remote_workers_only:
  269. # Just return futures
  270. return ray_worker_outputs
  271. driver_args = args if all_args is None else all_args[0]
  272. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  273. # Start the driver worker after all the ray workers.
  274. if not use_dummy_driver:
  275. driver_worker_output = self.driver_worker.execute_method(
  276. method, *driver_args, **driver_kwargs)
  277. else:
  278. assert self.driver_dummy_worker is not None
  279. driver_worker_output = ray.get(
  280. self.driver_dummy_worker.execute_method.remote(
  281. method, *driver_args, **driver_kwargs))
  282. # Get the results of the ray workers.
  283. if self.workers:
  284. if use_ray_compiled_dag:
  285. try:
  286. ray_worker_outputs = [
  287. pickle.loads(chan.begin_read())
  288. for chan in output_channels
  289. ]
  290. finally:
  291. # Has to call end_read in order to reuse the DAG.
  292. for chan in output_channels:
  293. chan.end_read()
  294. else:
  295. ray_worker_outputs = ray.get(ray_worker_outputs)
  296. return [driver_worker_output] + ray_worker_outputs
  297. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  298. """Wait for futures returned from _run_workers() with
  299. async_run_remote_workers_only to complete."""
  300. ray.get(parallel_worker_tasks)
  301. def _compiled_ray_dag(self):
  302. import pkg_resources
  303. required_version = "2.9"
  304. current_version = pkg_resources.get_distribution("ray").version
  305. if current_version < required_version:
  306. raise ValueError(f"Ray version {required_version} or greater is "
  307. f"required, but found {current_version}")
  308. from ray.dag import InputNode, MultiOutputNode
  309. assert self.parallel_config.worker_use_ray
  310. # Right now, compiled DAG requires at least 1 arg. We send
  311. # a dummy value for now. It will be fixed soon.
  312. with InputNode() as input_data:
  313. forward_dag = MultiOutputNode([
  314. worker.execute_model_compiled_dag_remote.
  315. bind( # type: ignore[attr-defined]
  316. input_data) for worker in self.workers
  317. ])
  318. return forward_dag.experimental_compile()
  319. def check_health(self) -> None:
  320. """Raises an error if engine is unhealthy."""
  321. self._check_if_any_actor_is_dead()
  322. def _check_if_any_actor_is_dead(self):
  323. if not self.workers:
  324. return
  325. dead_actors = []
  326. for actor in self.workers:
  327. actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
  328. if actor_state["State"] == "DEAD":
  329. dead_actors.append(actor)
  330. if dead_actors:
  331. raise RuntimeError("At least one Worker is dead. "
  332. f"Dead Workers: {dead_actors}. ")
  333. class RayXPUExecutorAsync(RayXPUExecutor, DistributedGPUExecutorAsync):
  334. def __init__(self, *args, **kwargs):
  335. super().__init__(*args, **kwargs)
  336. self.driver_exec_method = make_async(self.driver_worker.execute_method)
  337. async def _driver_execute_model_async(
  338. self,
  339. execute_model_req: Optional[ExecuteModelRequest] = None
  340. ) -> List[SamplerOutput]:
  341. return await self.driver_exec_method("execute_model",
  342. execute_model_req)
  343. async def _start_worker_execution_loop(self):
  344. coros = [
  345. worker.execute_method.remote("start_worker_execution_loop")
  346. for worker in self.workers
  347. ]
  348. return await asyncio.gather(*coros)