Преглед изворни кода

feat: model executor refactor (#367)

* pipe the executor classes to the engine

* fix ray error

* formatting
AlpinDale пре 11 месеци
родитељ
комит
0f6d56b07f

+ 2 - 2
aphrodite/__init__.py

@@ -1,7 +1,7 @@
 from aphrodite.engine.args_tools import AsyncEngineArgs, EngineArgs
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
-from aphrodite.engine.ray_tools import initialize_cluster
+from aphrodite.engine.ray_tools import initialize_ray_cluster
 from aphrodite.endpoints.llm import LLM
 from aphrodite.common.outputs import CompletionOutput, RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
@@ -17,5 +17,5 @@ __all__ = [
     "EngineArgs",
     "AsyncAphrodite",
     "AsyncEngineArgs",
-    "initialize_cluster",
+    "initialize_ray_cluster",
 ]

+ 6 - 1
aphrodite/common/config.py

@@ -1,4 +1,4 @@
-from typing import Optional, Union, ClassVar
+from typing import TYPE_CHECKING, Optional, Union, ClassVar
 from dataclasses import dataclass
 import os
 from packaging.version import Version
@@ -11,6 +11,9 @@ from aphrodite.transformers_utils.config import get_config
 from aphrodite.common.utils import (get_cpu_memory, is_hip, is_neuron,
                                     get_nvcc_cuda_version)
 
+if TYPE_CHECKING:
+    from ray.util.placement_group import PlacementGroup
+
 _GB = 1 << 30
 
 
@@ -482,6 +485,7 @@ class ParallelConfig:
         max_parallel_loading_workers: Optional[int] = None,
         disable_custom_all_reduce: bool = False,
         ray_workers_use_nsight: bool = False,
+        placement_group: Optional["PlacementGroup"] = None,
     ) -> None:
         self.pipeline_parallel_size = pipeline_parallel_size
         if is_neuron():
@@ -497,6 +501,7 @@ class ParallelConfig:
         self.max_parallel_loading_workers = max_parallel_loading_workers
         self.disable_custom_all_reduce = disable_custom_all_reduce
         self.ray_workers_use_nsight = ray_workers_use_nsight
+        self.placement_group = placement_group
 
         self.world_size = pipeline_parallel_size * self.tensor_parallel_size
         # Ray worker is not supported for Neuron backend.

+ 46 - 440
aphrodite/engine/aphrodite_engine.py

@@ -1,19 +1,5 @@
-import copy
-from collections import defaultdict
-import os
 import time
-import pickle
-import importlib
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Dict,
-    Iterable,
-    List,
-    Optional,
-    Tuple,
-    Union,
-)
+from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
 from loguru import logger
 from transformers import PreTrainedTokenizer
 
@@ -29,12 +15,9 @@ from aphrodite.common.config import (
 )
 from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
 from aphrodite.engine.args_tools import EngineArgs
+from aphrodite.executor.executor_base import ExecutorBase
 from aphrodite.engine.metrics import StatLogger, Stats
-from aphrodite.engine.ray_tools import (
-    RayWorkerAphrodite,
-    initialize_cluster,
-    ray,
-)
+from aphrodite.engine.ray_tools import (initialize_ray_cluster)
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (
@@ -51,33 +34,11 @@ from aphrodite.transformers_utils.tokenizer import (
     TokenizerGroup,
 )
 from aphrodite.common.utils import (
-    Counter,
-    set_cuda_visible_devices,
-    get_ip,
-    get_open_port,
-    get_distributed_init_method,
-)
+    Counter, )
 from aphrodite.common.logger import setup_logger
 
-if ray:
-    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
-
-if TYPE_CHECKING:
-    from ray.util.placement_group import PlacementGroup
-
 _LOCAL_LOGGING_INTERVAL_SEC = 5
 
-# A map between the device type (in device config) to its worker module.
-DEVICE_TO_WORKER_MODULE_MAP = {
-    "cuda": "aphrodite.task_handler.worker",
-    "neuron": "aphrodite.task_handler.neuron_worker",
-}
-
-# If the env var is set, it uses the Ray's compiled DAG API
-# which optimizes the control plane overhead.
-# Run APHRODITE with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
-USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
-
 
 class AphroditeEngine:
     """An LLM engine that receives requests and generates texts.
@@ -103,8 +64,8 @@ class AphroditeEngine:
         scheduler_config: The configuration related to the request scheduler.
         device_config: The configuration related to the device.
         lora_config: The configuration related to LoRA.
-        placement_group: Ray placement group for distributed execution.
-            Required for distributed execution.
+        executor_class: The model executor class for managing distributed
+            execution.
         log_stats: Whether to log statistics.
     """
 
@@ -116,7 +77,7 @@ class AphroditeEngine:
         scheduler_config: SchedulerConfig,
         device_config: DeviceConfig,
         lora_config: Optional[LoRAConfig],
-        placement_group: Optional["PlacementGroup"],
+        executor_class: Type[ExecutorBase],
         log_stats: bool,
     ) -> None:
         logger.info(
@@ -148,33 +109,13 @@ class AphroditeEngine:
         self._init_tokenizer()
         self.seq_counter = Counter()
 
-        # Create the parallel GPU workers.
-        if self.parallel_config.worker_use_ray:
-            # Disable Ray usage stats collection.
-            ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
-            if ray_usage != "1":
-                os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
-            # Pass additional arguments to initialize the worker
-            additional_ray_args = {}
-            if self.parallel_config.ray_workers_use_nsight:
-                logger.info("Configuring Ray workers to use nsight.")
-                additional_ray_args = {
-                    "runtime_env": {
-                        "nsight": {
-                            "t": "cuda,cudnn,cublas",
-                            "o": "'worker_process_%p'",
-                            "cuda-graph-trace": "node",
-                        }
-                    }
-                }
-            self._init_workers_ray(placement_group, **additional_ray_args)
-        else:
-            self._init_workers()
-
-        # Profile the memory usage and initialize the cache.
-        self._init_cache()
+        self.model_executor = executor_class(model_config, cache_config,
+                                             parallel_config, scheduler_config,
+                                             device_config, lora_config)
 
         # Create the scheduler.
+        # NOTE: the cache_config here have been updated with the numbers of
+        # GPU and CPU blocks, which are profiled in the distributed executor.
         self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
 
         # Metric Logging.
@@ -185,9 +126,29 @@ class AphroditeEngine:
             )
             self.stat_logger.info("cache_config", self.cache_config)
 
-        self.forward_dag = None
-        if USE_RAY_COMPILED_DAG:
-            self.forward_dag = self._compiled_ray_dag()
+    @classmethod
+    def from_engine_args(cls, engine_args: EngineArgs) -> "AphroditeEngine":
+        """Creates an LLM engine from the engine arguments."""
+        # Create the engine configs.
+        engine_configs = engine_args.create_engine_configs()
+        parallel_config = engine_configs[2]
+
+        # Initialize the cluster and specify the executor class.
+        if parallel_config.worker_use_ray:
+            initialize_ray_cluster(parallel_config)
+            from aphrodite.executor.ray_gpu_executor import RayGPUExecutor
+            executor_class = RayGPUExecutor
+        else:
+            assert parallel_config.world_size == 1, (
+                "Ray is required if parallel_config.world_size > 1.")
+            from aphrodite.executor.gpu_executor import GPUExecutor
+            executor_class = GPUExecutor
+
+        # Create the LLM engine.
+        engine = cls(*engine_configs,
+                     executor_class=executor_class,
+                     log_stats=not engine_args.disable_log_stats)
+        return engine
 
     def __reduce__(self):
         # This is to ensure that the AphroditeEngine is not referenced in
@@ -201,40 +162,6 @@ class AphroditeEngine:
                               sequence: Sequence) -> "PreTrainedTokenizer":
         return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
 
-    def _dispatch_worker(self):
-        worker_module = DEVICE_TO_WORKER_MODULE_MAP[
-            self.device_config.device_type]
-        imported_worker = importlib.import_module(worker_module)
-        Worker = imported_worker.Worker
-        return Worker
-
-    def _init_workers(self):
-        # Lazy import the Worker to avoid importing torch.cuda/xformers
-        # before CUDA_VISIBLE_DEVICES is set in the Worker
-        Worker = self._dispatch_worker()
-
-        assert (self.parallel_config.world_size == 1
-                ), "Ray is required if parallel_config.world_size > 1."
-
-        self.workers: List[Worker] = []
-        distributed_init_method = get_distributed_init_method(
-            get_ip(), get_open_port())
-        self.driver_worker = Worker(
-            self.model_config,
-            self.parallel_config,
-            self.scheduler_config,
-            self.device_config,
-            local_rank=0,
-            rank=0,
-            distributed_init_method=distributed_init_method,
-            lora_config=self.lora_config,
-            kv_cache_dtype=self.cache_config.cache_dtype,
-            # kv_quant_params_path=(self.cache_config.cache_quant_params_path),
-            is_driver_worker=True,
-        )
-        self._run_workers("init_model")
-        self._run_workers("load_model")
-
     def _init_tokenizer(self, **tokenizer_init_kwargs):
         init_kwargs = dict(
             enable_lora=bool(self.lora_config),
@@ -248,131 +175,6 @@ class AphroditeEngine:
         self.tokenizer: TokenizerGroup = TokenizerGroup(
             self.model_config.tokenizer, **init_kwargs)
 
-    def _init_workers_ray(self, placement_group: "PlacementGroup",
-                          **ray_remote_kwargs):
-        if self.parallel_config.tensor_parallel_size == 1:
-            num_gpus = self.cache_config.gpu_memory_utilization
-        else:
-            num_gpus = 1
-
-        self.driver_dummy_worker: RayWorkerAphrodite = None
-        self.workers: List[RayWorkerAphrodite] = []
-
-        driver_ip = get_ip()
-        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
-            if not bundle.get("GPU", 0):
-                continue
-            scheduling_strategy = PlacementGroupSchedulingStrategy(
-                placement_group=placement_group,
-                placement_group_capture_child_tasks=True,
-                placement_group_bundle_index=bundle_id,
-            )
-            worker = ray.remote(
-                num_cpus=0,
-                num_gpus=num_gpus,
-                scheduling_strategy=scheduling_strategy,
-                **ray_remote_kwargs,
-            )(RayWorkerAphrodite).remote(self.model_config.trust_remote_code)
-
-            worker_ip = ray.get(worker.get_node_ip.remote())
-            if worker_ip == driver_ip and self.driver_dummy_worker is None:
-                # If the worker is on the same node as the driver, we use it
-                # as the resource holder for the driver process.
-                self.driver_dummy_worker = worker
-            else:
-                self.workers.append(worker)
-
-        if self.driver_dummy_worker is None:
-            raise ValueError(
-                "Ray does not allocate any GPUs on the driver node. Consider "
-                "adjusting the Ray placement group or running the driver on a "
-                "GPU node.")
-
-        driver_node_id, driver_gpu_ids = ray.get(
-            self.driver_dummy_worker.get_node_and_gpu_ids.remote())
-        worker_node_and_gpu_ids = ray.get(
-            [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
-
-        node_workers = defaultdict(list)
-        node_gpus = defaultdict(list)
-
-        node_workers[driver_node_id].append(0)
-        node_gpus[driver_node_id].extend(driver_gpu_ids)
-        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
-                                               start=1):
-            node_workers[node_id].append(i)
-            node_gpus[node_id].extend(gpu_ids)
-        for node_id, gpu_ids in node_gpus.items():
-            node_gpus[node_id] = sorted(gpu_ids)
-
-        # Set CUDA_VISIBLE_DEVICES for the driver.
-        set_cuda_visible_devices(node_gpus[driver_node_id])
-        for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
-            worker.set_cuda_visible_devices.remote(node_gpus[node_id])
-
-        distributed_init_method = get_distributed_init_method(
-            driver_ip, get_open_port())
-
-        # Lazy import the Worker to avoid importing torch.cuda/xformers
-        # before CUDA_VISIBLE_DEVICES is set in the Worker
-        Worker = self._dispatch_worker()
-
-        # Initialize torch distributed process group for the workers.
-        model_config = copy.deepcopy(self.model_config)
-        parallel_config = copy.deepcopy(self.parallel_config)
-        scheduler_config = copy.deepcopy(self.scheduler_config)
-        device_config = copy.deepcopy(self.device_config)
-        lora_config = copy.deepcopy(self.lora_config)
-        kv_cache_dtype = self.cache_config.cache_dtype
-        # kv_quant_params_path = self.cache_config.cache_quant_params_path
-
-        for rank, (worker, (node_id,
-                            _)) in enumerate(zip(self.workers,
-                                                 worker_node_and_gpu_ids),
-                                             start=1):
-            local_rank = node_workers[node_id].index(rank)
-            worker.init_worker.remote(
-                lambda rank=rank, local_rank=local_rank: Worker(
-                    model_config,
-                    parallel_config,
-                    scheduler_config,
-                    device_config,
-                    local_rank,
-                    rank,
-                    distributed_init_method,
-                    lora_config=lora_config,
-                    kv_cache_dtype=kv_cache_dtype,
-                    # kv_quant_params_path=kv_quant_params_path,
-                ))
-
-        driver_rank = 0
-        driver_local_rank = node_workers[driver_node_id].index(driver_rank)
-        self.driver_worker = Worker(
-            self.model_config,
-            self.parallel_config,
-            self.scheduler_config,
-            self.device_config,
-            driver_local_rank,
-            driver_rank,
-            distributed_init_method,
-            lora_config=self.lora_config,
-            kv_cache_dtype=kv_cache_dtype,
-            # kv_quant_params_path=kv_quant_params_path,
-            is_driver_worker=True,
-        )
-
-        # don't use cupy for eager mode
-        self._run_workers(
-            "init_model",
-            cupy_port=get_open_port()
-            if not model_config.enforce_eager else None,
-        )
-        self._run_workers(
-            "load_model",
-            max_concurrent_workers=self.parallel_config.
-            max_parallel_loading_workers,
-        )
-
     def _verify_args(self) -> None:
         self.model_config.verify_with_parallel_config(self.parallel_config)
         self.cache_config.verify_with_parallel_config(self.parallel_config)
@@ -381,90 +183,6 @@ class AphroditeEngine:
             self.lora_config.verify_with_scheduler_config(
                 self.scheduler_config)
 
-    def _init_cache(self) -> None:
-        # ruff: noqa: E501
-        """Profiles the memory usage and initializes the KV cache.
-
-        The engine will first conduct a profiling of the existing memory usage.
-        Then, it calculate the maximum possible number of GPU and CPU blocks
-        that can be allocated with the remaining free memory.
-        More details can be found in the
-        :meth:`~aphrodite.task_handler.worker.Worker.profile_num_available_blocks` method
-        from class :class:`~aphrodite.task_handler.Worker`.
-
-        Afterwards, as there may be multiple workers,
-        we take the minimum number of blocks across all workers
-        to ensure this can be applied to all of them.
-
-        Finally, the engine will initialize the KV cache
-        with the calculated number of blocks.
-
-        .. tip::
-            You may limit the usage of GPU memory
-            by adjusting the `gpu_memory_utilization` parameters.
-        """
-        # Get the maximum number of blocks that can be allocated on GPU and CPU.
-        num_blocks = self._run_workers(
-            "profile_num_available_blocks",
-            block_size=self.cache_config.block_size,
-            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
-            cpu_swap_space=self.cache_config.swap_space_bytes,
-            cache_dtype=self.cache_config.cache_dtype,
-        )
-
-        # Since we use a shared centralized controller, we take the minimum
-        # number of blocks across all workers to make sure all the memory
-        # operators can be applied to all workers.
-        num_gpu_blocks = min(b[0] for b in num_blocks)
-        num_cpu_blocks = min(b[1] for b in num_blocks)
-        # FIXME: Change to debug log.
-        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
-                    f"# CPU blocks: {num_cpu_blocks}")
-
-        logger.info(
-            f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"
-        )
-
-        if num_gpu_blocks <= 0:
-            raise ValueError("No available memory for the cache blocks. "
-                             "Try increasing `gpu_memory_utilization` when "
-                             "initializing the engine.")
-        max_seq_len = self.cache_config.block_size * num_gpu_blocks
-        logger.info(f"Maximum sequence length allowed in the cache: "
-                    f"{max_seq_len}")
-        if self.model_config.max_model_len > max_seq_len:
-            raise ValueError(
-                f"The model's max seq len ({self.model_config.max_model_len}) "
-                "is larger than the maximum number of tokens that can be "
-                f"stored in KV cache ({max_seq_len}). Try increasing "
-                "`gpu_memory_utilization` or decreasing `max_model_len` when "
-                "initializing the engine.")
-
-        self.cache_config.num_gpu_blocks = num_gpu_blocks
-        self.cache_config.num_cpu_blocks = num_cpu_blocks
-
-        # Initialize the cache.
-        self._run_workers("init_cache_engine", cache_config=self.cache_config)
-        # Warm up the model. This includes capturing the model into CUDA graph
-        # if enforce_eager is False.
-        self._run_workers("warm_up_model")
-
-    @classmethod
-    def from_engine_args(cls, engine_args: EngineArgs) -> "AphroditeEngine":
-        """Creates an LLM engine from the engine arguments."""
-        # Create the engine configs.
-        engine_configs = engine_args.create_engine_configs()
-        parallel_config = engine_configs[2]
-        # Initialize the cluster.
-        placement_group = initialize_cluster(parallel_config)
-        # Create the LLM engine.
-        engine = cls(
-            *engine_configs,
-            placement_group,
-            log_stats=not engine_args.disable_log_stats,
-        )
-        return engine
-
     def encode_request(
         self,
         request_id: str,
@@ -894,7 +612,7 @@ class AphroditeEngine:
                 - A Sequence Group (SG) refer to a group of sequences
                   that are generated from the same prompt.
 
-            - Step 2: Calls the workers to execute the model.
+            - Step 2: Calls the distributed executor to execute the model.
             - Step 3: Processes the model output. This mainly includes:
 
                 - Decodes the relevant outputs.
@@ -930,20 +648,10 @@ class AphroditeEngine:
         seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
 
         if not scheduler_outputs.is_empty():
-            # Execute the model.
-            all_outputs = self._run_workers(
-                "execute_model",
-                driver_kwargs={
-                    "seq_group_metadata_list": seq_group_metadata_list,
-                    "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
-                    "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
-                    "blocks_to_copy": scheduler_outputs.blocks_to_copy,
-                },
-                use_ray_compiled_dag=USE_RAY_COMPILED_DAG,
-            )
-
-            # Only the driver worker returns the sampling results.
-            output = all_outputs[0]
+            output = self.model_executor.execute_model(
+                seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
+                scheduler_outputs.blocks_to_swap_out,
+                scheduler_outputs.blocks_to_copy)
         else:
             output = []
 
@@ -999,7 +707,8 @@ class AphroditeEngine:
             # Latency Timings.
             time_last_iters = []
             for seq_group in scheduler_outputs.scheduled_seq_groups:
-                # Time since last token. (n.b. updates seq_group.metrics.last_token_time)
+                # Time since last token.
+                # (n.b. updates seq_group.metrics.last_token_time)
                 time_last_iters.append(seq_group.get_last_latency(now))
                 # Time since arrival for all finished requests.
                 if seq_group.is_finished():
@@ -1123,119 +832,16 @@ class AphroditeEngine:
             seq.output_text = seq.output_text[:-len(stop_string)]
 
     def add_lora(self, lora_request: LoRARequest) -> bool:
-        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
-        return self._run_workers(
-            "add_lora",
-            lora_request=lora_request,
-        )
+        return self.model_executor.add_lora(lora_request)
 
     def remove_lora(self, lora_id: int) -> bool:
-        assert lora_id > 0, "lora_id must be greater than 0."
-        return self._run_workers(
-            "remove_lora",
-            lora_id=lora_id,
-        )
+        return self.model_executor.remove_lora(lora_id)
 
     def list_loras(self) -> List[int]:
-        return self._run_workers("list_loras")
-
-    def _run_workers(
-        self,
-        method: str,
-        *args,
-        driver_args: Optional[List[Any]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
-        max_concurrent_workers: Optional[int] = None,
-        use_ray_compiled_dag: bool = False,
-        **kwargs,
-    ) -> Any:
-        """Runs the given method on all workers."""
-
-        if max_concurrent_workers:
-            raise NotImplementedError(
-                "max_concurrent_workers is not supported yet.")
-
-        if use_ray_compiled_dag:
-            # Right now, compiled DAG can only accept a single
-            # input.
-            # TODO: Fix it.
-            output_channels = self.forward_dag.execute(1)
-        else:
-            # Start the ray workers first.
-            ray_worker_outputs = [
-                worker.execute_method.remote(method, *args, **kwargs)
-                for worker in self.workers
-            ]
-
-        if driver_args is None:
-            driver_args = args
-        if driver_kwargs is None:
-            driver_kwargs = kwargs
-
-        # Start the driver worker after all the ray workers.
-        driver_worker_output = getattr(self.driver_worker,
-                                       method)(*driver_args, **driver_kwargs)
-
-        # Get the results of the ray workers.
-        if self.workers:
-            if use_ray_compiled_dag:
-                try:
-                    ray_worker_outputs = [
-                        pickle.loads(chan.begin_read())
-                        for chan in output_channels
-                    ]
-                finally:
-                    # Has to call end_read in order to reuse the DAG.
-                    for chan in output_channels:
-                        chan.end_read()
-            else:
-                ray_worker_outputs = ray.get(ray_worker_outputs)
-
-        return [driver_worker_output] + ray_worker_outputs
-
-    def _compiled_ray_dag(self):
-        from packaging import version
-        import pkg_resources
-
-        required_version = "2.9"
-        current_version = pkg_resources.get_distribution("ray").version
-
-        if version.parse(current_version) < version.parse(required_version):
-            raise ValueError(f"Ray version {required_version} or greater is "
-                             f"required, but found {current_version}")
-
-        from ray.dag import MultiOutputNode, InputNode
-
-        assert self.parallel_config.worker_use_ray
-
-        # Right now, compiled DAG requires at least 1 arg. We send
-        # a dummy value for now. It will be fixed soon.
-        with InputNode() as input_data:
-            forward_dag = MultiOutputNode([
-                worker.execute_model_compiled_dag_remote.bind(input_data)
-                for worker in self.workers
-            ])
-        return forward_dag.experimental_compile()
+        return self.model_executor.list_loras()
 
     def check_health(self) -> None:
-        """Raises an error if engine is unhealthy."""
-        self._check_if_any_actor_is_dead()
-
-    def _check_if_any_actor_is_dead(self):
-        if not self.parallel_config.worker_use_ray:
-            return
-
-        if not self.workers:
-            return
-
-        dead_actors = []
-        for actor in self.workers:
-            actor_state = ray.state.actors(actor._ray_actor_id.hex())
-            if actor_state["State"] == "DEAD":
-                dead_actors.append(actor)
-        if dead_actors:
-            raise RuntimeError("At least one Worker is dead. "
-                               f"Dead Workers: {dead_actors}. ")
+        self.model_executor.check_health()
 
 
 setup_logger()

+ 39 - 69
aphrodite/engine/async_aphrodite.py

@@ -2,8 +2,8 @@ import asyncio
 import os
 import time
 from functools import partial
-from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
-                    Union, AsyncIterator, Callable)
+from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type,
+                    Union, AsyncIterator)
 from loguru import logger
 from transformers import PreTrainedTokenizer
 
@@ -11,7 +11,7 @@ from aphrodite.lora.request import LoRARequest
 from aphrodite.common.config import ModelConfig
 from aphrodite.engine.args_tools import AsyncEngineArgs
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
-from aphrodite.engine.ray_tools import initialize_cluster, ray
+from aphrodite.engine.ray_tools import initialize_ray_cluster, ray
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 
@@ -211,17 +211,10 @@ class _AsyncAphrodite(AphroditeEngine):
 
         if not scheduler_outputs.is_empty():
             # Execute the model.
-            all_outputs = await self._run_workers_async(
-                "execute_model",
-                driver_kwargs={
-                    "seq_group_metadata_list": seq_group_metadata_list,
-                    "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
-                    "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
-                    "blocks_to_copy": scheduler_outputs.blocks_to_copy,
-                })
-
-            # Only the driver worker returns the sampling results.
-            output = all_outputs[0]
+            output = await self.model_executor.execute_model_async(
+                seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
+                scheduler_outputs.blocks_to_swap_out,
+                scheduler_outputs.blocks_to_copy)
         else:
             output = []
 
@@ -271,37 +264,8 @@ class _AsyncAphrodite(AphroditeEngine):
             lora_request=lora_request,
         )
 
-    async def _run_workers_async(
-        self,
-        method: str,
-        *args,
-        driver_args: Optional[List[Any]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
-        **kwargs,
-    ) -> Any:
-        """Runs the given method on all workers."""
-        coros = []
-
-        if driver_args is None:
-            driver_args = args
-        if driver_kwargs is None:
-            driver_kwargs = kwargs
-
-        # Run the driver worker asynchronously.
-        driver_executor = getattr(self.driver_worker, method)
-        coros.append(asyncio.get_event_loop().run_in_executor(
-            None, partial(driver_executor, *driver_args, **driver_kwargs)))
-
-        # Run the ray workers asynchronously.
-        for worker in self.workers:
-            coros.append(worker.execute_method.remote(method, *args, **kwargs))
-
-        all_outputs = await asyncio.gather(*coros)
-        return all_outputs
-
-    async def check_health_async(self):
-        """Raises an error if engine is unhealthy."""
-        self._check_if_any_actor_is_dead()
+    async def check_health_async(self) -> None:
+        self.model_executor.check_health()
 
 
 class AsyncAphrodite:
@@ -357,6 +321,34 @@ class AsyncAphrodite:
         self._request_tracker: Optional[RequestTracker] = None
         self._errored_with: Optional[BaseException] = None
 
+    @classmethod
+    def from_engine_args(cls,
+                         engine_args: AsyncEngineArgs,
+                         start_engine_loop: bool = True) -> "AsyncAphrodite":
+        """Creates an async LLM engine from the engine arguments."""
+        # Create the engine configs.
+        engine_configs = engine_args.create_engine_configs()
+        parallel_config = engine_configs[2]
+        if parallel_config.worker_use_ray or engine_args.engine_use_ray:
+            initialize_ray_cluster(parallel_config)
+            from aphrodite.executor.ray_gpu_executor import RayGPUExecutorAsync
+            executor_class = RayGPUExecutorAsync
+        else:
+            assert parallel_config.world_size == 1, (
+                "Ray is required if parallel_config.world_size > 1.")
+            from aphrodite.executor.gpu_executor import GPUExecutorAsync
+            executor_class = GPUExecutorAsync
+        # Create the async LLM engine.
+        engine = cls(parallel_config.worker_use_ray,
+                     engine_args.engine_use_ray,
+                     *engine_configs,
+                     executor_class,
+                     log_requests=not engine_args.disable_log_requests,
+                     log_stats=not engine_args.disable_log_stats,
+                     max_log_len=engine_args.max_log_len,
+                     start_engine_loop=start_engine_loop)
+        return engine
+
     @property
     def is_running(self) -> bool:
         return (self.background_loop is not None
@@ -579,7 +571,7 @@ class AsyncAphrodite:
             - If the engine is not running, start the background loop,
               which iteratively invokes
               # pylint: disable=line-too-long
-              :meth:`~aphrodite.engine.async_llm_engine.AsyncAphrodite.engine_step`
+              :meth:`~aphrodite.engine.async_aphrodite.AsyncAphrodite.engine_step`
               to process the waiting requests.
             - Add the request to the engine's `RequestTracker`.
               On the next background loop, this request will be sent to
@@ -682,35 +674,13 @@ class AsyncAphrodite:
         else:
             return self.engine.get_model_config()
 
-    @classmethod
-    def from_engine_args(cls,
-                         engine_args: AsyncEngineArgs,
-                         start_engine_loop: bool = True) -> "AsyncAphrodite":
-        """Creates an async LLM engine from the engine arguments."""
-        # Create the engine configs.
-        engine_configs = engine_args.create_engine_configs()
-        parallel_config = engine_configs[2]
-        # Initialize the cluster.
-        placement_group = initialize_cluster(parallel_config,
-                                             engine_args.engine_use_ray)
-        # Create the async LLM engine.
-        engine = cls(parallel_config.worker_use_ray,
-                     engine_args.engine_use_ray,
-                     *engine_configs,
-                     placement_group,
-                     log_requests=not engine_args.disable_log_requests,
-                     log_stats=not engine_args.disable_log_stats,
-                     max_log_len=engine_args.max_log_len,
-                     start_engine_loop=start_engine_loop)
-        return engine
-
     async def do_log_stats(self) -> None:
         if self.engine_use_ray:
             await self.engine.do_log_stats.remote()
         else:
             self.engine.do_log_stats()
 
-    async def check_health(self):
+    async def check_health(self) -> None:
         """Raises an error if engine is unhealthy."""
         t = time.perf_counter()
         logger.debug("Starting health check...")

+ 25 - 32
aphrodite/engine/ray_tools.py

@@ -1,6 +1,6 @@
 import pickle
 
-from typing import Optional, List, Tuple, TYPE_CHECKING
+from typing import Optional, List, Tuple
 from loguru import logger
 
 from aphrodite.common.config import ParallelConfig
@@ -63,45 +63,37 @@ except ImportError as e:
     ray = None
     RayWorkerAphrodite = None
 
-if TYPE_CHECKING:
-    from ray.util.placement_group import PlacementGroup
 
-
-def initialize_cluster(
+def initialize_ray_cluster(
     parallel_config: ParallelConfig,
-    engine_use_ray: bool = False,
     ray_address: Optional[str] = None,
-) -> Optional["PlacementGroup"]:
-    """Initialize the distributed cluster probably with Ray.
+):
+    """Initialize the distributed cluster with Ray.
+    it will connect to the Ray cluster and create a placement group
+    for the workers, which includes the specification of the resources
+    for each distributed worker.
 
     Args:
         parallel_config: The configurations for parallel execution.
-        engine_use_ray: Whether to use Ray for async engine.
         ray_address: The address of the Ray cluster. If None, uses
             the default Ray cluster address.
-
-    Returns:
-        An optional `PlacementGroup`. It includes the specification
-        of the resources for each distributed worker. None if Ray is
-        not used.
     """
-    if parallel_config.worker_use_ray or engine_use_ray:
-        if ray is None:
-            raise ImportError(
-                "Ray is not installed. Please install Ray to use distributed "
-                "serving.")
-        # Connect to a ray cluster.
-        if is_hip():
-            ray.init(address=ray_address,
-                     ignore_reinit_error=True,
-                     num_gpus=parallel_config.world_size)
-        else:
-            ray.init(address=ray_address, ignore_reinit_error=True)
-
-    if not parallel_config.worker_use_ray:
-        assert parallel_config.world_size == 1, (
-            "Ray is required if parallel_config.world_size > 1.")
-        return None
+    if ray is None:
+        raise ImportError(
+            "Ray is not installed. Please install Ray to use distributed "
+            "serving.")
+
+    # Connect to a ray cluster.
+    if is_hip():
+        ray.init(address=ray_address,
+                 ignore_reinit_error=True,
+                 num_gpus=parallel_config.world_size)
+    else:
+        ray.init(address=ray_address, ignore_reinit_error=True)
+
+    if parallel_config.placement_group:
+        # Placement group is already set.
+        return
 
     # Create placement group for worker processes
     current_placement_group = ray.util.get_current_placement_group()
@@ -136,4 +128,5 @@ def initialize_cluster(
         # if they cannot be provisioned.
         ray.get(current_placement_group.ready(), timeout=1800)
 
-    return current_placement_group
+    # Set the placement group in the parallel config
+    parallel_config.placement_group = current_placement_group

+ 15 - 9
aphrodite/executor/executor_base.py

@@ -1,16 +1,20 @@
 from abc import ABC, abstractmethod
 from typing import Dict, List, Optional
 
-from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
-                                     ParallelConfig, SchedulerConfig,
-                                     LoRAConfig)
+from aphrodite.common.config import (
+    CacheConfig,
+    DeviceConfig,
+    ModelConfig,
+    ParallelConfig,
+    SchedulerConfig,
+    LoRAConfig,
+)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 
 
 class ExecutorBase(ABC):
     """Base class for all executors.
-
     An executor is responsible for executing the model on a specific device
     type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
     that can execute the model on multiple devices.
@@ -29,11 +33,13 @@ class ExecutorBase(ABC):
         raise NotImplementedError
 
     @abstractmethod
-    def execute_model(self,
-                      seq_group_metadata_list: List[SequenceGroupMetadata],
-                      blocks_to_swap_in: Dict[int, int],
-                      blocks_to_swap_out: Dict[int, int],
-                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+    def execute_model(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
         """Executes one model step on the given sequences."""
         raise NotImplementedError
 

+ 55 - 27
aphrodite/executor/gpu_executor.py

@@ -1,16 +1,32 @@
+import importlib
 from typing import Dict, List, Optional
 
 from loguru import logger
 
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
-                                     ParallelConfig, SchedulerConfig,
-                                     LoRAConfig)
+from aphrodite.common.config import (
+    CacheConfig,
+    DeviceConfig,
+    ModelConfig,
+    ParallelConfig,
+    SchedulerConfig,
+    LoRAConfig,
+)
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.executor.utils import check_block_size_valid
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
-from aphrodite.common.utils import (get_ip, get_open_port,
-                                    get_distributed_init_method, make_async)
+from aphrodite.common.utils import (
+    get_ip,
+    get_open_port,
+    get_distributed_init_method,
+    make_async,
+)
+
+# A map between the device type (in device config) to its worker module.
+DEVICE_TO_WORKER_MODULE_MAP = {
+    "cuda": "aphrodite.task_handler.worker",
+    "neuron": "aphrodite.task_handler.neuron_worker",
+}
 
 
 class GPUExecutor(ExecutorBase):
@@ -37,13 +53,20 @@ class GPUExecutor(ExecutorBase):
         # Profile the memory usage and initialize the cache.
         self._init_cache()
 
+    def _dispatch_worker(self):
+        worker_module = DEVICE_TO_WORKER_MODULE_MAP[
+            self.device_config.device_type]
+        imported_worker = importlib.import_module(worker_module)
+        Worker = imported_worker.Worker
+        return Worker
+
     def _init_worker(self):
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # before CUDA_VISIBLE_DEVICES is set in the Worker
-        from aphrodite.task_handler.worker import Worker
+        Worker = self._dispatch_worker()
 
-        assert self.parallel_config.world_size == 1, (
-            "GPUExecutor only supports single GPU.")
+        assert (self.parallel_config.world_size == 1
+                ), "GPUExecutor only supports single GPU."
 
         distributed_init_method = get_distributed_init_method(
             get_ip(), get_open_port())
@@ -59,28 +82,27 @@ class GPUExecutor(ExecutorBase):
             kv_cache_dtype=self.cache_config.cache_dtype,
             is_driver_worker=True,
         )
-        self.driver_worker.init_device()
+        self.driver_worker.init_model()
         self.driver_worker.load_model()
 
     def _init_cache(self) -> None:
         """Profiles the memory usage and initializes the KV cache.
-
         The engine first profiles the existing memory usage.
         Then, it allocates the remaining memory for KV blocks.
-
         .. tip::
             You may limit the usage of GPU memory
             by adjusting the `gpu_memory_utilization` parameter.
         """
         # Get the maximum number of blocks that can be allocated on GPU and CPU.
-        num_gpu_blocks, num_cpu_blocks = (
-            self.driver_worker.profile_num_available_blocks(
-                block_size=self.cache_config.block_size,
-                gpu_memory_utilization=self.cache_config.
-                gpu_memory_utilization,
-                cpu_swap_space=self.cache_config.swap_space_bytes,
-                cache_dtype=self.cache_config.cache_dtype,
-            ))
+        (
+            num_gpu_blocks,
+            num_cpu_blocks,
+        ) = self.driver_worker.profile_num_available_blocks(
+            block_size=self.cache_config.block_size,
+            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
+            cpu_swap_space=self.cache_config.swap_space_bytes,
+            cache_dtype=self.cache_config.cache_dtype,
+        )
 
         logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                     f"# CPU blocks: {num_cpu_blocks}")
@@ -89,8 +111,11 @@ class GPUExecutor(ExecutorBase):
             f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"  # noqa: E501
         )
 
-        check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
-                               self.model_config.max_model_len)
+        check_block_size_valid(
+            num_gpu_blocks,
+            self.cache_config.block_size,
+            self.model_config.max_model_len,
+        )
 
         self.cache_config.num_gpu_blocks = num_gpu_blocks
         self.cache_config.num_cpu_blocks = num_cpu_blocks
@@ -101,11 +126,13 @@ class GPUExecutor(ExecutorBase):
         # if enforce_eager is False.
         self.driver_worker.warm_up_model()
 
-    def execute_model(self,
-                      seq_group_metadata_list: List[SequenceGroupMetadata],
-                      blocks_to_swap_in: Dict[int, int],
-                      blocks_to_swap_out: Dict[int, int],
-                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+    def execute_model(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
         output = self.driver_worker.execute_model(
             seq_group_metadata_list=seq_group_metadata_list,
             blocks_to_swap_in=blocks_to_swap_in,
@@ -144,7 +171,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
             seq_group_metadata_list=seq_group_metadata_list,
             blocks_to_swap_in=blocks_to_swap_in,
             blocks_to_swap_out=blocks_to_swap_out,
-            blocks_to_copy=blocks_to_copy)
+            blocks_to_copy=blocks_to_copy,
+        )
         return output
 
     async def check_health_async(self) -> None:

+ 35 - 37
aphrodite/executor/ray_gpu_executor.py

@@ -3,30 +3,22 @@ import copy
 from collections import defaultdict
 import os
 import pickle
+import importlib
 from typing import TYPE_CHECKING, Any, Dict, List, Optional
 
 from loguru import logger
 
-from aphrodite.common.config import (
-    CacheConfig,
-    DeviceConfig,
-    ModelConfig,
-    ParallelConfig,
-    SchedulerConfig,
-    LoRAConfig,
-)
+from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     LoRAConfig)
 from aphrodite.engine.ray_tools import RayWorkerAphrodite, ray
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.executor.utils import check_block_size_valid
 from aphrodite.lora.request import LoRARequest
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
-from aphrodite.common.utils import (
-    set_cuda_visible_devices,
-    get_ip,
-    get_open_port,
-    get_distributed_init_method,
-    make_async,
-)
+from aphrodite.common.utils import (set_cuda_visible_devices, get_ip,
+                                    get_open_port, get_distributed_init_method,
+                                    make_async)
 
 if ray is not None:
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@@ -34,6 +26,12 @@ if ray is not None:
 if TYPE_CHECKING:
     from ray.util.placement_group import PlacementGroup
 
+# A map between the device type (in device config) to its worker module.
+DEVICE_TO_WORKER_MODULE_MAP = {
+    "cuda": "aphrodite.task_handler.worker",
+    "neuron": "aphrodite.task_handler.neuron_worker",
+}
+
 # If the env var is set, it uses the Ray's compiled DAG API
 # which optimizes the control plane overhead.
 # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
@@ -76,6 +74,13 @@ class RayGPUExecutor(ExecutorBase):
         if USE_RAY_COMPILED_DAG:
             self.forward_dag = self._compiled_ray_dag()
 
+    def _dispatch_worker(self):
+        worker_module = DEVICE_TO_WORKER_MODULE_MAP[
+            self.device_config.device_type]
+        imported_worker = importlib.import_module(worker_module)
+        Worker = imported_worker.Worker
+        return Worker
+
     def _init_workers_ray(self, placement_group: "PlacementGroup",
                           **ray_remote_kwargs):
         if self.parallel_config.tensor_parallel_size == 1:
@@ -151,7 +156,7 @@ class RayGPUExecutor(ExecutorBase):
 
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # before CUDA_VISIBLE_DEVICES is set in the Worker
-        from aphrodite.task_handler.worker import Worker
+        Worker = self._dispatch_worker()
 
         model_config = copy.deepcopy(self.model_config)
         parallel_config = copy.deepcopy(self.parallel_config)
@@ -195,13 +200,11 @@ class RayGPUExecutor(ExecutorBase):
             is_driver_worker=True,
         )
 
-        # FIXME(woosuk): We are not properly initializing cupy NCCL when
+        # FIXME: We are not properly initializing cupy NCCL when
         # we have multiple nodes.
-        self._run_workers(
-            "init_device",
-            cupy_port=get_open_port()
-            if not model_config.enforce_eager else None,
-        )
+        self._run_workers("init_model",
+                          cupy_port=get_open_port()
+                          if not model_config.enforce_eager else None)
         self._run_workers(
             "load_model",
             max_concurrent_workers=self.parallel_config.
@@ -209,6 +212,7 @@ class RayGPUExecutor(ExecutorBase):
         )
 
     def _init_cache(self) -> None:
+        # ruff: noqa: E501
         """Profiles the memory usage and initializes the KV cache.
 
         The engine will first conduct a profiling of the existing memory usage.
@@ -228,7 +232,7 @@ class RayGPUExecutor(ExecutorBase):
         .. tip::
             You may limit the usage of GPU memory
             by adjusting the `gpu_memory_utilization` parameter.
-        """  # noqa: E501
+        """
         # Get the maximum number of blocks that can be allocated on GPU and CPU.
         num_blocks = self._run_workers(
             "profile_num_available_blocks",
@@ -265,13 +269,11 @@ class RayGPUExecutor(ExecutorBase):
         # if enforce_eager is False.
         self._run_workers("warm_up_model")
 
-    def execute_model(
-        self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
-        blocks_to_swap_in: Dict[int, int],
-        blocks_to_swap_out: Dict[int, int],
-        blocks_to_copy: Dict[int, List[int]],
-    ) -> SamplerOutput:
+    def execute_model(self,
+                      seq_group_metadata_list: List[SequenceGroupMetadata],
+                      blocks_to_swap_in: Dict[int, int],
+                      blocks_to_swap_out: Dict[int, int],
+                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
         all_outputs = self._run_workers(
             "execute_model",
             driver_kwargs={
@@ -280,8 +282,7 @@ class RayGPUExecutor(ExecutorBase):
                 "blocks_to_swap_out": blocks_to_swap_out,
                 "blocks_to_copy": blocks_to_copy,
             },
-            use_ray_compiled_dag=USE_RAY_COMPILED_DAG,
-        )
+            use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
 
         # Only the driver worker returns the sampling results.
         output = all_outputs[0]
@@ -322,7 +323,7 @@ class RayGPUExecutor(ExecutorBase):
 
         if use_ray_compiled_dag:
             # Right now, compiled DAG can only accept a single
-            # input. TODO(sang): Fix it.
+            # input. TODO: Fix it.
             output_channels = self.forward_dag.execute(1)
         else:
             # Start the ray workers first.
@@ -359,7 +360,6 @@ class RayGPUExecutor(ExecutorBase):
 
     def _compiled_ray_dag(self):
         import pkg_resources
-
         required_version = "2.9"
         current_version = pkg_resources.get_distribution("ray").version
         if current_version < required_version:
@@ -367,7 +367,6 @@ class RayGPUExecutor(ExecutorBase):
                              f"required, but found {current_version}")
 
         from ray.dag import MultiOutputNode, InputNode
-
         assert self.parallel_config.worker_use_ray
 
         # Right now, compiled DAG requires at least 1 arg. We send
@@ -440,8 +439,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
                 "blocks_to_swap_in": blocks_to_swap_in,
                 "blocks_to_swap_out": blocks_to_swap_out,
                 "blocks_to_copy": blocks_to_copy,
-            },
-        )
+            })
 
         # Only the driver worker returns the sampling results.
         output = all_outputs[0]