123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580 |
- import asyncio
- import os
- from collections import defaultdict
- from itertools import islice, repeat
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
- import msgspec
- from loguru import logger
- import aphrodite.common.envs as envs
- from aphrodite.common.sequence import ExecuteModelRequest
- from aphrodite.common.utils import (_run_task_with_lock,
- get_aphrodite_instance_id,
- get_distributed_init_method, get_ip,
- get_open_port, make_async)
- from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
- DistributedGPUExecutor, DistributedGPUExecutorAsync)
- from aphrodite.executor.msgspec_utils import encode_hook
- from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
- from aphrodite.modeling.layers.sampler import SamplerOutput
- if ray is not None:
- from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
- if TYPE_CHECKING:
- from ray.util.placement_group import PlacementGroup
- # If the env var is set, it uses the Ray's compiled DAG API
- # which optimizes the control plane overhead.
- # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
- APHRODITE_USE_RAY_COMPILED_DAG = envs.APHRODITE_USE_RAY_COMPILED_DAG
- APHRODITE_TRACE_FUNCTION = envs.APHRODITE_TRACE_FUNCTION
- APHRODITE_USE_RAY_SPMD_WORKER = envs.APHRODITE_USE_RAY_SPMD_WORKER
- APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = (
- envs.APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
- class RayGPUExecutor(DistributedGPUExecutor):
- uses_ray: bool = True
- def _init_executor(self) -> None:
- self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
- # If the env var is set, it uses the Ray's compiled DAG API
- # which optimizes the control plane overhead.
- # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
- # Currently, this requires USE_RAY_SPMD_WORKER=True.
- self.use_ray_compiled_dag = APHRODITE_USE_RAY_COMPILED_DAG
- # If the env var is set, then we do not distinguish between the
- # "driver worker" vs other workers. Also, the rank 0 worker will
- # be executed in a remote Ray worker. Currently this requires
- # USE_RAY_COMPILED_DAG=True.
- self.use_ray_spmd_worker = APHRODITE_USE_RAY_SPMD_WORKER
- if self.use_ray_compiled_dag:
- assert self.use_ray_spmd_worker, (
- "APHRODITE_USE_RAY_COMPILED_DAG=1 requires "
- "APHRODITE_USE_RAY_SPMD_WORKER=1")
- if self.use_ray_spmd_worker:
- # TODO: Support SPMD worker for non-DAG Ray executor.
- assert self.use_ray_compiled_dag, (
- "APHRODITE_USE_RAY_SPMD_WORKER=1 requires "
- "APHRODITE_USE_RAY_COMPILED_DAG=1")
- assert self.uses_ray
- placement_group = self.parallel_config.placement_group
- # Disable Ray usage stats collection.
- ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
- if ray_usage != "1":
- os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
- # Create the parallel GPU workers.
- self._init_workers_ray(placement_group)
- self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
- self.output_decoder = msgspec.msgpack.Decoder(
- Optional[List[SamplerOutput]])
- def shutdown(self) -> None:
- if hasattr(self, "forward_dag") and self.forward_dag is not None:
- self.forward_dag.teardown()
- import ray
- for worker in self.workers:
- ray.kill(worker)
- self.forward_dag = None
- def _configure_ray_workers_use_nsight(self,
- ray_remote_kwargs) -> Dict[str, Any]:
- # If nsight profiling is enabled, we need to set the profiling
- # configuration for the ray workers as runtime env.
- runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
- runtime_env.update({
- "nsight": {
- "t": "cuda,cudnn,cublas",
- "o": "'worker_process_%p'",
- "cuda-graph-trace": "node",
- }
- })
- return ray_remote_kwargs
- def _get_worker_wrapper_args(self) -> Dict[str, Any]:
- (worker_module_name, worker_class_name,
- worker_class_fn) = self._get_worker_module_and_class()
- return dict(
- worker_module_name=worker_module_name,
- worker_class_name=worker_class_name,
- worker_class_fn=worker_class_fn,
- trust_remote_code=self.model_config.trust_remote_code,
- )
- # child class could overwrite this to return actual env vars.
- def _get_env_vars_to_be_updated(self):
- return self._env_vars_for_all_workers
- def _init_workers_ray(self, placement_group: "PlacementGroup",
- **ray_remote_kwargs):
- if (self.parallel_config.tensor_parallel_size == 1
- and self.parallel_config.pipeline_parallel_size == 1):
- # For single GPU case, we use a ray worker with constrained memory.
- num_gpus = self.cache_config.gpu_memory_utilization
- else:
- # Otherwise, the ray workers are allocated with a full GPU.
- num_gpus = 1
- # The driver dummy worker does not actually use any resources.
- # It holds the resource for the driver worker.
- self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
- # The remaining workers are the actual ray actors.
- self.workers: List[RayWorkerWrapper] = []
- # Used in ray compiled DAG: indexed first by PP rank,
- # and then TP rank. In other words, the inner list is
- # the TP group of workers for a PP rank.
- self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
- if self.parallel_config.ray_workers_use_nsight:
- ray_remote_kwargs = self._configure_ray_workers_use_nsight(
- ray_remote_kwargs)
- logger.info(f"use_ray_spmd_worker: {self.use_ray_spmd_worker}")
- # Create the workers.
- driver_ip = get_ip()
- logger.info(f"driver_ip: {driver_ip}")
- worker_wrapper_kwargs = self._get_worker_wrapper_args()
- for bundle_id, bundle in enumerate(placement_group.bundle_specs):
- if not bundle.get("GPU", 0):
- continue
- scheduling_strategy = PlacementGroupSchedulingStrategy(
- placement_group=placement_group,
- placement_group_capture_child_tasks=True,
- placement_group_bundle_index=bundle_id,
- )
- worker = ray.remote(
- num_cpus=0,
- num_gpus=num_gpus,
- scheduling_strategy=scheduling_strategy,
- **ray_remote_kwargs,
- )(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
- if self.use_ray_spmd_worker:
- self.workers.append(worker)
- else:
- worker_ip = ray.get(worker.get_node_ip.remote())
- if worker_ip == driver_ip and self.driver_dummy_worker is None:
- # If the worker is on the same node as the driver, we use it
- # as the resource holder for the driver process.
- self.driver_dummy_worker = worker
- self.driver_worker = RayWorkerWrapper(
- **worker_wrapper_kwargs)
- else:
- # Else, added to the list of workers.
- self.workers.append(worker)
- logger.debug(f"workers: {self.workers}")
- logger.debug(f"driver_dummy_worker: {self.driver_dummy_worker}")
- if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
- raise ValueError(
- "Ray does not allocate any GPUs on the driver node. Consider "
- "adjusting the Ray placement group or running the driver on a "
- "GPU node.")
- worker_ips = [
- ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
- for worker in self.workers
- ]
- ip_counts: Dict[str, int] = {}
- for ip in worker_ips:
- ip_counts[ip] = ip_counts.get(ip, 0) + 1
- def sort_by_driver_then_worker_ip(worker):
- """
- Sort the workers based on 3 properties:
- 1. If the worker is on the same node as the driver (vllm engine),
- it should be placed first.
- 2. Then, if the worker is on a node with fewer workers, it should
- be placed first.
- 3. Finally, if the work is on a node with smaller IP address, it
- should be placed first.
- """
- ip = ray.get(worker.get_node_ip.remote())
- return (ip != driver_ip, ip_counts[ip], ip)
- # After sorting, the workers on the same node will be
- # close to each other, and the workers on the driver
- # node will be placed first.
- self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
- # Get the set of GPU IDs used on each node.
- worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
- use_dummy_driver=True)
- node_workers = defaultdict(list) # node id -> list of worker ranks
- node_gpus = defaultdict(list) # node id -> list of gpu ids
- for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
- node_workers[node_id].append(i)
- # `gpu_ids` can be a list of strings or integers.
- # convert them to integers for consistency.
- # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
- # string sorting is not sufficient.
- gpu_ids = [int(x) for x in gpu_ids]
- node_gpus[node_id].extend(gpu_ids)
- for node_id, gpu_ids in node_gpus.items():
- node_gpus[node_id] = sorted(gpu_ids)
- all_ips = set(worker_ips + [driver_ip])
- n_ips = len(all_ips)
- n_nodes = len(node_workers)
- if n_nodes != n_ips:
- raise RuntimeError(
- f"Every node should have a unique IP address. Got {n_nodes}"
- f" nodes with node ids {list(node_workers.keys())} and "
- f"{n_ips} unique IP addresses {all_ips}. Please check your"
- " network configuration. If you set `APHRODITE_HOST_IP` or "
- "`HOST_IP` environment variable, make sure it is unique for"
- " each node.")
- APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
- # Set environment variables for the driver and workers.
- all_args_to_update_environment_variables = [({
- "CUDA_VISIBLE_DEVICES":
- ",".join(map(str, node_gpus[node_id])),
- "APHRODITE_INSTANCE_ID":
- APHRODITE_INSTANCE_ID,
- "APHRODITE_TRACE_FUNCTION":
- str(APHRODITE_TRACE_FUNCTION),
- **({
- "APHRODITE_ATTENTION_BACKEND": envs.APHRODITE_ATTENTION_BACKEND
- } if envs.APHRODITE_ATTENTION_BACKEND is not None else {})
- }, ) for (node_id, _) in worker_node_and_gpu_ids]
- self._env_vars_for_all_workers = (
- all_args_to_update_environment_variables)
- self._run_workers("update_environment_variables",
- all_args=self._get_env_vars_to_be_updated())
- if len(node_gpus) == 1:
- # in single node case, we don't need to get the IP address.
- # the loopback address is sufficient
- # NOTE: a node may have several IP addresses, one for each
- # network interface. `get_ip()` might return any of them,
- # while they might not work for communication inside the node
- # if the network setup is complicated. Using the loopback address
- # solves this issue, as it always works for communication inside
- # the node.
- driver_ip = "127.0.0.1"
- distributed_init_method = get_distributed_init_method(
- driver_ip, get_open_port())
- # Initialize the actual workers inside worker wrapper.
- init_worker_all_kwargs = [
- self._get_worker_kwargs(
- local_rank=node_workers[node_id].index(rank),
- rank=rank,
- distributed_init_method=distributed_init_method,
- ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
- ]
- self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
- self._run_workers("init_device")
- self._run_workers("load_model",
- max_concurrent_workers=self.parallel_config.
- max_parallel_loading_workers)
- if self.use_ray_spmd_worker:
- for pp_rank in range(self.parallel_config.pipeline_parallel_size):
- self.pp_tp_workers.append([])
- for tp_rank in range(
- self.parallel_config.tensor_parallel_size):
- # PP=2, TP=4
- # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
- rank = (pp_rank * self.parallel_config.tensor_parallel_size
- ) + tp_rank
- assert len(self.pp_tp_workers[pp_rank]) == tp_rank
- assert pp_rank < len(self.pp_tp_workers)
- self.pp_tp_workers[pp_rank].append(self.workers[rank])
- # This is the list of workers that are rank 0 of each TP group EXCEPT
- # global rank 0. These are the workers that will broadcast to the
- # rest of the workers.
- self.tp_driver_workers: List[RayWorkerWrapper] = []
- # This is the list of workers that are not drivers and not the first
- # worker in a TP group. These are the workers that will be
- # broadcasted to.
- self.non_driver_workers: List[RayWorkerWrapper] = []
- # Enforce rank order for correct rank to return final output.
- for index, worker in enumerate(self.workers):
- # The driver worker is rank 0 and not in self.workers.
- rank = index + 1
- if rank % self.parallel_config.tensor_parallel_size == 0:
- self.tp_driver_workers.append(worker)
- else:
- self.non_driver_workers.append(worker)
- def _driver_execute_model(
- self, execute_model_req: Optional[ExecuteModelRequest]
- ) -> Optional[List[SamplerOutput]]:
- """Run execute_model in the driver worker.
- Passing None will cause the driver to stop the model execution
- loop running in each of the remote workers.
- """
- assert not self.use_ray_spmd_worker, (
- "driver_worker does not exist for APHRODITE_USE_RAY_SPMD_WORKER=1")
- return self.driver_worker.execute_method("execute_model",
- execute_model_req)
- def execute_model(
- self,
- execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
- if not self.use_ray_spmd_worker:
- return super().execute_model(execute_model_req)
- if self.forward_dag is None:
- self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
- serialized_data = self.input_encoder.encode(execute_model_req)
- outputs = ray.get(self.forward_dag.execute(serialized_data))
- output = self.output_decoder.decode(outputs[0])
- return output
- def _run_workers(
- self,
- method: str,
- *args,
- async_run_tensor_parallel_workers_only: bool = False,
- all_args: Optional[List[Tuple[Any, ...]]] = None,
- all_kwargs: Optional[List[Dict[str, Any]]] = None,
- use_dummy_driver: bool = False,
- max_concurrent_workers: Optional[int] = None,
- **kwargs,
- ) -> Any:
- """Runs the given method on all workers. Can be used in the following
- ways:
- Args:
- - async_run_tensor_parallel_workers_only: If True the method will be
- run only in the remote TP workers, not the driver worker.
- It will also be run asynchronously and return a list of futures
- rather than blocking on the results.
- - args/kwargs: All workers share the same args/kwargs
- - all_args/all_kwargs: args/kwargs for each worker are specified
- individually
- """
- if self.use_ray_spmd_worker:
- assert not async_run_tensor_parallel_workers_only, (
- "async_run_tensor_parallel_workers_only is not supported for "
- "spmd mode.")
- if max_concurrent_workers:
- raise NotImplementedError(
- "max_concurrent_workers is not supported yet.")
- count = len(self.workers) if not \
- async_run_tensor_parallel_workers_only \
- else len(self.non_driver_workers)
- # If using SPMD worker, all workers are the same, so we should execute
- # the args on all workers. Otherwise, we skip the first worker's args
- # because those args will go to the driver worker.
- first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
- all_worker_args = repeat(args, count) if all_args is None \
- else islice(all_args, first_worker_args_index, None)
- all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
- else islice(all_kwargs, first_worker_args_index, None)
- # Start the ray workers first.
- ray_workers = self.workers
- if async_run_tensor_parallel_workers_only:
- ray_workers = self.non_driver_workers
- ray_worker_outputs = [
- worker.execute_method.remote(method, *worker_args, **worker_kwargs)
- for (worker, worker_args, worker_kwargs
- ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
- ]
- if async_run_tensor_parallel_workers_only:
- # Just return futures
- return ray_worker_outputs
- driver_worker_output = []
- # In SPMD mode, the driver worker is the same as any other worker,
- # so we only explicitly execute on the driver worker if using a
- # non-SPMD worker class.
- if not self.use_ray_spmd_worker:
- driver_args = args if all_args is None else all_args[0]
- driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
- # Start the driver worker after all the ray workers.
- if not use_dummy_driver:
- driver_worker_output = [
- self.driver_worker.execute_method(method, *driver_args,
- **driver_kwargs)
- ]
- else:
- assert self.driver_dummy_worker is not None
- driver_worker_output = [
- ray.get(
- self.driver_dummy_worker.execute_method.remote(
- method, *driver_args, **driver_kwargs))
- ]
- # Get the results of the ray workers.
- if self.workers:
- ray_worker_outputs = ray.get(ray_worker_outputs)
- return driver_worker_output + ray_worker_outputs
- def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
- """Wait for futures returned from _run_workers() with
- async_run_remote_workers_only to complete."""
- ray.get(parallel_worker_tasks)
- def _check_ray_adag_installation(self):
- import pkg_resources
- from packaging import version
- required_version = version.parse("2.35")
- current_version = version.parse(
- pkg_resources.get_distribution("ray").version)
- if current_version < required_version:
- raise ValueError(f"Ray version {required_version} or greater is "
- f"required, but found {current_version}")
- import importlib.util
- adag_spec = importlib.util.find_spec(
- "ray.experimental.compiled_dag_ref")
- if adag_spec is None:
- raise ValueError("Ray accelerated DAG is not installed. "
- "Run `pip install ray[adag]` to install it.")
- cupy_spec = importlib.util.find_spec("cupy")
- if cupy_spec is None and APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
- raise ValueError(
- "cupy is not installed but required since "
- "APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
- "Run `pip install ray[adag]` and check cupy installation.")
-
- def _compiled_ray_dag(self, enable_asyncio: bool):
- assert self.parallel_config.use_ray
- self._check_ray_adag_installation()
- from ray.dag import InputNode, MultiOutputNode
- from ray.experimental.channel.torch_tensor_type import TorchTensorType
- logger.info(f"APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = "
- f"{APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL}")
- with InputNode() as input_data:
- # Example DAG: PP=2, TP=4
- # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501
- # -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501
- # -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501
- # -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501
- # All workers in the first TP group will take in the
- # ExecuteModelRequest as input.
- outputs = [input_data for _ in self.pp_tp_workers[0]]
- for pp_rank, tp_group in enumerate(self.pp_tp_workers):
- # Each PP worker takes in the output of the previous PP worker,
- # and the TP group executes in SPMD fashion.
- outputs = [
- worker.execute_model_spmd.
- bind( # type: ignore[attr-defined]
- outputs[i]) for i, worker in enumerate(tp_group)
- ]
- last_pp_rank = len(self.pp_tp_workers) - 1
- if pp_rank < last_pp_rank:
- # Specify how intermediate tensors should be passed
- # between pp stages, no need to specify for the last
- # pp stage.
- transport = "nccl" \
- if APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
- else "auto"
- outputs = [
- output.with_type_hint(
- TorchTensorType(transport=transport))
- for output in outputs
- ]
- forward_dag = MultiOutputNode(outputs)
- return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
- def __del__(self):
- self.shutdown()
- class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.pp_locks: Optional[List[asyncio.Lock]] = None
- self.use_ray_spmd_worker = APHRODITE_USE_RAY_SPMD_WORKER
- if not self.use_ray_compiled_dag:
- self.driver_exec_method = make_async(
- self.driver_worker.execute_method)
- async def execute_model_async(
- self,
- execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
- if not self.use_ray_spmd_worker:
- return await super().execute_model_async(execute_model_req)
- if self.forward_dag is None:
- self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
- serialized_data = self.input_encoder.encode(execute_model_req)
- dag_future = await self.forward_dag.execute_async(serialized_data)
- outputs = await dag_future
- return self.output_decoder.decode(outputs[0])
- async def _driver_execute_model_async(
- self,
- execute_model_req: Optional[ExecuteModelRequest] = None
- ) -> List[SamplerOutput]:
- assert not self.use_ray_spmd_worker, (
- "driver_worker does not exist for APHRODITE_USE_RAY_SPMD_WORKER=1")
- if not self.tp_driver_workers:
- return await self.driver_exec_method("execute_model",
- execute_model_req)
- if self.pp_locks is None:
- # This locks each pipeline parallel stage so multiple virtual
- # engines can't execute on the same stage at the same time
- # We create the locks here to avoid creating them in the constructor
- # which uses a different asyncio loop.
- self.pp_locks = [
- asyncio.Lock()
- for _ in range(self.parallel_config.pipeline_parallel_size)
- ]
- tasks = [
- asyncio.create_task(
- _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
- "execute_model", execute_model_req))
- ]
- for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
- start=1):
- tasks.append(
- asyncio.create_task(
- _run_task_with_lock(driver_worker.execute_method.remote,
- self.pp_locks[pp_rank],
- "execute_model", execute_model_req)))
- results = await asyncio.gather(*tasks)
- # Only the last PP stage has the final results.
- return results[-1]
- async def _start_worker_execution_loop(self):
- assert not self.use_ray_spmd_worker, (
- "worker loop is disabled for APHRODITE_USE_RAY_SPMD_WORKER=1")
- coros = [
- worker.execute_method.remote("start_worker_execution_loop")
- for worker in self.non_driver_workers
- ]
- return await asyncio.gather(*coros)
- def __del__(self):
- self.shutdown()
|