1
0

ray_xpu_executor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. import asyncio
  2. import os
  3. from collections import defaultdict
  4. from itertools import islice, repeat
  5. from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
  6. Tuple, Union)
  7. from loguru import logger
  8. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  9. LoRAConfig, ModelConfig, ParallelConfig,
  10. PromptAdapterConfig, SchedulerConfig,
  11. SpeculativeConfig)
  12. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  13. from aphrodite.common.utils import (get_distributed_init_method, get_ip,
  14. get_open_port, make_async)
  15. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  16. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  17. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  18. from aphrodite.lora.request import LoRARequest
  19. if ray is not None:
  20. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  21. if TYPE_CHECKING:
  22. from ray.util.placement_group import PlacementGroup
  23. # If the env var is set, it uses the Ray's compiled DAG API
  24. # which optimizes the control plane overhead.
  25. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  26. USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
  27. class RayXPUExecutor(DistributedGPUExecutor):
  28. uses_ray: bool = True
  29. def __init__(
  30. self,
  31. model_config: ModelConfig,
  32. cache_config: CacheConfig,
  33. parallel_config: ParallelConfig,
  34. scheduler_config: SchedulerConfig,
  35. device_config: DeviceConfig,
  36. load_config: LoadConfig,
  37. lora_config: Optional[LoRAConfig],
  38. speculative_config: Optional[SpeculativeConfig],
  39. prompt_adapter_config: Optional[PromptAdapterConfig],
  40. ) -> None:
  41. assert device_config.device_type == "xpu"
  42. assert (not speculative_config
  43. ), "Speculative decoding not yet supported for XPU backend"
  44. self.model_config = model_config
  45. self.cache_config = cache_config
  46. self.load_config = load_config
  47. self.lora_config = lora_config
  48. self.parallel_config = parallel_config
  49. self.scheduler_config = scheduler_config
  50. self.device_config = device_config
  51. self.prompt_adapter_config = prompt_adapter_config
  52. placement_group = self.parallel_config.placement_group
  53. # Disable Ray usage stats collection.
  54. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  55. if ray_usage != "1":
  56. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  57. # Create the parallel GPU workers.
  58. self._init_workers_ray(placement_group)
  59. self.forward_dag = None
  60. if USE_RAY_COMPILED_DAG:
  61. self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
  62. # This is non-None when the execute model loop is running
  63. # in the parallel workers. It's a coroutine in the AsyncAphrodite case.
  64. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
  65. # Updated by implementations that require additional args to be passed
  66. # to the _run_workers execute_model call
  67. self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
  68. def _init_executor(self) -> None:
  69. pass
  70. def determine_num_available_blocks(self) -> Tuple[int, int]:
  71. """Determine the number of available KV blocks.
  72. This invokes `determine_num_available_blocks` on each worker and takes
  73. the min of the results, guaranteeing that the selected cache sizes are
  74. compatible with all workers.
  75. Returns:
  76. - Tuple[num_gpu_blocks, num_cpu_blocks]
  77. """
  78. # Get the maximum number of blocks that can be allocated on GPU and CPU.
  79. num_blocks = self._run_workers("determine_num_available_blocks", )
  80. # Since we use a shared centralized controller, we take the minimum
  81. # number of blocks across all workers to make sure all the memory
  82. # operators can be applied to all workers.
  83. num_gpu_blocks = min(b[0] for b in num_blocks)
  84. num_cpu_blocks = min(b[1] for b in num_blocks)
  85. return num_gpu_blocks, num_cpu_blocks
  86. def _get_worker_wrapper_args(self) -> Dict[str, Any]:
  87. return dict(
  88. worker_module_name="aphrodite.task_handler.xpu_worker",
  89. worker_class_name="XPUWorker",
  90. trust_remote_code=self.model_config.trust_remote_code,
  91. )
  92. def _init_workers_ray(self, placement_group: "PlacementGroup",
  93. **ray_remote_kwargs):
  94. if self.parallel_config.tensor_parallel_size == 1:
  95. # For single GPU case, we use a ray worker with constrained memory.
  96. num_gpus = self.cache_config.gpu_memory_utilization
  97. else:
  98. # Otherwise, the ray workers are allocated with a full GPU.
  99. num_gpus = 1
  100. # The driver dummy worker does not actually use any resources.
  101. # It holds the resource for the driver worker.
  102. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
  103. # The remaining workers are the actual ray actors.
  104. self.workers: List[RayWorkerWrapper] = []
  105. # Create the workers.
  106. driver_ip = get_ip()
  107. worker_wrapper_kwargs = self._get_worker_wrapper_args()
  108. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  109. if not bundle.get("GPU", 0):
  110. continue
  111. scheduling_strategy = PlacementGroupSchedulingStrategy(
  112. placement_group=placement_group,
  113. placement_group_capture_child_tasks=True,
  114. placement_group_bundle_index=bundle_id,
  115. )
  116. worker = ray.remote(
  117. num_cpus=0,
  118. num_gpus=num_gpus,
  119. scheduling_strategy=scheduling_strategy,
  120. **ray_remote_kwargs,
  121. )(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
  122. worker_ip = ray.get(worker.get_node_ip.remote())
  123. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  124. # If the worker is on the same node as the driver, we use it
  125. # as the resource holder for the driver process.
  126. self.driver_dummy_worker = worker
  127. self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
  128. else:
  129. # Else, added to the list of workers.
  130. self.workers.append(worker)
  131. if self.driver_dummy_worker is None:
  132. raise ValueError(
  133. "Ray does not allocate any GPUs on the driver node. Consider "
  134. "adjusting the Ray placement group or running the driver on a "
  135. "GPU node.")
  136. # Get the set of GPU IDs used on each node.
  137. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  138. use_dummy_driver=True)
  139. node_workers = defaultdict(list)
  140. node_gpus = defaultdict(list)
  141. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
  142. node_workers[node_id].append(i)
  143. node_gpus[node_id].extend(gpu_ids)
  144. for node_id, gpu_ids in node_gpus.items():
  145. node_gpus[node_id] = sorted(gpu_ids)
  146. # TODO: add env var for xpu
  147. distributed_init_method = get_distributed_init_method(
  148. driver_ip, get_open_port())
  149. def collect_arg_helper_func(**kwargs):
  150. # avoid writing `{"name": value}` manually
  151. return kwargs
  152. init_worker_all_kwargs = []
  153. # Initialize the actual workers inside worker wrapper.
  154. for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
  155. local_rank = node_workers[node_id].index(rank)
  156. init_worker_all_kwargs.append(
  157. collect_arg_helper_func(
  158. model_config=self.model_config,
  159. parallel_config=self.parallel_config,
  160. scheduler_config=self.scheduler_config,
  161. device_config=self.device_config,
  162. cache_config=self.cache_config,
  163. load_config=self.load_config,
  164. local_rank=local_rank,
  165. rank=rank,
  166. distributed_init_method=distributed_init_method,
  167. lora_config=self.lora_config,
  168. is_driver_worker=rank == 0,
  169. ))
  170. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  171. self._run_workers("init_device")
  172. self._run_workers(
  173. "load_model",
  174. max_concurrent_workers=self.parallel_config.
  175. max_parallel_loading_workers,
  176. )
  177. def initialize_cache(self, num_gpu_blocks: int,
  178. num_cpu_blocks: int) -> None:
  179. """Initialize the KV cache in all workers.
  180. """
  181. # NOTE: We log here to avoid multiple logs when number of workers is
  182. # greater than one. We could log in the engine, but not all executors
  183. # have GPUs.
  184. logger.info(f"# XPU blocks: {num_gpu_blocks}, "
  185. f"# CPU blocks: {num_cpu_blocks}")
  186. logger.info(
  187. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  188. )
  189. self.cache_config.num_gpu_blocks = num_gpu_blocks
  190. self.cache_config.num_cpu_blocks = num_cpu_blocks
  191. self._run_workers("initialize_cache",
  192. num_gpu_blocks=num_gpu_blocks,
  193. num_cpu_blocks=num_cpu_blocks)
  194. def _driver_execute_model(
  195. self,
  196. execute_model_req: Optional[ExecuteModelRequest] = None
  197. ) -> List[SamplerOutput]:
  198. """Run execute_model in the driver worker.
  199. Passing None will cause the driver to stop the model execution
  200. loop running in each of the remote workers.
  201. """
  202. return self.driver_worker.execute_method("execute_model",
  203. execute_model_req)
  204. def add_lora(self, lora_request: LoRARequest) -> bool:
  205. assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
  206. return self._run_workers(
  207. "add_lora",
  208. lora_request=lora_request,
  209. )
  210. def remove_lora(self, lora_id: int) -> bool:
  211. assert lora_id > 0, "lora_id must be greater than 0."
  212. return self._run_workers(
  213. "remove_lora",
  214. lora_id=lora_id,
  215. )
  216. def list_loras(self) -> Set[int]:
  217. return self._run_workers("list_loras")
  218. def pin_lora(self, lora_id: int) -> bool:
  219. assert lora_id > 0, "lora_id must be greater than 0."
  220. return self._run_workers(
  221. "pin_lora",
  222. lora_id=lora_id,
  223. )
  224. def _run_workers(
  225. self,
  226. method: str,
  227. *args,
  228. async_run_remote_workers_only: bool = False,
  229. all_args: Optional[List[Tuple[Any, ...]]] = None,
  230. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  231. use_dummy_driver: bool = False,
  232. max_concurrent_workers: Optional[int] = None,
  233. **kwargs,
  234. ) -> Any:
  235. """Runs the given method on all workers. Can be used in the following
  236. ways:
  237. - args/kwargs: All workers share the same args/kwargs
  238. - args/kwargs and driver_args/driver_kwargs: Driver worker has
  239. different args
  240. - all_args/all_kwargs: args/kwargs for each worker are specified
  241. individually
  242. """
  243. if max_concurrent_workers:
  244. raise NotImplementedError(
  245. "max_concurrent_workers is not supported yet.")
  246. count = len(self.workers)
  247. all_worker_args = repeat(args, count) if all_args is None \
  248. else islice(all_args, 1, None)
  249. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  250. else islice(all_kwargs, 1, None)
  251. # Start the ray workers first.
  252. ray_worker_outputs = [
  253. worker.execute_method.remote(method, *worker_args, **worker_kwargs)
  254. for (worker, worker_args, worker_kwargs
  255. ) in zip(self.workers, all_worker_args, all_worker_kwargs)
  256. ]
  257. if async_run_remote_workers_only:
  258. # Just return futures
  259. return ray_worker_outputs
  260. driver_worker_output = []
  261. driver_args = args if all_args is None else all_args[0]
  262. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  263. # Start the driver worker after all the ray workers.
  264. if not use_dummy_driver:
  265. driver_worker_output = self.driver_worker.execute_method(
  266. method, *driver_args, **driver_kwargs)
  267. else:
  268. assert self.driver_dummy_worker is not None
  269. driver_worker_output = ray.get(
  270. self.driver_dummy_worker.execute_method.remote(
  271. method, *driver_args, **driver_kwargs))
  272. # Get the results of the ray workers.
  273. if self.workers:
  274. ray_worker_outputs = ray.get(ray_worker_outputs)
  275. return driver_worker_output + ray_worker_outputs
  276. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  277. """Wait for futures returned from _run_workers() with
  278. async_run_remote_workers_only to complete."""
  279. ray.get(parallel_worker_tasks)
  280. def _compiled_ray_dag(self, enable_asyncio: bool):
  281. import pkg_resources
  282. from packaging import version
  283. required_version = version.parse("2.32")
  284. current_version = version.parse(
  285. pkg_resources.get_distribution("ray").version)
  286. if current_version < required_version:
  287. raise ValueError(f"Ray version {required_version} or greater is "
  288. f"required, but found {current_version}")
  289. from ray.dag import InputNode, MultiOutputNode
  290. assert self.parallel_config.use_ray
  291. # Right now, compiled DAG requires at least 1 arg. We send
  292. # a dummy value for now. It will be fixed soon.
  293. with InputNode() as input_data:
  294. forward_dag = MultiOutputNode([
  295. worker.execute_model_compiled_dag_remote.
  296. bind( # type: ignore[attr-defined]
  297. input_data) for worker in self.workers
  298. ])
  299. return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
  300. def check_health(self) -> None:
  301. """Raises an error if engine is unhealthy."""
  302. self._check_if_any_actor_is_dead()
  303. def _check_if_any_actor_is_dead(self):
  304. if not self.workers:
  305. return
  306. dead_actors = []
  307. for actor in self.workers:
  308. actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
  309. if actor_state["State"] == "DEAD":
  310. dead_actors.append(actor)
  311. if dead_actors:
  312. raise RuntimeError("At least one Worker is dead. "
  313. f"Dead Workers: {dead_actors}. ")
  314. class RayXPUExecutorAsync(RayXPUExecutor, DistributedGPUExecutorAsync):
  315. def __init__(self, *args, **kwargs):
  316. super().__init__(*args, **kwargs)
  317. self.driver_exec_method = make_async(self.driver_worker.execute_method)
  318. async def _driver_execute_model_async(
  319. self,
  320. execute_model_req: Optional[ExecuteModelRequest] = None
  321. ) -> List[SamplerOutput]:
  322. return await self.driver_exec_method("execute_model",
  323. execute_model_req)
  324. async def _start_worker_execution_loop(self):
  325. coros = [
  326. worker.execute_method.remote("start_worker_execution_loop")
  327. for worker in self.workers
  328. ]
  329. return await asyncio.gather(*coros)