ray_xpu_executor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. import asyncio
  2. import os
  3. from collections import defaultdict
  4. from itertools import islice, repeat
  5. from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
  6. Tuple, Union)
  7. from loguru import logger
  8. import aphrodite.common.envs as envs
  9. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  10. LoRAConfig, ModelConfig, ParallelConfig,
  11. PromptAdapterConfig, SchedulerConfig,
  12. SpeculativeConfig)
  13. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  14. from aphrodite.common.utils import (get_distributed_init_method, get_ip,
  15. get_open_port, make_async)
  16. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  17. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  18. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  19. from aphrodite.lora.request import LoRARequest
  20. if ray is not None:
  21. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  22. if TYPE_CHECKING:
  23. from ray.util.placement_group import PlacementGroup
  24. # If the env var is set, it uses the Ray's compiled DAG API
  25. # which optimizes the control plane overhead.
  26. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  27. USE_RAY_COMPILED_DAG = envs.APHRODITE_USE_RAY_COMPILED_DAG
  28. class RayXPUExecutor(DistributedGPUExecutor):
  29. uses_ray: bool = True
  30. def __init__(
  31. self,
  32. model_config: ModelConfig,
  33. cache_config: CacheConfig,
  34. parallel_config: ParallelConfig,
  35. scheduler_config: SchedulerConfig,
  36. device_config: DeviceConfig,
  37. load_config: LoadConfig,
  38. lora_config: Optional[LoRAConfig],
  39. speculative_config: Optional[SpeculativeConfig],
  40. prompt_adapter_config: Optional[PromptAdapterConfig],
  41. ) -> None:
  42. assert device_config.device_type == "xpu"
  43. assert (not speculative_config
  44. ), "Speculative decoding not yet supported for XPU backend"
  45. self.model_config = model_config
  46. self.cache_config = cache_config
  47. self.load_config = load_config
  48. self.lora_config = lora_config
  49. self.parallel_config = parallel_config
  50. self.scheduler_config = scheduler_config
  51. self.device_config = device_config
  52. self.prompt_adapter_config = prompt_adapter_config
  53. placement_group = self.parallel_config.placement_group
  54. # Disable Ray usage stats collection.
  55. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  56. if ray_usage != "1":
  57. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  58. # Create the parallel GPU workers.
  59. self._init_workers_ray(placement_group)
  60. self.forward_dag = None
  61. if USE_RAY_COMPILED_DAG:
  62. self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
  63. # This is non-None when the execute model loop is running
  64. # in the parallel workers. It's a coroutine in the AsyncAphrodite case.
  65. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
  66. # Updated by implementations that require additional args to be passed
  67. # to the _run_workers execute_model call
  68. self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
  69. def _init_executor(self) -> None:
  70. pass
  71. def determine_num_available_blocks(self) -> Tuple[int, int]:
  72. """Determine the number of available KV blocks.
  73. This invokes `determine_num_available_blocks` on each worker and takes
  74. the min of the results, guaranteeing that the selected cache sizes are
  75. compatible with all workers.
  76. Returns:
  77. - Tuple[num_gpu_blocks, num_cpu_blocks]
  78. """
  79. # Get the maximum number of blocks that can be allocated on GPU and CPU.
  80. num_blocks = self._run_workers("determine_num_available_blocks", )
  81. # Since we use a shared centralized controller, we take the minimum
  82. # number of blocks across all workers to make sure all the memory
  83. # operators can be applied to all workers.
  84. num_gpu_blocks = min(b[0] for b in num_blocks)
  85. num_cpu_blocks = min(b[1] for b in num_blocks)
  86. return num_gpu_blocks, num_cpu_blocks
  87. def _get_worker_wrapper_args(self) -> Dict[str, Any]:
  88. return dict(
  89. worker_module_name="aphrodite.task_handler.xpu_worker",
  90. worker_class_name="XPUWorker",
  91. trust_remote_code=self.model_config.trust_remote_code,
  92. )
  93. def _init_workers_ray(self, placement_group: "PlacementGroup",
  94. **ray_remote_kwargs):
  95. if self.parallel_config.tensor_parallel_size == 1:
  96. # For single GPU case, we use a ray worker with constrained memory.
  97. num_gpus = self.cache_config.gpu_memory_utilization
  98. else:
  99. # Otherwise, the ray workers are allocated with a full GPU.
  100. num_gpus = 1
  101. # The driver dummy worker does not actually use any resources.
  102. # It holds the resource for the driver worker.
  103. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
  104. # The remaining workers are the actual ray actors.
  105. self.workers: List[RayWorkerWrapper] = []
  106. # Create the workers.
  107. driver_ip = get_ip()
  108. worker_wrapper_kwargs = self._get_worker_wrapper_args()
  109. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  110. if not bundle.get("GPU", 0):
  111. continue
  112. scheduling_strategy = PlacementGroupSchedulingStrategy(
  113. placement_group=placement_group,
  114. placement_group_capture_child_tasks=True,
  115. placement_group_bundle_index=bundle_id,
  116. )
  117. worker = ray.remote(
  118. num_cpus=0,
  119. num_gpus=num_gpus,
  120. scheduling_strategy=scheduling_strategy,
  121. **ray_remote_kwargs,
  122. )(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
  123. worker_ip = ray.get(worker.get_node_ip.remote())
  124. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  125. # If the worker is on the same node as the driver, we use it
  126. # as the resource holder for the driver process.
  127. self.driver_dummy_worker = worker
  128. self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
  129. else:
  130. # Else, added to the list of workers.
  131. self.workers.append(worker)
  132. if self.driver_dummy_worker is None:
  133. raise ValueError(
  134. "Ray does not allocate any GPUs on the driver node. Consider "
  135. "adjusting the Ray placement group or running the driver on a "
  136. "GPU node.")
  137. # Get the set of GPU IDs used on each node.
  138. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  139. use_dummy_driver=True)
  140. node_workers = defaultdict(list)
  141. node_gpus = defaultdict(list)
  142. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
  143. node_workers[node_id].append(i)
  144. node_gpus[node_id].extend(gpu_ids)
  145. for node_id, gpu_ids in node_gpus.items():
  146. node_gpus[node_id] = sorted(gpu_ids)
  147. # TODO: add env var for xpu
  148. distributed_init_method = get_distributed_init_method(
  149. driver_ip, get_open_port())
  150. def collect_arg_helper_func(**kwargs):
  151. # avoid writing `{"name": value}` manually
  152. return kwargs
  153. init_worker_all_kwargs = []
  154. # Initialize the actual workers inside worker wrapper.
  155. for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
  156. local_rank = node_workers[node_id].index(rank)
  157. init_worker_all_kwargs.append(
  158. collect_arg_helper_func(
  159. model_config=self.model_config,
  160. parallel_config=self.parallel_config,
  161. scheduler_config=self.scheduler_config,
  162. device_config=self.device_config,
  163. cache_config=self.cache_config,
  164. load_config=self.load_config,
  165. local_rank=local_rank,
  166. rank=rank,
  167. distributed_init_method=distributed_init_method,
  168. lora_config=self.lora_config,
  169. is_driver_worker=rank == 0,
  170. ))
  171. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  172. self._run_workers("init_device")
  173. self._run_workers(
  174. "load_model",
  175. max_concurrent_workers=self.parallel_config.
  176. max_parallel_loading_workers,
  177. )
  178. def initialize_cache(self, num_gpu_blocks: int,
  179. num_cpu_blocks: int) -> None:
  180. """Initialize the KV cache in all workers.
  181. """
  182. # NOTE: We log here to avoid multiple logs when number of workers is
  183. # greater than one. We could log in the engine, but not all executors
  184. # have GPUs.
  185. logger.info(f"# XPU blocks: {num_gpu_blocks}, "
  186. f"# CPU blocks: {num_cpu_blocks}")
  187. logger.info(
  188. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  189. )
  190. self.cache_config.num_gpu_blocks = num_gpu_blocks
  191. self.cache_config.num_cpu_blocks = num_cpu_blocks
  192. self._run_workers("initialize_cache",
  193. num_gpu_blocks=num_gpu_blocks,
  194. num_cpu_blocks=num_cpu_blocks)
  195. def _driver_execute_model(
  196. self,
  197. execute_model_req: Optional[ExecuteModelRequest] = None
  198. ) -> List[SamplerOutput]:
  199. """Run execute_model in the driver worker.
  200. Passing None will cause the driver to stop the model execution
  201. loop running in each of the remote workers.
  202. """
  203. return self.driver_worker.execute_method("execute_model",
  204. execute_model_req)
  205. def add_lora(self, lora_request: LoRARequest) -> bool:
  206. assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
  207. return self._run_workers(
  208. "add_lora",
  209. lora_request=lora_request,
  210. )
  211. def remove_lora(self, lora_id: int) -> bool:
  212. assert lora_id > 0, "lora_id must be greater than 0."
  213. return self._run_workers(
  214. "remove_lora",
  215. lora_id=lora_id,
  216. )
  217. def list_loras(self) -> Set[int]:
  218. return self._run_workers("list_loras")
  219. def pin_lora(self, lora_id: int) -> bool:
  220. assert lora_id > 0, "lora_id must be greater than 0."
  221. return self._run_workers(
  222. "pin_lora",
  223. lora_id=lora_id,
  224. )
  225. def _run_workers(
  226. self,
  227. method: str,
  228. *args,
  229. async_run_remote_workers_only: bool = False,
  230. all_args: Optional[List[Tuple[Any, ...]]] = None,
  231. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  232. use_dummy_driver: bool = False,
  233. max_concurrent_workers: Optional[int] = None,
  234. **kwargs,
  235. ) -> Any:
  236. """Runs the given method on all workers. Can be used in the following
  237. ways:
  238. - args/kwargs: All workers share the same args/kwargs
  239. - args/kwargs and driver_args/driver_kwargs: Driver worker has
  240. different args
  241. - all_args/all_kwargs: args/kwargs for each worker are specified
  242. individually
  243. """
  244. if max_concurrent_workers:
  245. raise NotImplementedError(
  246. "max_concurrent_workers is not supported yet.")
  247. count = len(self.workers)
  248. all_worker_args = repeat(args, count) if all_args is None \
  249. else islice(all_args, 1, None)
  250. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  251. else islice(all_kwargs, 1, None)
  252. # Start the ray workers first.
  253. ray_worker_outputs = [
  254. worker.execute_method.remote(method, *worker_args, **worker_kwargs)
  255. for (worker, worker_args, worker_kwargs
  256. ) in zip(self.workers, all_worker_args, all_worker_kwargs)
  257. ]
  258. if async_run_remote_workers_only:
  259. # Just return futures
  260. return ray_worker_outputs
  261. driver_worker_output = []
  262. driver_args = args if all_args is None else all_args[0]
  263. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  264. # Start the driver worker after all the ray workers.
  265. if not use_dummy_driver:
  266. driver_worker_output = self.driver_worker.execute_method(
  267. method, *driver_args, **driver_kwargs)
  268. else:
  269. assert self.driver_dummy_worker is not None
  270. driver_worker_output = ray.get(
  271. self.driver_dummy_worker.execute_method.remote(
  272. method, *driver_args, **driver_kwargs))
  273. # Get the results of the ray workers.
  274. if self.workers:
  275. ray_worker_outputs = ray.get(ray_worker_outputs)
  276. return driver_worker_output + ray_worker_outputs
  277. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  278. """Wait for futures returned from _run_workers() with
  279. async_run_remote_workers_only to complete."""
  280. ray.get(parallel_worker_tasks)
  281. def _compiled_ray_dag(self, enable_asyncio: bool):
  282. import pkg_resources
  283. from packaging import version
  284. required_version = version.parse("2.32")
  285. current_version = version.parse(
  286. pkg_resources.get_distribution("ray").version)
  287. if current_version < required_version:
  288. raise ValueError(f"Ray version {required_version} or greater is "
  289. f"required, but found {current_version}")
  290. from ray.dag import InputNode, MultiOutputNode
  291. assert self.parallel_config.use_ray
  292. # Right now, compiled DAG requires at least 1 arg. We send
  293. # a dummy value for now. It will be fixed soon.
  294. with InputNode() as input_data:
  295. forward_dag = MultiOutputNode([
  296. worker.execute_model_compiled_dag_remote.
  297. bind( # type: ignore[attr-defined]
  298. input_data) for worker in self.workers
  299. ])
  300. return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
  301. def check_health(self) -> None:
  302. """Raises an error if engine is unhealthy."""
  303. self._check_if_any_actor_is_dead()
  304. def _check_if_any_actor_is_dead(self):
  305. if not self.workers:
  306. return
  307. dead_actors = []
  308. for actor in self.workers:
  309. actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
  310. if actor_state["State"] == "DEAD":
  311. dead_actors.append(actor)
  312. if dead_actors:
  313. raise RuntimeError("At least one Worker is dead. "
  314. f"Dead Workers: {dead_actors}. ")
  315. class RayXPUExecutorAsync(RayXPUExecutor, DistributedGPUExecutorAsync):
  316. def __init__(self, *args, **kwargs):
  317. super().__init__(*args, **kwargs)
  318. self.driver_exec_method = make_async(self.driver_worker.execute_method)
  319. async def _driver_execute_model_async(
  320. self,
  321. execute_model_req: Optional[ExecuteModelRequest] = None
  322. ) -> List[SamplerOutput]:
  323. return await self.driver_exec_method("execute_model",
  324. execute_model_req)
  325. async def _start_worker_execution_loop(self):
  326. coros = [
  327. worker.execute_method.remote("start_worker_execution_loop")
  328. for worker in self.workers
  329. ]
  330. return await asyncio.gather(*coros)