import asyncio import os import pickle from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from loguru import logger from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata from aphrodite.common.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) from aphrodite.engine.ray_tools import RayWorkerWrapper, ray from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase from aphrodite.lora.request import LoRARequest if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it. USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0)) class RayGPUExecutor(ExecutorBase): def _init_executor(self) -> None: assert (not self.speculative_config ), "Speculative decoding not yet supported for RayGPU backend." assert self.parallel_config.worker_use_ray placement_group = self.parallel_config.placement_group # Disable Ray usage stats collection. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") if ray_usage != "1": os.environ["RAY_USAGE_STATS_ENABLED"] = "0" # Create the parallel GPU workers. self._init_workers_ray(placement_group) self.forward_dag = None if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> Dict[str, Any]: # If nsight profiling is enabled, we need to set the profiling # configuration for the ray workers as runtime env. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) runtime_env.update({ "nsight": { "t": "cuda,cudnn,cublas", "o": "'worker_process_%p'", "cuda-graph-trace": "node", } }) return ray_remote_kwargs def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): if self.parallel_config.tensor_parallel_size == 1: # For single GPU case, we use a ray worker with constrained memory. num_gpus = self.cache_config.gpu_memory_utilization else: # Otherwise, the ray workers are allocated with a full GPU. num_gpus = 1 # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. self.driver_dummy_worker: RayWorkerWrapper = None # The remaining workers are the actual ray actors. self.workers: List[RayWorkerWrapper] = [] if self.parallel_config.ray_workers_use_nsight: ray_remote_kwargs = self._configure_ray_workers_use_nsight( ray_remote_kwargs) # Create the workers. driver_ip = get_ip() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("GPU", 0): continue scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_capture_child_tasks=True, placement_group_bundle_index=bundle_id, ) worker = ray.remote( num_cpus=0, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, )(RayWorkerWrapper).remote( worker_module_name="aphrodite.task_handler.worker", worker_class_name="Worker", ) worker_ip = ray.get(worker.get_node_ip.remote()) if worker_ip == driver_ip and self.driver_dummy_worker is None: # If the worker is on the same node as the driver, we use it # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( worker_module_name="aphrodite.task_handler.worker", worker_class_name="Worker", ) else: # Else, added to the list of workers. self.workers.append(worker) if self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " "GPU node.") worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", use_dummy_driver=True) node_workers = defaultdict(list) node_gpus = defaultdict(list) for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): node_workers[node_id].append(i) node_gpus[node_id].extend(gpu_ids) for node_id, gpu_ids in node_gpus.items(): node_gpus[node_id] = sorted(gpu_ids) # Set CUDA_VISIBLE_DEVICES for the driver and workers. all_args_to_update_environment_variables = [] for (node_id, _) in worker_node_and_gpu_ids: all_args_to_update_environment_variables.append([{ "CUDA_VISIBLE_DEVICES": ",".join(map(str, node_gpus[node_id])), }]) self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) def collect_arg_helper_func(**kwargs): return kwargs init_worker_all_kwargs = [] for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): local_rank = node_workers[node_id].index(rank) init_worker_all_kwargs.append( collect_arg_helper_func( model_config=self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, )) self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") self._run_workers( "load_model", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers, ) def determine_num_available_blocks(self) -> tuple[int, int]: """Determine the number of available KV blocks. This invokes `determine_num_available_blocks` on each worker and takes the min of the results, guaranteeing that the selected cache sizes are compatible with all workers. Returns: - tuple[num_gpu_blocks, num_cpu_blocks] """ # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers("determine_num_available_blocks", ) # Since we use a shared centralized controller, we take the minimum # number of blocks across all workers to make sure all the memory # operators can be applied to all workers. num_gpu_blocks = min(b[0] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks) return num_gpu_blocks, num_cpu_blocks def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Initialize the KV cache in all workers. """ # NOTE: We log here to avoid multiple logs when number of workers is # greater than one. We could log in the engine, but not all executors # have GPUs. logger.info(f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") logger.info( f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501 ) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks self._run_workers("initialize_cache", num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], num_lookahead_slots: int = 0) -> List[SamplerOutput]: all_outputs = self._run_workers( "execute_model", driver_kwargs={ "seq_group_metadata_list": seq_group_metadata_list, "blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_out": blocks_to_swap_out, "blocks_to_copy": blocks_to_copy, }, use_ray_compiled_dag=USE_RAY_COMPILED_DAG) # Only the driver worker returns the sampling results. output = all_outputs[0] return output def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." return self._run_workers( "add_lora", lora_request=lora_request, ) def remove_lora(self, lora_id: int) -> bool: assert lora_id > 0, "lora_id must be greater than 0." return self._run_workers( "remove_lora", lora_id=lora_id, ) def list_loras(self) -> Set[int]: return self._run_workers("list_loras") def _run_workers( self, method: str, *args, driver_args: Optional[Tuple[Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, all_args: Optional[List[List[Any]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers.""" if driver_args is None: driver_args = args if driver_kwargs is None: driver_kwargs = kwargs # For MyPy type checking assert driver_args is not None assert driver_kwargs is not None if all_args is None: all_args = [driver_args] + [args] * len(self.workers) if all_kwargs is None: all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers) assert all_args is not None assert all_kwargs is not None if max_concurrent_workers: raise NotImplementedError( "max_concurrent_workers is not supported yet.") if use_ray_compiled_dag: # Right now, compiled DAG can only accept a single # input. TODO: Fix it. output_channels = self.forward_dag.execute(1) else: # Start the ray workers first. ray_worker_outputs = [ worker.execute_method.remote(method, *worker_args, **worker_kwargs) for (worker, worker_args, worker_kwargs ) in zip(self.workers, all_args[1:], all_kwargs[1:]) ] if driver_args is None: driver_args = args if driver_kwargs is None: driver_kwargs = kwargs # Start the driver worker after all the ray workers. if not use_dummy_driver: driver_worker_output = self.driver_worker.execute_method( method, *all_args[0], **all_kwargs[0]) else: driver_worker_output = ray.get( self.driver_dummy_worker.execute_method.remote( method, *all_args[0], **all_kwargs[0])) # Get the results of the ray workers. if self.workers: if use_ray_compiled_dag: try: ray_worker_outputs = [ pickle.loads(chan.begin_read()) for chan in output_channels ] finally: # Has to call end_read in order to reuse the DAG. for chan in output_channels: chan.end_read() else: ray_worker_outputs = ray.get(ray_worker_outputs) return [driver_worker_output] + ray_worker_outputs def _compiled_ray_dag(self): import pkg_resources required_version = "2.9" current_version = pkg_resources.get_distribution("ray").version if current_version < required_version: raise ValueError(f"Ray version {required_version} or greater is " f"required, but found {current_version}") from ray.dag import InputNode, MultiOutputNode assert self.parallel_config.worker_use_ray # Right now, compiled DAG requires at least 1 arg. We send # a dummy value for now. It will be fixed soon. with InputNode() as input_data: forward_dag = MultiOutputNode([ worker.execute_model_compiled_dag_remote.bind(input_data) for worker in self.workers ]) return forward_dag.experimental_compile() def check_health(self) -> None: """Raises an error if engine is unhealthy.""" self._check_if_any_actor_is_dead() def _check_if_any_actor_is_dead(self): if not self.workers: return dead_actors = [] for actor in self.workers: actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access if actor_state["State"] == "DEAD": dead_actors.append(actor) if dead_actors: raise RuntimeError("At least one Worker is dead. " f"Dead Workers: {dead_actors}. ") class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): async def _run_workers_async( self, method: str, *args, driver_args: Optional[List[Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> Any: """Runs the given method on all workers.""" coros = [] if driver_args is None: driver_args = args if driver_kwargs is None: driver_kwargs = kwargs # Run the driver worker asynchronously. def helper(): return self.driver_worker.execute_method(method, *driver_args, **driver_kwargs) driver_executor = make_async(helper) coros.append(driver_executor()) # Run the ray workers asynchronously. for worker in self.workers: coros.append(worker.execute_method.remote(method, *args, **kwargs)) all_outputs = await asyncio.gather(*coros) return all_outputs async def execute_model_async( self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], num_lookahead_slots: int = 0, ) -> SamplerOutput: all_outputs = await self._run_workers_async( "execute_model", driver_kwargs={ "seq_group_metadata_list": seq_group_metadata_list, "blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_out": blocks_to_swap_out, "blocks_to_copy": blocks_to_copy, "num_lookahead_slots": num_lookahead_slots, }) # Only the driver worker returns the sampling results. output = all_outputs[0] return output