ray_gpu_executor.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. import asyncio
  2. import os
  3. from collections import defaultdict
  4. from itertools import islice, repeat
  5. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
  6. import msgspec
  7. from loguru import logger
  8. import aphrodite.common.envs as envs
  9. from aphrodite.common.sequence import ExecuteModelRequest
  10. from aphrodite.common.utils import (_run_task_with_lock,
  11. get_aphrodite_instance_id,
  12. get_distributed_init_method, get_ip,
  13. get_open_port, make_async)
  14. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  15. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  16. from aphrodite.executor.msgspec_utils import encode_hook
  17. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  18. from aphrodite.modeling.layers.sampler import SamplerOutput
  19. if ray is not None:
  20. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  21. if TYPE_CHECKING:
  22. from ray.util.placement_group import PlacementGroup
  23. # If the env var is set, it uses the Ray's compiled DAG API
  24. # which optimizes the control plane overhead.
  25. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  26. APHRODITE_USE_RAY_COMPILED_DAG = envs.APHRODITE_USE_RAY_COMPILED_DAG
  27. APHRODITE_TRACE_FUNCTION = envs.APHRODITE_TRACE_FUNCTION
  28. APHRODITE_USE_RAY_SPMD_WORKER = envs.APHRODITE_USE_RAY_SPMD_WORKER
  29. APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = (
  30. envs.APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
  31. class RayGPUExecutor(DistributedGPUExecutor):
  32. uses_ray: bool = True
  33. def _init_executor(self) -> None:
  34. self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
  35. # If the env var is set, it uses the Ray's compiled DAG API
  36. # which optimizes the control plane overhead.
  37. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
  38. # Currently, this requires USE_RAY_SPMD_WORKER=True.
  39. self.use_ray_compiled_dag = APHRODITE_USE_RAY_COMPILED_DAG
  40. # If the env var is set, then we do not distinguish between the
  41. # "driver worker" vs other workers. Also, the rank 0 worker will
  42. # be executed in a remote Ray worker. Currently this requires
  43. # USE_RAY_COMPILED_DAG=True.
  44. self.use_ray_spmd_worker = APHRODITE_USE_RAY_SPMD_WORKER
  45. if self.use_ray_compiled_dag:
  46. assert self.use_ray_spmd_worker, (
  47. "APHRODITE_USE_RAY_COMPILED_DAG=1 requires "
  48. "APHRODITE_USE_RAY_SPMD_WORKER=1")
  49. if self.use_ray_spmd_worker:
  50. # TODO: Support SPMD worker for non-DAG Ray executor.
  51. assert self.use_ray_compiled_dag, (
  52. "APHRODITE_USE_RAY_SPMD_WORKER=1 requires "
  53. "APHRODITE_USE_RAY_COMPILED_DAG=1")
  54. assert self.uses_ray
  55. placement_group = self.parallel_config.placement_group
  56. # Disable Ray usage stats collection.
  57. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  58. if ray_usage != "1":
  59. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  60. # Create the parallel GPU workers.
  61. self._init_workers_ray(placement_group)
  62. self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
  63. self.output_decoder = msgspec.msgpack.Decoder(
  64. Optional[List[SamplerOutput]])
  65. def shutdown(self) -> None:
  66. if hasattr(self, "forward_dag") and self.forward_dag is not None:
  67. self.forward_dag.teardown()
  68. import ray
  69. for worker in self.workers:
  70. ray.kill(worker)
  71. self.forward_dag = None
  72. def _configure_ray_workers_use_nsight(self,
  73. ray_remote_kwargs) -> Dict[str, Any]:
  74. # If nsight profiling is enabled, we need to set the profiling
  75. # configuration for the ray workers as runtime env.
  76. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
  77. runtime_env.update({
  78. "nsight": {
  79. "t": "cuda,cudnn,cublas",
  80. "o": "'worker_process_%p'",
  81. "cuda-graph-trace": "node",
  82. }
  83. })
  84. return ray_remote_kwargs
  85. def _get_worker_wrapper_args(self) -> Dict[str, Any]:
  86. (worker_module_name, worker_class_name,
  87. worker_class_fn) = self._get_worker_module_and_class()
  88. return dict(
  89. worker_module_name=worker_module_name,
  90. worker_class_name=worker_class_name,
  91. worker_class_fn=worker_class_fn,
  92. trust_remote_code=self.model_config.trust_remote_code,
  93. )
  94. # child class could overwrite this to return actual env vars.
  95. def _get_env_vars_to_be_updated(self):
  96. return self._env_vars_for_all_workers
  97. def _init_workers_ray(self, placement_group: "PlacementGroup",
  98. **ray_remote_kwargs):
  99. if (self.parallel_config.tensor_parallel_size == 1
  100. and self.parallel_config.pipeline_parallel_size == 1):
  101. # For single GPU case, we use a ray worker with constrained memory.
  102. num_gpus = self.cache_config.gpu_memory_utilization
  103. else:
  104. # Otherwise, the ray workers are allocated with a full GPU.
  105. num_gpus = 1
  106. # The driver dummy worker does not actually use any resources.
  107. # It holds the resource for the driver worker.
  108. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
  109. # The remaining workers are the actual ray actors.
  110. self.workers: List[RayWorkerWrapper] = []
  111. # Used in ray compiled DAG: indexed first by PP rank,
  112. # and then TP rank. In other words, the inner list is
  113. # the TP group of workers for a PP rank.
  114. self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
  115. if self.parallel_config.ray_workers_use_nsight:
  116. ray_remote_kwargs = self._configure_ray_workers_use_nsight(
  117. ray_remote_kwargs)
  118. logger.info(f"use_ray_spmd_worker: {self.use_ray_spmd_worker}")
  119. # Create the workers.
  120. driver_ip = get_ip()
  121. logger.info(f"driver_ip: {driver_ip}")
  122. worker_wrapper_kwargs = self._get_worker_wrapper_args()
  123. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  124. if not bundle.get("GPU", 0):
  125. continue
  126. scheduling_strategy = PlacementGroupSchedulingStrategy(
  127. placement_group=placement_group,
  128. placement_group_capture_child_tasks=True,
  129. placement_group_bundle_index=bundle_id,
  130. )
  131. worker = ray.remote(
  132. num_cpus=0,
  133. num_gpus=num_gpus,
  134. scheduling_strategy=scheduling_strategy,
  135. **ray_remote_kwargs,
  136. )(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
  137. if self.use_ray_spmd_worker:
  138. self.workers.append(worker)
  139. else:
  140. worker_ip = ray.get(worker.get_node_ip.remote())
  141. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  142. # If the worker is on the same node as the driver, we use it
  143. # as the resource holder for the driver process.
  144. self.driver_dummy_worker = worker
  145. self.driver_worker = RayWorkerWrapper(
  146. **worker_wrapper_kwargs)
  147. else:
  148. # Else, added to the list of workers.
  149. self.workers.append(worker)
  150. logger.debug(f"workers: {self.workers}")
  151. logger.debug(f"driver_dummy_worker: {self.driver_dummy_worker}")
  152. if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
  153. raise ValueError(
  154. "Ray does not allocate any GPUs on the driver node. Consider "
  155. "adjusting the Ray placement group or running the driver on a "
  156. "GPU node.")
  157. worker_ips = [
  158. ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
  159. for worker in self.workers
  160. ]
  161. ip_counts: Dict[str, int] = {}
  162. for ip in worker_ips:
  163. ip_counts[ip] = ip_counts.get(ip, 0) + 1
  164. def sort_by_driver_then_worker_ip(worker):
  165. """
  166. Sort the workers based on 3 properties:
  167. 1. If the worker is on the same node as the driver (vllm engine),
  168. it should be placed first.
  169. 2. Then, if the worker is on a node with fewer workers, it should
  170. be placed first.
  171. 3. Finally, if the work is on a node with smaller IP address, it
  172. should be placed first.
  173. """
  174. ip = ray.get(worker.get_node_ip.remote())
  175. return (ip != driver_ip, ip_counts[ip], ip)
  176. # After sorting, the workers on the same node will be
  177. # close to each other, and the workers on the driver
  178. # node will be placed first.
  179. self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
  180. # Get the set of GPU IDs used on each node.
  181. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  182. use_dummy_driver=True)
  183. node_workers = defaultdict(list) # node id -> list of worker ranks
  184. node_gpus = defaultdict(list) # node id -> list of gpu ids
  185. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
  186. node_workers[node_id].append(i)
  187. # `gpu_ids` can be a list of strings or integers.
  188. # convert them to integers for consistency.
  189. # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
  190. # string sorting is not sufficient.
  191. gpu_ids = [int(x) for x in gpu_ids]
  192. node_gpus[node_id].extend(gpu_ids)
  193. for node_id, gpu_ids in node_gpus.items():
  194. node_gpus[node_id] = sorted(gpu_ids)
  195. all_ips = set(worker_ips + [driver_ip])
  196. n_ips = len(all_ips)
  197. n_nodes = len(node_workers)
  198. if n_nodes != n_ips:
  199. raise RuntimeError(
  200. f"Every node should have a unique IP address. Got {n_nodes}"
  201. f" nodes with node ids {list(node_workers.keys())} and "
  202. f"{n_ips} unique IP addresses {all_ips}. Please check your"
  203. " network configuration. If you set `APHRODITE_HOST_IP` or "
  204. "`HOST_IP` environment variable, make sure it is unique for"
  205. " each node.")
  206. APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
  207. # Set environment variables for the driver and workers.
  208. all_args_to_update_environment_variables = [({
  209. "CUDA_VISIBLE_DEVICES":
  210. ",".join(map(str, node_gpus[node_id])),
  211. "APHRODITE_INSTANCE_ID":
  212. APHRODITE_INSTANCE_ID,
  213. "APHRODITE_TRACE_FUNCTION":
  214. str(APHRODITE_TRACE_FUNCTION),
  215. **({
  216. "APHRODITE_ATTENTION_BACKEND": envs.APHRODITE_ATTENTION_BACKEND
  217. } if envs.APHRODITE_ATTENTION_BACKEND is not None else {})
  218. }, ) for (node_id, _) in worker_node_and_gpu_ids]
  219. self._env_vars_for_all_workers = (
  220. all_args_to_update_environment_variables)
  221. self._run_workers("update_environment_variables",
  222. all_args=self._get_env_vars_to_be_updated())
  223. if len(node_gpus) == 1:
  224. # in single node case, we don't need to get the IP address.
  225. # the loopback address is sufficient
  226. # NOTE: a node may have several IP addresses, one for each
  227. # network interface. `get_ip()` might return any of them,
  228. # while they might not work for communication inside the node
  229. # if the network setup is complicated. Using the loopback address
  230. # solves this issue, as it always works for communication inside
  231. # the node.
  232. driver_ip = "127.0.0.1"
  233. distributed_init_method = get_distributed_init_method(
  234. driver_ip, get_open_port())
  235. # Initialize the actual workers inside worker wrapper.
  236. init_worker_all_kwargs = [
  237. self._get_worker_kwargs(
  238. local_rank=node_workers[node_id].index(rank),
  239. rank=rank,
  240. distributed_init_method=distributed_init_method,
  241. ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
  242. ]
  243. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  244. self._run_workers("init_device")
  245. self._run_workers("load_model",
  246. max_concurrent_workers=self.parallel_config.
  247. max_parallel_loading_workers)
  248. if self.use_ray_spmd_worker:
  249. for pp_rank in range(self.parallel_config.pipeline_parallel_size):
  250. self.pp_tp_workers.append([])
  251. for tp_rank in range(
  252. self.parallel_config.tensor_parallel_size):
  253. # PP=2, TP=4
  254. # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
  255. rank = (pp_rank * self.parallel_config.tensor_parallel_size
  256. ) + tp_rank
  257. assert len(self.pp_tp_workers[pp_rank]) == tp_rank
  258. assert pp_rank < len(self.pp_tp_workers)
  259. self.pp_tp_workers[pp_rank].append(self.workers[rank])
  260. # This is the list of workers that are rank 0 of each TP group EXCEPT
  261. # global rank 0. These are the workers that will broadcast to the
  262. # rest of the workers.
  263. self.tp_driver_workers: List[RayWorkerWrapper] = []
  264. # This is the list of workers that are not drivers and not the first
  265. # worker in a TP group. These are the workers that will be
  266. # broadcasted to.
  267. self.non_driver_workers: List[RayWorkerWrapper] = []
  268. # Enforce rank order for correct rank to return final output.
  269. for index, worker in enumerate(self.workers):
  270. # The driver worker is rank 0 and not in self.workers.
  271. rank = index + 1
  272. if rank % self.parallel_config.tensor_parallel_size == 0:
  273. self.tp_driver_workers.append(worker)
  274. else:
  275. self.non_driver_workers.append(worker)
  276. def _driver_execute_model(
  277. self, execute_model_req: Optional[ExecuteModelRequest]
  278. ) -> Optional[List[SamplerOutput]]:
  279. """Run execute_model in the driver worker.
  280. Passing None will cause the driver to stop the model execution
  281. loop running in each of the remote workers.
  282. """
  283. assert not self.use_ray_spmd_worker, (
  284. "driver_worker does not exist for APHRODITE_USE_RAY_SPMD_WORKER=1")
  285. return self.driver_worker.execute_method("execute_model",
  286. execute_model_req)
  287. def execute_model(
  288. self,
  289. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  290. if not self.use_ray_spmd_worker:
  291. return super().execute_model(execute_model_req)
  292. if self.forward_dag is None:
  293. self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
  294. serialized_data = self.input_encoder.encode(execute_model_req)
  295. outputs = ray.get(self.forward_dag.execute(serialized_data))
  296. output = self.output_decoder.decode(outputs[0])
  297. return output
  298. def _run_workers(
  299. self,
  300. method: str,
  301. *args,
  302. async_run_tensor_parallel_workers_only: bool = False,
  303. all_args: Optional[List[Tuple[Any, ...]]] = None,
  304. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  305. use_dummy_driver: bool = False,
  306. max_concurrent_workers: Optional[int] = None,
  307. **kwargs,
  308. ) -> Any:
  309. """Runs the given method on all workers. Can be used in the following
  310. ways:
  311. Args:
  312. - async_run_tensor_parallel_workers_only: If True the method will be
  313. run only in the remote TP workers, not the driver worker.
  314. It will also be run asynchronously and return a list of futures
  315. rather than blocking on the results.
  316. - args/kwargs: All workers share the same args/kwargs
  317. - all_args/all_kwargs: args/kwargs for each worker are specified
  318. individually
  319. """
  320. if self.use_ray_spmd_worker:
  321. assert not async_run_tensor_parallel_workers_only, (
  322. "async_run_tensor_parallel_workers_only is not supported for "
  323. "spmd mode.")
  324. if max_concurrent_workers:
  325. raise NotImplementedError(
  326. "max_concurrent_workers is not supported yet.")
  327. count = len(self.workers) if not \
  328. async_run_tensor_parallel_workers_only \
  329. else len(self.non_driver_workers)
  330. # If using SPMD worker, all workers are the same, so we should execute
  331. # the args on all workers. Otherwise, we skip the first worker's args
  332. # because those args will go to the driver worker.
  333. first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
  334. all_worker_args = repeat(args, count) if all_args is None \
  335. else islice(all_args, first_worker_args_index, None)
  336. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  337. else islice(all_kwargs, first_worker_args_index, None)
  338. # Start the ray workers first.
  339. ray_workers = self.workers
  340. if async_run_tensor_parallel_workers_only:
  341. ray_workers = self.non_driver_workers
  342. ray_worker_outputs = [
  343. worker.execute_method.remote(method, *worker_args, **worker_kwargs)
  344. for (worker, worker_args, worker_kwargs
  345. ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
  346. ]
  347. if async_run_tensor_parallel_workers_only:
  348. # Just return futures
  349. return ray_worker_outputs
  350. driver_worker_output = []
  351. # In SPMD mode, the driver worker is the same as any other worker,
  352. # so we only explicitly execute on the driver worker if using a
  353. # non-SPMD worker class.
  354. if not self.use_ray_spmd_worker:
  355. driver_args = args if all_args is None else all_args[0]
  356. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  357. # Start the driver worker after all the ray workers.
  358. if not use_dummy_driver:
  359. driver_worker_output = [
  360. self.driver_worker.execute_method(method, *driver_args,
  361. **driver_kwargs)
  362. ]
  363. else:
  364. assert self.driver_dummy_worker is not None
  365. driver_worker_output = [
  366. ray.get(
  367. self.driver_dummy_worker.execute_method.remote(
  368. method, *driver_args, **driver_kwargs))
  369. ]
  370. # Get the results of the ray workers.
  371. if self.workers:
  372. ray_worker_outputs = ray.get(ray_worker_outputs)
  373. return driver_worker_output + ray_worker_outputs
  374. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  375. """Wait for futures returned from _run_workers() with
  376. async_run_remote_workers_only to complete."""
  377. ray.get(parallel_worker_tasks)
  378. def _check_ray_adag_installation(self):
  379. import pkg_resources
  380. from packaging import version
  381. required_version = version.parse("2.35")
  382. current_version = version.parse(
  383. pkg_resources.get_distribution("ray").version)
  384. if current_version < required_version:
  385. raise ValueError(f"Ray version {required_version} or greater is "
  386. f"required, but found {current_version}")
  387. import importlib.util
  388. adag_spec = importlib.util.find_spec(
  389. "ray.experimental.compiled_dag_ref")
  390. if adag_spec is None:
  391. raise ValueError("Ray accelerated DAG is not installed. "
  392. "Run `pip install ray[adag]` to install it.")
  393. cupy_spec = importlib.util.find_spec("cupy")
  394. if cupy_spec is None and APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
  395. raise ValueError(
  396. "cupy is not installed but required since "
  397. "APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
  398. "Run `pip install ray[adag]` and check cupy installation.")
  399. def _compiled_ray_dag(self, enable_asyncio: bool):
  400. assert self.parallel_config.use_ray
  401. self._check_ray_adag_installation()
  402. from ray.dag import InputNode, MultiOutputNode
  403. from ray.experimental.channel.torch_tensor_type import TorchTensorType
  404. logger.info(f"APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = "
  405. f"{APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL}")
  406. with InputNode() as input_data:
  407. # Example DAG: PP=2, TP=4
  408. # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501
  409. # -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501
  410. # -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501
  411. # -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501
  412. # All workers in the first TP group will take in the
  413. # ExecuteModelRequest as input.
  414. outputs = [input_data for _ in self.pp_tp_workers[0]]
  415. for pp_rank, tp_group in enumerate(self.pp_tp_workers):
  416. # Each PP worker takes in the output of the previous PP worker,
  417. # and the TP group executes in SPMD fashion.
  418. outputs = [
  419. worker.execute_model_spmd.
  420. bind( # type: ignore[attr-defined]
  421. outputs[i]) for i, worker in enumerate(tp_group)
  422. ]
  423. last_pp_rank = len(self.pp_tp_workers) - 1
  424. if pp_rank < last_pp_rank:
  425. # Specify how intermediate tensors should be passed
  426. # between pp stages, no need to specify for the last
  427. # pp stage.
  428. transport = "nccl" \
  429. if APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
  430. else "auto"
  431. outputs = [
  432. output.with_type_hint(
  433. TorchTensorType(transport=transport))
  434. for output in outputs
  435. ]
  436. forward_dag = MultiOutputNode(outputs)
  437. return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
  438. def __del__(self):
  439. self.shutdown()
  440. class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
  441. def __init__(self, *args, **kwargs):
  442. super().__init__(*args, **kwargs)
  443. self.pp_locks: Optional[List[asyncio.Lock]] = None
  444. self.use_ray_spmd_worker = APHRODITE_USE_RAY_SPMD_WORKER
  445. if not self.use_ray_compiled_dag:
  446. self.driver_exec_method = make_async(
  447. self.driver_worker.execute_method)
  448. async def execute_model_async(
  449. self,
  450. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  451. if not self.use_ray_spmd_worker:
  452. return await super().execute_model_async(execute_model_req)
  453. if self.forward_dag is None:
  454. self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
  455. serialized_data = self.input_encoder.encode(execute_model_req)
  456. dag_future = await self.forward_dag.execute_async(serialized_data)
  457. outputs = await dag_future
  458. return self.output_decoder.decode(outputs[0])
  459. async def _driver_execute_model_async(
  460. self,
  461. execute_model_req: Optional[ExecuteModelRequest] = None
  462. ) -> List[SamplerOutput]:
  463. assert not self.use_ray_spmd_worker, (
  464. "driver_worker does not exist for APHRODITE_USE_RAY_SPMD_WORKER=1")
  465. if not self.tp_driver_workers:
  466. return await self.driver_exec_method("execute_model",
  467. execute_model_req)
  468. if self.pp_locks is None:
  469. # This locks each pipeline parallel stage so multiple virtual
  470. # engines can't execute on the same stage at the same time
  471. # We create the locks here to avoid creating them in the constructor
  472. # which uses a different asyncio loop.
  473. self.pp_locks = [
  474. asyncio.Lock()
  475. for _ in range(self.parallel_config.pipeline_parallel_size)
  476. ]
  477. tasks = [
  478. asyncio.create_task(
  479. _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
  480. "execute_model", execute_model_req))
  481. ]
  482. for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
  483. start=1):
  484. tasks.append(
  485. asyncio.create_task(
  486. _run_task_with_lock(driver_worker.execute_method.remote,
  487. self.pp_locks[pp_rank],
  488. "execute_model", execute_model_req)))
  489. results = await asyncio.gather(*tasks)
  490. # Only the last PP stage has the final results.
  491. return results[-1]
  492. async def _start_worker_execution_loop(self):
  493. assert not self.use_ray_spmd_worker, (
  494. "worker loop is disabled for APHRODITE_USE_RAY_SPMD_WORKER=1")
  495. coros = [
  496. worker.execute_method.remote("start_worker_execution_loop")
  497. for worker in self.non_driver_workers
  498. ]
  499. return await asyncio.gather(*coros)
  500. def __del__(self):
  501. self.shutdown()