import asyncio import os from collections import defaultdict from itertools import islice, repeat from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, Union) from loguru import logger import aphrodite.common.envs as envs from aphrodite.common.sequence import ExecuteModelRequest from aphrodite.common.utils import (get_aphrodite_instance_id, get_distributed_init_method, get_ip, get_open_port, make_async) from aphrodite.executor.executor_base import ExecutorAsyncBase from aphrodite.executor.ray_utils import RayWorkerWrapper, ray from aphrodite.executor.tpu_executor import TPUExecutor from aphrodite.modeling.layers.sampler import SamplerOutput if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup APHRODITE_TRACE_FUNCTION = envs.APHRODITE_TRACE_FUNCTION class RayTPUExecutor(TPUExecutor): uses_ray: bool = True def __init__(self, *args, **kwargs): # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncAphrodite case. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None # Updated by implementations that require additional args to be passed # to the _run_workers execute_model call self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} super().__init__(*args, **kwargs) def _init_executor(self) -> None: assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group # Disable Ray usage stats collection. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") if ray_usage != "1": os.environ["RAY_USAGE_STATS_ENABLED"] = "0" # Create the parallel TPU workers. self._init_workers_ray(placement_group) def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None # The remaining workers are the actual ray actors. self.workers: List[RayWorkerWrapper] = [] # Create the workers. driver_ip = get_ip() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("TPU", 0): continue scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_capture_child_tasks=True, placement_group_bundle_index=bundle_id, ) assert self.speculative_config is None if self.scheduler_config.is_multi_step: worker_module_name = "aphrodite.worker.multi_step_tpu_worker" worker_class_name = "MultiStepTPUWorker" else: worker_module_name = "aphrodite.worker.tpu_worker" worker_class_name = "TPUWorker" # GKE does not fetch environment information from metadata server # and instead sets these from within the Ray process. Therefore we # need to override the Ray environment variables manually. override_env = {} if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: override_env.update({ "TPU_CHIPS_PER_HOST_BOUNDS": os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] }) if "TPU_HOST_BOUNDS" in os.environ: override_env.update( {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) worker = ray.remote( num_cpus=0, resources={"TPU": 1}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, )(RayWorkerWrapper).remote( worker_module_name=worker_module_name, worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) if override_env: worker.override_env_vars.remote(override_env) worker_ip = ray.get(worker.get_node_ip.remote()) if worker_ip == driver_ip and self.driver_dummy_worker is None: # If the worker is on the same node as the driver, we use it # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( worker_module_name=worker_module_name, worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) else: # Else, added to the list of workers. self.workers.append(worker) logger.debug(f"workers: {self.workers}") logger.debug(f"driver_dummy_worker: {self.driver_dummy_worker}") if self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any TPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " "TPU node.") worker_ips = [ ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] for worker in self.workers ] ip_counts: Dict[str, int] = {} for ip in worker_ips: ip_counts[ip] = ip_counts.get(ip, 0) + 1 def sort_by_driver_then_worker_ip(worker): """ Sort the workers based on 3 properties: 1. If the worker is on the same node as the driver (vllm engine), it should be placed first. 2. Then, if the worker is on a node with fewer workers, it should be placed first. 3. Finally, if the work is on a node with smaller IP address, it should be placed first. """ ip = ray.get(worker.get_node_ip.remote()) return (ip != driver_ip, ip_counts[ip], ip) # After sorting, the workers on the same node will be # close to each other, and the workers on the driver # node will be placed first. self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) # Get the set of TPU IDs used on each node. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", use_dummy_driver=True) node_workers = defaultdict(list) for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): node_workers[node_id].append(i) APHRODITE_INSTANCE_ID = get_aphrodite_instance_id() # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ "APHRODITE_INSTANCE_ID": APHRODITE_INSTANCE_ID, "APHRODITE_TRACE_FUNCTION": str(APHRODITE_TRACE_FUNCTION), }, ) for _ in worker_node_and_gpu_ids] self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) if len(node_workers) == 1: # in single node case, we don't need to get the IP address. # the loopback address is sufficient # NOTE: a node may have several IP addresses, one for each # network interface. `get_ip()` might return any of them, # while they might not work for communication inside the node # if the network setup is complicated. Using the loopback address # solves this issue, as it always works for communication inside # the node. driver_ip = "127.0.0.1" distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) # Initialize the actual workers inside worker wrapper. init_worker_all_kwargs = [ self._get_worker_kwargs( local_rank=node_workers[node_id].index(rank), rank=rank, distributed_init_method=distributed_init_method, ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) ] self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") self._run_workers("load_model", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) def _driver_execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ return self.driver_worker.execute_method("execute_model", execute_model_req) def _run_workers( self, method: str, *args, async_run_remote_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers. Can be used in the following ways: - async_run_remote_workers_only: If True the method will be run only in the remote workers, not the driver worker. It will also be run asynchronously and return a list of futures rather than blocking on the results. - args/kwargs: All workers share the same args/kwargs - all_args/all_kwargs: args/kwargs for each worker are specified individually """ if max_concurrent_workers: raise NotImplementedError( "max_concurrent_workers is not supported yet.") count = len(self.workers) all_worker_args = repeat(args, count) if all_args is None \ else islice(all_args, 1, None) all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ else islice(all_kwargs, 1, None) # Start the ray workers first. ray_worker_outputs = [ worker.execute_method.remote(method, *worker_args, **worker_kwargs) for (worker, worker_args, worker_kwargs ) in zip(self.workers, all_worker_args, all_worker_kwargs) ] if async_run_remote_workers_only: # Just return futures return ray_worker_outputs driver_args = args if all_args is None else all_args[0] driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] # Start the driver worker after all the ray workers. if not use_dummy_driver: driver_worker_output = self.driver_worker.execute_method( method, *driver_args, **driver_kwargs) else: assert self.driver_dummy_worker is not None driver_worker_output = ray.get( self.driver_dummy_worker.execute_method.remote( method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: ray_worker_outputs = ray.get(ray_worker_outputs) return [driver_worker_output] + ray_worker_outputs def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: """Wait for futures returned from _run_workers() with async_run_remote_workers_only to complete.""" ray.get(parallel_worker_tasks) def determine_num_available_blocks(self) -> Tuple[int, int]: num_blocks = self._run_workers("determine_num_available_blocks", ) num_tpu_blocks = min(b[0] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks) return num_tpu_blocks, num_cpu_blocks def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: logger.info(f"# TPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks self._run_workers("initialize_cache", num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) def execute_model( self, execute_model_req: ExecuteModelRequest, ) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", async_run_remote_workers_only=True, **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. return self._driver_execute_model(execute_model_req) def stop_remote_worker_execution_loop(self) -> None: if self.parallel_worker_tasks is None: return self._driver_execute_model() parallel_worker_tasks = self.parallel_worker_tasks self.parallel_worker_tasks = None # Ensure that workers exit model loop cleanly # (this will raise otherwise) self._wait_for_tasks_completion(parallel_worker_tasks) class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.driver_exec_method = make_async(self.driver_worker.execute_method) async def execute_model_async( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: # Start model execution loop running in the parallel workers self.parallel_worker_tasks = asyncio.create_task( self._start_worker_execution_loop()) # Only the driver worker returns the sampling results. return await self._driver_execute_model_async(execute_model_req) async def stop_remote_worker_execution_loop_async(self) -> None: if self.parallel_worker_tasks is None: return await self._driver_execute_model_async() parallel_worker_tasks = self.parallel_worker_tasks self.parallel_worker_tasks = None # Ensure that workers exit model loop cleanly # (this will raise otherwise) await parallel_worker_tasks async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: return await self.driver_exec_method("execute_model", execute_model_req) async def _start_worker_execution_loop(self): coros = [ worker.execute_method.remote("start_worker_execution_loop") for worker in self.workers ] return await asyncio.gather(*coros)