123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313 |
- import asyncio
- import os
- from collections import defaultdict
- from itertools import islice, repeat
- from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
- Union)
- from loguru import logger
- from aphrodite import envs
- from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
- from aphrodite.common.utils import (get_aphrodite_instance_id,
- get_distributed_init_method, get_ip,
- get_open_port, make_async)
- from aphrodite.executor.executor_base import ExecutorAsyncBase
- from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
- from aphrodite.executor.tpu_executor import TPUExecutor
- if ray is not None:
- from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
- if TYPE_CHECKING:
- from ray.util.placement_group import PlacementGroup
- APHRODITE_TRACE_FUNCTION = envs.APHRODITE_TRACE_FUNCTION
- class RayTPUExecutor(TPUExecutor):
- def __init__(self, *args, **kwargs):
- # This is non-None when the execute model loop is running
- # in the parallel workers. It's a coroutine in the AsyncAphrodite case.
- self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
- # Updated by implementations that require additional args to be passed
- # to the _run_workers execute_model call
- self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
- super().__init__(*args, **kwargs)
- def _init_executor(self) -> None:
- assert self.parallel_config.distributed_executor_backend == "ray"
- placement_group = self.parallel_config.placement_group
- # Disable Ray usage stats collection.
- ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
- if ray_usage != "1":
- os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
- # Create the parallel TPU workers.
- self._init_workers_ray(placement_group)
- def _init_workers_ray(self, placement_group: "PlacementGroup",
- **ray_remote_kwargs):
- # The driver dummy worker does not actually use any resources.
- # It holds the resource for the driver worker.
- self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
- # The remaining workers are the actual ray actors.
- self.workers: List[RayWorkerWrapper] = []
- # Create the workers.
- driver_ip = get_ip()
- for bundle_id, bundle in enumerate(placement_group.bundle_specs):
- if not bundle.get("TPU", 0):
- continue
- scheduling_strategy = PlacementGroupSchedulingStrategy(
- placement_group=placement_group,
- placement_group_capture_child_tasks=True,
- placement_group_bundle_index=bundle_id,
- )
- assert self.speculative_config is None
- worker_module_name = "aphrodite.task_handler.tpu_worker"
- worker_class_name = "TPUWorker"
- worker = ray.remote(
- num_cpus=0,
- resources={"TPU": 1},
- scheduling_strategy=scheduling_strategy,
- **ray_remote_kwargs,
- )(RayWorkerWrapper).remote(
- worker_module_name=worker_module_name,
- worker_class_name=worker_class_name,
- trust_remote_code=self.model_config.trust_remote_code,
- )
- worker_ip = ray.get(worker.get_node_ip.remote())
- if worker_ip == driver_ip and self.driver_dummy_worker is None:
- # If the worker is on the same node as the driver, we use it
- # as the resource holder for the driver process.
- self.driver_dummy_worker = worker
- self.driver_worker = RayWorkerWrapper(
- worker_module_name=worker_module_name,
- worker_class_name=worker_class_name,
- trust_remote_code=self.model_config.trust_remote_code,
- )
- else:
- # Else, added to the list of workers.
- self.workers.append(worker)
- if self.driver_dummy_worker is None:
- raise ValueError(
- "Ray does not allocate any TPUs on the driver node. Consider "
- "adjusting the Ray placement group or running the driver on a "
- "TPU node.")
- # Get the set of TPU IDs used on each node.
- worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
- use_dummy_driver=True)
- node_workers = defaultdict(list)
- for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
- node_workers[node_id].append(i)
- APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
- # Set environment variables for the driver and workers.
- all_args_to_update_environment_variables = [({
- "APHRODITE_INSTANCE_ID":
- APHRODITE_INSTANCE_ID,
- "APHRODITE_TRACE_FUNCTION":
- str(APHRODITE_TRACE_FUNCTION),
- }, ) for _ in worker_node_and_gpu_ids]
- self._run_workers("update_environment_variables",
- all_args=all_args_to_update_environment_variables)
- if len(node_workers) == 1:
- # in single node case, we don't need to get the IP address.
- # the loopback address is sufficient
- # NOTE: a node may have several IP addresses, one for each
- # network interface. `get_ip()` might return any of them,
- # while they might not work for communication inside the node
- # if the network setup is complicated. Using the loopback address
- # solves this issue, as it always works for communication inside
- # the node.
- driver_ip = "127.0.0.1"
- distributed_init_method = get_distributed_init_method(
- driver_ip, get_open_port())
- # Initialize the actual workers inside worker wrapper.
- init_worker_all_kwargs = [
- self._get_worker_kwargs(
- local_rank=node_workers[node_id].index(rank),
- rank=rank,
- distributed_init_method=distributed_init_method,
- ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
- ]
- self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
- self._run_workers("init_device")
- self._run_workers("load_model",
- max_concurrent_workers=self.parallel_config.
- max_parallel_loading_workers)
- def _driver_execute_model(
- self,
- execute_model_req: Optional[ExecuteModelRequest] = None
- ) -> List[SamplerOutput]:
- """Run execute_model in the driver worker.
- Passing None will cause the driver to stop the model execution
- loop running in each of the remote workers.
- """
- return self.driver_worker.execute_method("execute_model",
- execute_model_req)
- def _run_workers(
- self,
- method: str,
- *args,
- async_run_remote_workers_only: bool = False,
- all_args: Optional[List[Tuple[Any, ...]]] = None,
- all_kwargs: Optional[List[Dict[str, Any]]] = None,
- use_dummy_driver: bool = False,
- max_concurrent_workers: Optional[int] = None,
- use_ray_compiled_dag: bool = False,
- **kwargs,
- ) -> Any:
- """Runs the given method on all workers. Can be used in the following
- ways:
- - async_run_remote_workers_only: If True the method will be run only
- in the remote workers, not the driver worker. It will also be
- run asynchronously and return a list of futures rather than blocking
- on the results.
- - args/kwargs: All workers share the same args/kwargs
- - all_args/all_kwargs: args/kwargs for each worker are specified
- individually
- """
- if max_concurrent_workers:
- raise NotImplementedError(
- "max_concurrent_workers is not supported yet.")
- count = len(self.workers)
- all_worker_args = repeat(args, count) if all_args is None \
- else islice(all_args, 1, None)
- all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
- else islice(all_kwargs, 1, None)
- # Start the ray workers first.
- ray_worker_outputs = [
- worker.execute_method.remote(method, *worker_args, **worker_kwargs)
- for (worker, worker_args, worker_kwargs
- ) in zip(self.workers, all_worker_args, all_worker_kwargs)
- ]
- if async_run_remote_workers_only:
- # Just return futures
- return ray_worker_outputs
- driver_args = args if all_args is None else all_args[0]
- driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
- # Start the driver worker after all the ray workers.
- if not use_dummy_driver:
- driver_worker_output = self.driver_worker.execute_method(
- method, *driver_args, **driver_kwargs)
- else:
- assert self.driver_dummy_worker is not None
- driver_worker_output = ray.get(
- self.driver_dummy_worker.execute_method.remote(
- method, *driver_args, **driver_kwargs))
- # Get the results of the ray workers.
- if self.workers:
- ray_worker_outputs = ray.get(ray_worker_outputs)
- return [driver_worker_output] + ray_worker_outputs
- def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
- """Wait for futures returned from _run_workers() with
- async_run_remote_workers_only to complete."""
- ray.get(parallel_worker_tasks)
- def determine_num_available_blocks(self) -> Tuple[int, int]:
- num_blocks = self._run_workers("determine_num_available_blocks", )
- num_tpu_blocks = min(b[0] for b in num_blocks)
- num_cpu_blocks = min(b[1] for b in num_blocks)
- return num_tpu_blocks, num_cpu_blocks
- def initialize_cache(self, num_gpu_blocks: int,
- num_cpu_blocks: int) -> None:
- logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
- num_cpu_blocks)
- self.cache_config.num_gpu_blocks = num_gpu_blocks
- self.cache_config.num_cpu_blocks = num_cpu_blocks
- self._run_workers("initialize_cache",
- num_gpu_blocks=num_gpu_blocks,
- num_cpu_blocks=num_cpu_blocks)
- def execute_model(
- self,
- execute_model_req: ExecuteModelRequest,
- ) -> List[SamplerOutput]:
- if self.parallel_worker_tasks is None:
- self.parallel_worker_tasks = self._run_workers(
- "start_worker_execution_loop",
- async_run_remote_workers_only=True,
- **self.extra_execute_model_run_workers_kwargs)
- # Only the driver worker returns the sampling results.
- return self._driver_execute_model(execute_model_req)
- def stop_remote_worker_execution_loop(self) -> None:
- if self.parallel_worker_tasks is None:
- return
- self._driver_execute_model()
- parallel_worker_tasks = self.parallel_worker_tasks
- self.parallel_worker_tasks = None
- # Ensure that workers exit model loop cleanly
- # (this will raise otherwise)
- self._wait_for_tasks_completion(parallel_worker_tasks)
- class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.driver_exec_method = make_async(self.driver_worker.execute_method)
- async def execute_model_async(
- self,
- execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
- if self.parallel_worker_tasks is None:
- # Start model execution loop running in the parallel workers
- self.parallel_worker_tasks = asyncio.create_task(
- self._start_worker_execution_loop())
- # Only the driver worker returns the sampling results.
- return await self._driver_execute_model_async(execute_model_req)
- async def stop_remote_worker_execution_loop_async(self) -> None:
- if self.parallel_worker_tasks is None:
- return
- await self._driver_execute_model_async()
- parallel_worker_tasks = self.parallel_worker_tasks
- self.parallel_worker_tasks = None
- # Ensure that workers exit model loop cleanly
- # (this will raise otherwise)
- await parallel_worker_tasks
- async def _driver_execute_model_async(
- self,
- execute_model_req: Optional[ExecuteModelRequest] = None
- ) -> List[SamplerOutput]:
- return await self.driver_exec_method("execute_model",
- execute_model_req)
- async def _start_worker_execution_loop(self):
- coros = [
- worker.execute_method.remote("start_worker_execution_loop")
- for worker in self.workers
- ]
- return await asyncio.gather(*coros)
|