Browse Source

feat: model executor refactor (#367)

* pipe the executor classes to the engine

* fix ray error

* formatting
AlpinDale 11 months ago
parent
commit
0f6d56b07f

+ 2 - 2
aphrodite/__init__.py

@@ -1,7 +1,7 @@
 from aphrodite.engine.args_tools import AsyncEngineArgs, EngineArgs
 from aphrodite.engine.args_tools import AsyncEngineArgs, EngineArgs
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
-from aphrodite.engine.ray_tools import initialize_cluster
+from aphrodite.engine.ray_tools import initialize_ray_cluster
 from aphrodite.endpoints.llm import LLM
 from aphrodite.endpoints.llm import LLM
 from aphrodite.common.outputs import CompletionOutput, RequestOutput
 from aphrodite.common.outputs import CompletionOutput, RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
@@ -17,5 +17,5 @@ __all__ = [
     "EngineArgs",
     "EngineArgs",
     "AsyncAphrodite",
     "AsyncAphrodite",
     "AsyncEngineArgs",
     "AsyncEngineArgs",
-    "initialize_cluster",
+    "initialize_ray_cluster",
 ]
 ]

+ 6 - 1
aphrodite/common/config.py

@@ -1,4 +1,4 @@
-from typing import Optional, Union, ClassVar
+from typing import TYPE_CHECKING, Optional, Union, ClassVar
 from dataclasses import dataclass
 from dataclasses import dataclass
 import os
 import os
 from packaging.version import Version
 from packaging.version import Version
@@ -11,6 +11,9 @@ from aphrodite.transformers_utils.config import get_config
 from aphrodite.common.utils import (get_cpu_memory, is_hip, is_neuron,
 from aphrodite.common.utils import (get_cpu_memory, is_hip, is_neuron,
                                     get_nvcc_cuda_version)
                                     get_nvcc_cuda_version)
 
 
+if TYPE_CHECKING:
+    from ray.util.placement_group import PlacementGroup
+
 _GB = 1 << 30
 _GB = 1 << 30
 
 
 
 
@@ -482,6 +485,7 @@ class ParallelConfig:
         max_parallel_loading_workers: Optional[int] = None,
         max_parallel_loading_workers: Optional[int] = None,
         disable_custom_all_reduce: bool = False,
         disable_custom_all_reduce: bool = False,
         ray_workers_use_nsight: bool = False,
         ray_workers_use_nsight: bool = False,
+        placement_group: Optional["PlacementGroup"] = None,
     ) -> None:
     ) -> None:
         self.pipeline_parallel_size = pipeline_parallel_size
         self.pipeline_parallel_size = pipeline_parallel_size
         if is_neuron():
         if is_neuron():
@@ -497,6 +501,7 @@ class ParallelConfig:
         self.max_parallel_loading_workers = max_parallel_loading_workers
         self.max_parallel_loading_workers = max_parallel_loading_workers
         self.disable_custom_all_reduce = disable_custom_all_reduce
         self.disable_custom_all_reduce = disable_custom_all_reduce
         self.ray_workers_use_nsight = ray_workers_use_nsight
         self.ray_workers_use_nsight = ray_workers_use_nsight
+        self.placement_group = placement_group
 
 
         self.world_size = pipeline_parallel_size * self.tensor_parallel_size
         self.world_size = pipeline_parallel_size * self.tensor_parallel_size
         # Ray worker is not supported for Neuron backend.
         # Ray worker is not supported for Neuron backend.

+ 46 - 440
aphrodite/engine/aphrodite_engine.py

@@ -1,19 +1,5 @@
-import copy
-from collections import defaultdict
-import os
 import time
 import time
-import pickle
-import importlib
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Dict,
-    Iterable,
-    List,
-    Optional,
-    Tuple,
-    Union,
-)
+from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
 from loguru import logger
 from loguru import logger
 from transformers import PreTrainedTokenizer
 from transformers import PreTrainedTokenizer
 
 
@@ -29,12 +15,9 @@ from aphrodite.common.config import (
 )
 )
 from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
 from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.args_tools import EngineArgs
+from aphrodite.executor.executor_base import ExecutorBase
 from aphrodite.engine.metrics import StatLogger, Stats
 from aphrodite.engine.metrics import StatLogger, Stats
-from aphrodite.engine.ray_tools import (
-    RayWorkerAphrodite,
-    initialize_cluster,
-    ray,
-)
+from aphrodite.engine.ray_tools import (initialize_ray_cluster)
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (
 from aphrodite.common.sequence import (
@@ -51,33 +34,11 @@ from aphrodite.transformers_utils.tokenizer import (
     TokenizerGroup,
     TokenizerGroup,
 )
 )
 from aphrodite.common.utils import (
 from aphrodite.common.utils import (
-    Counter,
-    set_cuda_visible_devices,
-    get_ip,
-    get_open_port,
-    get_distributed_init_method,
-)
+    Counter, )
 from aphrodite.common.logger import setup_logger
 from aphrodite.common.logger import setup_logger
 
 
-if ray:
-    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
-
-if TYPE_CHECKING:
-    from ray.util.placement_group import PlacementGroup
-
 _LOCAL_LOGGING_INTERVAL_SEC = 5
 _LOCAL_LOGGING_INTERVAL_SEC = 5
 
 
-# A map between the device type (in device config) to its worker module.
-DEVICE_TO_WORKER_MODULE_MAP = {
-    "cuda": "aphrodite.task_handler.worker",
-    "neuron": "aphrodite.task_handler.neuron_worker",
-}
-
-# If the env var is set, it uses the Ray's compiled DAG API
-# which optimizes the control plane overhead.
-# Run APHRODITE with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
-USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
-
 
 
 class AphroditeEngine:
 class AphroditeEngine:
     """An LLM engine that receives requests and generates texts.
     """An LLM engine that receives requests and generates texts.
@@ -103,8 +64,8 @@ class AphroditeEngine:
         scheduler_config: The configuration related to the request scheduler.
         scheduler_config: The configuration related to the request scheduler.
         device_config: The configuration related to the device.
         device_config: The configuration related to the device.
         lora_config: The configuration related to LoRA.
         lora_config: The configuration related to LoRA.
-        placement_group: Ray placement group for distributed execution.
-            Required for distributed execution.
+        executor_class: The model executor class for managing distributed
+            execution.
         log_stats: Whether to log statistics.
         log_stats: Whether to log statistics.
     """
     """
 
 
@@ -116,7 +77,7 @@ class AphroditeEngine:
         scheduler_config: SchedulerConfig,
         scheduler_config: SchedulerConfig,
         device_config: DeviceConfig,
         device_config: DeviceConfig,
         lora_config: Optional[LoRAConfig],
         lora_config: Optional[LoRAConfig],
-        placement_group: Optional["PlacementGroup"],
+        executor_class: Type[ExecutorBase],
         log_stats: bool,
         log_stats: bool,
     ) -> None:
     ) -> None:
         logger.info(
         logger.info(
@@ -148,33 +109,13 @@ class AphroditeEngine:
         self._init_tokenizer()
         self._init_tokenizer()
         self.seq_counter = Counter()
         self.seq_counter = Counter()
 
 
-        # Create the parallel GPU workers.
-        if self.parallel_config.worker_use_ray:
-            # Disable Ray usage stats collection.
-            ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
-            if ray_usage != "1":
-                os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
-            # Pass additional arguments to initialize the worker
-            additional_ray_args = {}
-            if self.parallel_config.ray_workers_use_nsight:
-                logger.info("Configuring Ray workers to use nsight.")
-                additional_ray_args = {
-                    "runtime_env": {
-                        "nsight": {
-                            "t": "cuda,cudnn,cublas",
-                            "o": "'worker_process_%p'",
-                            "cuda-graph-trace": "node",
-                        }
-                    }
-                }
-            self._init_workers_ray(placement_group, **additional_ray_args)
-        else:
-            self._init_workers()
-
-        # Profile the memory usage and initialize the cache.
-        self._init_cache()
+        self.model_executor = executor_class(model_config, cache_config,
+                                             parallel_config, scheduler_config,
+                                             device_config, lora_config)
 
 
         # Create the scheduler.
         # Create the scheduler.
+        # NOTE: the cache_config here have been updated with the numbers of
+        # GPU and CPU blocks, which are profiled in the distributed executor.
         self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
         self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
 
 
         # Metric Logging.
         # Metric Logging.
@@ -185,9 +126,29 @@ class AphroditeEngine:
             )
             )
             self.stat_logger.info("cache_config", self.cache_config)
             self.stat_logger.info("cache_config", self.cache_config)
 
 
-        self.forward_dag = None
-        if USE_RAY_COMPILED_DAG:
-            self.forward_dag = self._compiled_ray_dag()
+    @classmethod
+    def from_engine_args(cls, engine_args: EngineArgs) -> "AphroditeEngine":
+        """Creates an LLM engine from the engine arguments."""
+        # Create the engine configs.
+        engine_configs = engine_args.create_engine_configs()
+        parallel_config = engine_configs[2]
+
+        # Initialize the cluster and specify the executor class.
+        if parallel_config.worker_use_ray:
+            initialize_ray_cluster(parallel_config)
+            from aphrodite.executor.ray_gpu_executor import RayGPUExecutor
+            executor_class = RayGPUExecutor
+        else:
+            assert parallel_config.world_size == 1, (
+                "Ray is required if parallel_config.world_size > 1.")
+            from aphrodite.executor.gpu_executor import GPUExecutor
+            executor_class = GPUExecutor
+
+        # Create the LLM engine.
+        engine = cls(*engine_configs,
+                     executor_class=executor_class,
+                     log_stats=not engine_args.disable_log_stats)
+        return engine
 
 
     def __reduce__(self):
     def __reduce__(self):
         # This is to ensure that the AphroditeEngine is not referenced in
         # This is to ensure that the AphroditeEngine is not referenced in
@@ -201,40 +162,6 @@ class AphroditeEngine:
                               sequence: Sequence) -> "PreTrainedTokenizer":
                               sequence: Sequence) -> "PreTrainedTokenizer":
         return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
         return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
 
 
-    def _dispatch_worker(self):
-        worker_module = DEVICE_TO_WORKER_MODULE_MAP[
-            self.device_config.device_type]
-        imported_worker = importlib.import_module(worker_module)
-        Worker = imported_worker.Worker
-        return Worker
-
-    def _init_workers(self):
-        # Lazy import the Worker to avoid importing torch.cuda/xformers
-        # before CUDA_VISIBLE_DEVICES is set in the Worker
-        Worker = self._dispatch_worker()
-
-        assert (self.parallel_config.world_size == 1
-                ), "Ray is required if parallel_config.world_size > 1."
-
-        self.workers: List[Worker] = []
-        distributed_init_method = get_distributed_init_method(
-            get_ip(), get_open_port())
-        self.driver_worker = Worker(
-            self.model_config,
-            self.parallel_config,
-            self.scheduler_config,
-            self.device_config,
-            local_rank=0,
-            rank=0,
-            distributed_init_method=distributed_init_method,
-            lora_config=self.lora_config,
-            kv_cache_dtype=self.cache_config.cache_dtype,
-            # kv_quant_params_path=(self.cache_config.cache_quant_params_path),
-            is_driver_worker=True,
-        )
-        self._run_workers("init_model")
-        self._run_workers("load_model")
-
     def _init_tokenizer(self, **tokenizer_init_kwargs):
     def _init_tokenizer(self, **tokenizer_init_kwargs):
         init_kwargs = dict(
         init_kwargs = dict(
             enable_lora=bool(self.lora_config),
             enable_lora=bool(self.lora_config),
@@ -248,131 +175,6 @@ class AphroditeEngine:
         self.tokenizer: TokenizerGroup = TokenizerGroup(
         self.tokenizer: TokenizerGroup = TokenizerGroup(
             self.model_config.tokenizer, **init_kwargs)
             self.model_config.tokenizer, **init_kwargs)
 
 
-    def _init_workers_ray(self, placement_group: "PlacementGroup",
-                          **ray_remote_kwargs):
-        if self.parallel_config.tensor_parallel_size == 1:
-            num_gpus = self.cache_config.gpu_memory_utilization
-        else:
-            num_gpus = 1
-
-        self.driver_dummy_worker: RayWorkerAphrodite = None
-        self.workers: List[RayWorkerAphrodite] = []
-
-        driver_ip = get_ip()
-        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
-            if not bundle.get("GPU", 0):
-                continue
-            scheduling_strategy = PlacementGroupSchedulingStrategy(
-                placement_group=placement_group,
-                placement_group_capture_child_tasks=True,
-                placement_group_bundle_index=bundle_id,
-            )
-            worker = ray.remote(
-                num_cpus=0,
-                num_gpus=num_gpus,
-                scheduling_strategy=scheduling_strategy,
-                **ray_remote_kwargs,
-            )(RayWorkerAphrodite).remote(self.model_config.trust_remote_code)
-
-            worker_ip = ray.get(worker.get_node_ip.remote())
-            if worker_ip == driver_ip and self.driver_dummy_worker is None:
-                # If the worker is on the same node as the driver, we use it
-                # as the resource holder for the driver process.
-                self.driver_dummy_worker = worker
-            else:
-                self.workers.append(worker)
-
-        if self.driver_dummy_worker is None:
-            raise ValueError(
-                "Ray does not allocate any GPUs on the driver node. Consider "
-                "adjusting the Ray placement group or running the driver on a "
-                "GPU node.")
-
-        driver_node_id, driver_gpu_ids = ray.get(
-            self.driver_dummy_worker.get_node_and_gpu_ids.remote())
-        worker_node_and_gpu_ids = ray.get(
-            [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
-
-        node_workers = defaultdict(list)
-        node_gpus = defaultdict(list)
-
-        node_workers[driver_node_id].append(0)
-        node_gpus[driver_node_id].extend(driver_gpu_ids)
-        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
-                                               start=1):
-            node_workers[node_id].append(i)
-            node_gpus[node_id].extend(gpu_ids)
-        for node_id, gpu_ids in node_gpus.items():
-            node_gpus[node_id] = sorted(gpu_ids)
-
-        # Set CUDA_VISIBLE_DEVICES for the driver.
-        set_cuda_visible_devices(node_gpus[driver_node_id])
-        for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
-            worker.set_cuda_visible_devices.remote(node_gpus[node_id])
-
-        distributed_init_method = get_distributed_init_method(
-            driver_ip, get_open_port())
-
-        # Lazy import the Worker to avoid importing torch.cuda/xformers
-        # before CUDA_VISIBLE_DEVICES is set in the Worker
-        Worker = self._dispatch_worker()
-
-        # Initialize torch distributed process group for the workers.
-        model_config = copy.deepcopy(self.model_config)
-        parallel_config = copy.deepcopy(self.parallel_config)
-        scheduler_config = copy.deepcopy(self.scheduler_config)
-        device_config = copy.deepcopy(self.device_config)
-        lora_config = copy.deepcopy(self.lora_config)
-        kv_cache_dtype = self.cache_config.cache_dtype
-        # kv_quant_params_path = self.cache_config.cache_quant_params_path
-
-        for rank, (worker, (node_id,
-                            _)) in enumerate(zip(self.workers,
-                                                 worker_node_and_gpu_ids),
-                                             start=1):
-            local_rank = node_workers[node_id].index(rank)
-            worker.init_worker.remote(
-                lambda rank=rank, local_rank=local_rank: Worker(
-                    model_config,
-                    parallel_config,
-                    scheduler_config,
-                    device_config,
-                    local_rank,
-                    rank,
-                    distributed_init_method,
-                    lora_config=lora_config,
-                    kv_cache_dtype=kv_cache_dtype,
-                    # kv_quant_params_path=kv_quant_params_path,
-                ))
-
-        driver_rank = 0
-        driver_local_rank = node_workers[driver_node_id].index(driver_rank)
-        self.driver_worker = Worker(
-            self.model_config,
-            self.parallel_config,
-            self.scheduler_config,
-            self.device_config,
-            driver_local_rank,
-            driver_rank,
-            distributed_init_method,
-            lora_config=self.lora_config,
-            kv_cache_dtype=kv_cache_dtype,
-            # kv_quant_params_path=kv_quant_params_path,
-            is_driver_worker=True,
-        )
-
-        # don't use cupy for eager mode
-        self._run_workers(
-            "init_model",
-            cupy_port=get_open_port()
-            if not model_config.enforce_eager else None,
-        )
-        self._run_workers(
-            "load_model",
-            max_concurrent_workers=self.parallel_config.
-            max_parallel_loading_workers,
-        )
-
     def _verify_args(self) -> None:
     def _verify_args(self) -> None:
         self.model_config.verify_with_parallel_config(self.parallel_config)
         self.model_config.verify_with_parallel_config(self.parallel_config)
         self.cache_config.verify_with_parallel_config(self.parallel_config)
         self.cache_config.verify_with_parallel_config(self.parallel_config)
@@ -381,90 +183,6 @@ class AphroditeEngine:
             self.lora_config.verify_with_scheduler_config(
             self.lora_config.verify_with_scheduler_config(
                 self.scheduler_config)
                 self.scheduler_config)
 
 
-    def _init_cache(self) -> None:
-        # ruff: noqa: E501
-        """Profiles the memory usage and initializes the KV cache.
-
-        The engine will first conduct a profiling of the existing memory usage.
-        Then, it calculate the maximum possible number of GPU and CPU blocks
-        that can be allocated with the remaining free memory.
-        More details can be found in the
-        :meth:`~aphrodite.task_handler.worker.Worker.profile_num_available_blocks` method
-        from class :class:`~aphrodite.task_handler.Worker`.
-
-        Afterwards, as there may be multiple workers,
-        we take the minimum number of blocks across all workers
-        to ensure this can be applied to all of them.
-
-        Finally, the engine will initialize the KV cache
-        with the calculated number of blocks.
-
-        .. tip::
-            You may limit the usage of GPU memory
-            by adjusting the `gpu_memory_utilization` parameters.
-        """
-        # Get the maximum number of blocks that can be allocated on GPU and CPU.
-        num_blocks = self._run_workers(
-            "profile_num_available_blocks",
-            block_size=self.cache_config.block_size,
-            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
-            cpu_swap_space=self.cache_config.swap_space_bytes,
-            cache_dtype=self.cache_config.cache_dtype,
-        )
-
-        # Since we use a shared centralized controller, we take the minimum
-        # number of blocks across all workers to make sure all the memory
-        # operators can be applied to all workers.
-        num_gpu_blocks = min(b[0] for b in num_blocks)
-        num_cpu_blocks = min(b[1] for b in num_blocks)
-        # FIXME: Change to debug log.
-        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
-                    f"# CPU blocks: {num_cpu_blocks}")
-
-        logger.info(
-            f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"
-        )
-
-        if num_gpu_blocks <= 0:
-            raise ValueError("No available memory for the cache blocks. "
-                             "Try increasing `gpu_memory_utilization` when "
-                             "initializing the engine.")
-        max_seq_len = self.cache_config.block_size * num_gpu_blocks
-        logger.info(f"Maximum sequence length allowed in the cache: "
-                    f"{max_seq_len}")
-        if self.model_config.max_model_len > max_seq_len:
-            raise ValueError(
-                f"The model's max seq len ({self.model_config.max_model_len}) "
-                "is larger than the maximum number of tokens that can be "
-                f"stored in KV cache ({max_seq_len}). Try increasing "
-                "`gpu_memory_utilization` or decreasing `max_model_len` when "
-                "initializing the engine.")
-
-        self.cache_config.num_gpu_blocks = num_gpu_blocks
-        self.cache_config.num_cpu_blocks = num_cpu_blocks
-
-        # Initialize the cache.
-        self._run_workers("init_cache_engine", cache_config=self.cache_config)
-        # Warm up the model. This includes capturing the model into CUDA graph
-        # if enforce_eager is False.
-        self._run_workers("warm_up_model")
-
-    @classmethod
-    def from_engine_args(cls, engine_args: EngineArgs) -> "AphroditeEngine":
-        """Creates an LLM engine from the engine arguments."""
-        # Create the engine configs.
-        engine_configs = engine_args.create_engine_configs()
-        parallel_config = engine_configs[2]
-        # Initialize the cluster.
-        placement_group = initialize_cluster(parallel_config)
-        # Create the LLM engine.
-        engine = cls(
-            *engine_configs,
-            placement_group,
-            log_stats=not engine_args.disable_log_stats,
-        )
-        return engine
-
     def encode_request(
     def encode_request(
         self,
         self,
         request_id: str,
         request_id: str,
@@ -894,7 +612,7 @@ class AphroditeEngine:
                 - A Sequence Group (SG) refer to a group of sequences
                 - A Sequence Group (SG) refer to a group of sequences
                   that are generated from the same prompt.
                   that are generated from the same prompt.
 
 
-            - Step 2: Calls the workers to execute the model.
+            - Step 2: Calls the distributed executor to execute the model.
             - Step 3: Processes the model output. This mainly includes:
             - Step 3: Processes the model output. This mainly includes:
 
 
                 - Decodes the relevant outputs.
                 - Decodes the relevant outputs.
@@ -930,20 +648,10 @@ class AphroditeEngine:
         seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
         seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
 
 
         if not scheduler_outputs.is_empty():
         if not scheduler_outputs.is_empty():
-            # Execute the model.
-            all_outputs = self._run_workers(
-                "execute_model",
-                driver_kwargs={
-                    "seq_group_metadata_list": seq_group_metadata_list,
-                    "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
-                    "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
-                    "blocks_to_copy": scheduler_outputs.blocks_to_copy,
-                },
-                use_ray_compiled_dag=USE_RAY_COMPILED_DAG,
-            )
-
-            # Only the driver worker returns the sampling results.
-            output = all_outputs[0]
+            output = self.model_executor.execute_model(
+                seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
+                scheduler_outputs.blocks_to_swap_out,
+                scheduler_outputs.blocks_to_copy)
         else:
         else:
             output = []
             output = []
 
 
@@ -999,7 +707,8 @@ class AphroditeEngine:
             # Latency Timings.
             # Latency Timings.
             time_last_iters = []
             time_last_iters = []
             for seq_group in scheduler_outputs.scheduled_seq_groups:
             for seq_group in scheduler_outputs.scheduled_seq_groups:
-                # Time since last token. (n.b. updates seq_group.metrics.last_token_time)
+                # Time since last token.
+                # (n.b. updates seq_group.metrics.last_token_time)
                 time_last_iters.append(seq_group.get_last_latency(now))
                 time_last_iters.append(seq_group.get_last_latency(now))
                 # Time since arrival for all finished requests.
                 # Time since arrival for all finished requests.
                 if seq_group.is_finished():
                 if seq_group.is_finished():
@@ -1123,119 +832,16 @@ class AphroditeEngine:
             seq.output_text = seq.output_text[:-len(stop_string)]
             seq.output_text = seq.output_text[:-len(stop_string)]
 
 
     def add_lora(self, lora_request: LoRARequest) -> bool:
     def add_lora(self, lora_request: LoRARequest) -> bool:
-        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
-        return self._run_workers(
-            "add_lora",
-            lora_request=lora_request,
-        )
+        return self.model_executor.add_lora(lora_request)
 
 
     def remove_lora(self, lora_id: int) -> bool:
     def remove_lora(self, lora_id: int) -> bool:
-        assert lora_id > 0, "lora_id must be greater than 0."
-        return self._run_workers(
-            "remove_lora",
-            lora_id=lora_id,
-        )
+        return self.model_executor.remove_lora(lora_id)
 
 
     def list_loras(self) -> List[int]:
     def list_loras(self) -> List[int]:
-        return self._run_workers("list_loras")
-
-    def _run_workers(
-        self,
-        method: str,
-        *args,
-        driver_args: Optional[List[Any]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
-        max_concurrent_workers: Optional[int] = None,
-        use_ray_compiled_dag: bool = False,
-        **kwargs,
-    ) -> Any:
-        """Runs the given method on all workers."""
-
-        if max_concurrent_workers:
-            raise NotImplementedError(
-                "max_concurrent_workers is not supported yet.")
-
-        if use_ray_compiled_dag:
-            # Right now, compiled DAG can only accept a single
-            # input.
-            # TODO: Fix it.
-            output_channels = self.forward_dag.execute(1)
-        else:
-            # Start the ray workers first.
-            ray_worker_outputs = [
-                worker.execute_method.remote(method, *args, **kwargs)
-                for worker in self.workers
-            ]
-
-        if driver_args is None:
-            driver_args = args
-        if driver_kwargs is None:
-            driver_kwargs = kwargs
-
-        # Start the driver worker after all the ray workers.
-        driver_worker_output = getattr(self.driver_worker,
-                                       method)(*driver_args, **driver_kwargs)
-
-        # Get the results of the ray workers.
-        if self.workers:
-            if use_ray_compiled_dag:
-                try:
-                    ray_worker_outputs = [
-                        pickle.loads(chan.begin_read())
-                        for chan in output_channels
-                    ]
-                finally:
-                    # Has to call end_read in order to reuse the DAG.
-                    for chan in output_channels:
-                        chan.end_read()
-            else:
-                ray_worker_outputs = ray.get(ray_worker_outputs)
-
-        return [driver_worker_output] + ray_worker_outputs
-
-    def _compiled_ray_dag(self):
-        from packaging import version
-        import pkg_resources
-
-        required_version = "2.9"
-        current_version = pkg_resources.get_distribution("ray").version
-
-        if version.parse(current_version) < version.parse(required_version):
-            raise ValueError(f"Ray version {required_version} or greater is "
-                             f"required, but found {current_version}")
-
-        from ray.dag import MultiOutputNode, InputNode
-
-        assert self.parallel_config.worker_use_ray
-
-        # Right now, compiled DAG requires at least 1 arg. We send
-        # a dummy value for now. It will be fixed soon.
-        with InputNode() as input_data:
-            forward_dag = MultiOutputNode([
-                worker.execute_model_compiled_dag_remote.bind(input_data)
-                for worker in self.workers
-            ])
-        return forward_dag.experimental_compile()
+        return self.model_executor.list_loras()
 
 
     def check_health(self) -> None:
     def check_health(self) -> None:
-        """Raises an error if engine is unhealthy."""
-        self._check_if_any_actor_is_dead()
-
-    def _check_if_any_actor_is_dead(self):
-        if not self.parallel_config.worker_use_ray:
-            return
-
-        if not self.workers:
-            return
-
-        dead_actors = []
-        for actor in self.workers:
-            actor_state = ray.state.actors(actor._ray_actor_id.hex())
-            if actor_state["State"] == "DEAD":
-                dead_actors.append(actor)
-        if dead_actors:
-            raise RuntimeError("At least one Worker is dead. "
-                               f"Dead Workers: {dead_actors}. ")
+        self.model_executor.check_health()
 
 
 
 
 setup_logger()
 setup_logger()

+ 39 - 69
aphrodite/engine/async_aphrodite.py

@@ -2,8 +2,8 @@ import asyncio
 import os
 import os
 import time
 import time
 from functools import partial
 from functools import partial
-from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
-                    Union, AsyncIterator, Callable)
+from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type,
+                    Union, AsyncIterator)
 from loguru import logger
 from loguru import logger
 from transformers import PreTrainedTokenizer
 from transformers import PreTrainedTokenizer
 
 
@@ -11,7 +11,7 @@ from aphrodite.lora.request import LoRARequest
 from aphrodite.common.config import ModelConfig
 from aphrodite.common.config import ModelConfig
 from aphrodite.engine.args_tools import AsyncEngineArgs
 from aphrodite.engine.args_tools import AsyncEngineArgs
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
-from aphrodite.engine.ray_tools import initialize_cluster, ray
+from aphrodite.engine.ray_tools import initialize_ray_cluster, ray
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
 
 
@@ -211,17 +211,10 @@ class _AsyncAphrodite(AphroditeEngine):
 
 
         if not scheduler_outputs.is_empty():
         if not scheduler_outputs.is_empty():
             # Execute the model.
             # Execute the model.
-            all_outputs = await self._run_workers_async(
-                "execute_model",
-                driver_kwargs={
-                    "seq_group_metadata_list": seq_group_metadata_list,
-                    "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
-                    "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
-                    "blocks_to_copy": scheduler_outputs.blocks_to_copy,
-                })
-
-            # Only the driver worker returns the sampling results.
-            output = all_outputs[0]
+            output = await self.model_executor.execute_model_async(
+                seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
+                scheduler_outputs.blocks_to_swap_out,
+                scheduler_outputs.blocks_to_copy)
         else:
         else:
             output = []
             output = []
 
 
@@ -271,37 +264,8 @@ class _AsyncAphrodite(AphroditeEngine):
             lora_request=lora_request,
             lora_request=lora_request,
         )
         )
 
 
-    async def _run_workers_async(
-        self,
-        method: str,
-        *args,
-        driver_args: Optional[List[Any]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
-        **kwargs,
-    ) -> Any:
-        """Runs the given method on all workers."""
-        coros = []
-
-        if driver_args is None:
-            driver_args = args
-        if driver_kwargs is None:
-            driver_kwargs = kwargs
-
-        # Run the driver worker asynchronously.
-        driver_executor = getattr(self.driver_worker, method)
-        coros.append(asyncio.get_event_loop().run_in_executor(
-            None, partial(driver_executor, *driver_args, **driver_kwargs)))
-
-        # Run the ray workers asynchronously.
-        for worker in self.workers:
-            coros.append(worker.execute_method.remote(method, *args, **kwargs))
-
-        all_outputs = await asyncio.gather(*coros)
-        return all_outputs
-
-    async def check_health_async(self):
-        """Raises an error if engine is unhealthy."""
-        self._check_if_any_actor_is_dead()
+    async def check_health_async(self) -> None:
+        self.model_executor.check_health()
 
 
 
 
 class AsyncAphrodite:
 class AsyncAphrodite:
@@ -357,6 +321,34 @@ class AsyncAphrodite:
         self._request_tracker: Optional[RequestTracker] = None
         self._request_tracker: Optional[RequestTracker] = None
         self._errored_with: Optional[BaseException] = None
         self._errored_with: Optional[BaseException] = None
 
 
+    @classmethod
+    def from_engine_args(cls,
+                         engine_args: AsyncEngineArgs,
+                         start_engine_loop: bool = True) -> "AsyncAphrodite":
+        """Creates an async LLM engine from the engine arguments."""
+        # Create the engine configs.
+        engine_configs = engine_args.create_engine_configs()
+        parallel_config = engine_configs[2]
+        if parallel_config.worker_use_ray or engine_args.engine_use_ray:
+            initialize_ray_cluster(parallel_config)
+            from aphrodite.executor.ray_gpu_executor import RayGPUExecutorAsync
+            executor_class = RayGPUExecutorAsync
+        else:
+            assert parallel_config.world_size == 1, (
+                "Ray is required if parallel_config.world_size > 1.")
+            from aphrodite.executor.gpu_executor import GPUExecutorAsync
+            executor_class = GPUExecutorAsync
+        # Create the async LLM engine.
+        engine = cls(parallel_config.worker_use_ray,
+                     engine_args.engine_use_ray,
+                     *engine_configs,
+                     executor_class,
+                     log_requests=not engine_args.disable_log_requests,
+                     log_stats=not engine_args.disable_log_stats,
+                     max_log_len=engine_args.max_log_len,
+                     start_engine_loop=start_engine_loop)
+        return engine
+
     @property
     @property
     def is_running(self) -> bool:
     def is_running(self) -> bool:
         return (self.background_loop is not None
         return (self.background_loop is not None
@@ -579,7 +571,7 @@ class AsyncAphrodite:
             - If the engine is not running, start the background loop,
             - If the engine is not running, start the background loop,
               which iteratively invokes
               which iteratively invokes
               # pylint: disable=line-too-long
               # pylint: disable=line-too-long
-              :meth:`~aphrodite.engine.async_llm_engine.AsyncAphrodite.engine_step`
+              :meth:`~aphrodite.engine.async_aphrodite.AsyncAphrodite.engine_step`
               to process the waiting requests.
               to process the waiting requests.
             - Add the request to the engine's `RequestTracker`.
             - Add the request to the engine's `RequestTracker`.
               On the next background loop, this request will be sent to
               On the next background loop, this request will be sent to
@@ -682,35 +674,13 @@ class AsyncAphrodite:
         else:
         else:
             return self.engine.get_model_config()
             return self.engine.get_model_config()
 
 
-    @classmethod
-    def from_engine_args(cls,
-                         engine_args: AsyncEngineArgs,
-                         start_engine_loop: bool = True) -> "AsyncAphrodite":
-        """Creates an async LLM engine from the engine arguments."""
-        # Create the engine configs.
-        engine_configs = engine_args.create_engine_configs()
-        parallel_config = engine_configs[2]
-        # Initialize the cluster.
-        placement_group = initialize_cluster(parallel_config,
-                                             engine_args.engine_use_ray)
-        # Create the async LLM engine.
-        engine = cls(parallel_config.worker_use_ray,
-                     engine_args.engine_use_ray,
-                     *engine_configs,
-                     placement_group,
-                     log_requests=not engine_args.disable_log_requests,
-                     log_stats=not engine_args.disable_log_stats,
-                     max_log_len=engine_args.max_log_len,
-                     start_engine_loop=start_engine_loop)
-        return engine
-
     async def do_log_stats(self) -> None:
     async def do_log_stats(self) -> None:
         if self.engine_use_ray:
         if self.engine_use_ray:
             await self.engine.do_log_stats.remote()
             await self.engine.do_log_stats.remote()
         else:
         else:
             self.engine.do_log_stats()
             self.engine.do_log_stats()
 
 
-    async def check_health(self):
+    async def check_health(self) -> None:
         """Raises an error if engine is unhealthy."""
         """Raises an error if engine is unhealthy."""
         t = time.perf_counter()
         t = time.perf_counter()
         logger.debug("Starting health check...")
         logger.debug("Starting health check...")

+ 25 - 32
aphrodite/engine/ray_tools.py

@@ -1,6 +1,6 @@
 import pickle
 import pickle
 
 
-from typing import Optional, List, Tuple, TYPE_CHECKING
+from typing import Optional, List, Tuple
 from loguru import logger
 from loguru import logger
 
 
 from aphrodite.common.config import ParallelConfig
 from aphrodite.common.config import ParallelConfig
@@ -63,45 +63,37 @@ except ImportError as e:
     ray = None
     ray = None
     RayWorkerAphrodite = None
     RayWorkerAphrodite = None
 
 
-if TYPE_CHECKING:
-    from ray.util.placement_group import PlacementGroup
 
 
-
-def initialize_cluster(
+def initialize_ray_cluster(
     parallel_config: ParallelConfig,
     parallel_config: ParallelConfig,
-    engine_use_ray: bool = False,
     ray_address: Optional[str] = None,
     ray_address: Optional[str] = None,
-) -> Optional["PlacementGroup"]:
-    """Initialize the distributed cluster probably with Ray.
+):
+    """Initialize the distributed cluster with Ray.
+    it will connect to the Ray cluster and create a placement group
+    for the workers, which includes the specification of the resources
+    for each distributed worker.
 
 
     Args:
     Args:
         parallel_config: The configurations for parallel execution.
         parallel_config: The configurations for parallel execution.
-        engine_use_ray: Whether to use Ray for async engine.
         ray_address: The address of the Ray cluster. If None, uses
         ray_address: The address of the Ray cluster. If None, uses
             the default Ray cluster address.
             the default Ray cluster address.
-
-    Returns:
-        An optional `PlacementGroup`. It includes the specification
-        of the resources for each distributed worker. None if Ray is
-        not used.
     """
     """
-    if parallel_config.worker_use_ray or engine_use_ray:
-        if ray is None:
-            raise ImportError(
-                "Ray is not installed. Please install Ray to use distributed "
-                "serving.")
-        # Connect to a ray cluster.
-        if is_hip():
-            ray.init(address=ray_address,
-                     ignore_reinit_error=True,
-                     num_gpus=parallel_config.world_size)
-        else:
-            ray.init(address=ray_address, ignore_reinit_error=True)
-
-    if not parallel_config.worker_use_ray:
-        assert parallel_config.world_size == 1, (
-            "Ray is required if parallel_config.world_size > 1.")
-        return None
+    if ray is None:
+        raise ImportError(
+            "Ray is not installed. Please install Ray to use distributed "
+            "serving.")
+
+    # Connect to a ray cluster.
+    if is_hip():
+        ray.init(address=ray_address,
+                 ignore_reinit_error=True,
+                 num_gpus=parallel_config.world_size)
+    else:
+        ray.init(address=ray_address, ignore_reinit_error=True)
+
+    if parallel_config.placement_group:
+        # Placement group is already set.
+        return
 
 
     # Create placement group for worker processes
     # Create placement group for worker processes
     current_placement_group = ray.util.get_current_placement_group()
     current_placement_group = ray.util.get_current_placement_group()
@@ -136,4 +128,5 @@ def initialize_cluster(
         # if they cannot be provisioned.
         # if they cannot be provisioned.
         ray.get(current_placement_group.ready(), timeout=1800)
         ray.get(current_placement_group.ready(), timeout=1800)
 
 
-    return current_placement_group
+    # Set the placement group in the parallel config
+    parallel_config.placement_group = current_placement_group

+ 15 - 9
aphrodite/executor/executor_base.py

@@ -1,16 +1,20 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from typing import Dict, List, Optional
 from typing import Dict, List, Optional
 
 
-from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
-                                     ParallelConfig, SchedulerConfig,
-                                     LoRAConfig)
+from aphrodite.common.config import (
+    CacheConfig,
+    DeviceConfig,
+    ModelConfig,
+    ParallelConfig,
+    SchedulerConfig,
+    LoRAConfig,
+)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 
 
 
 
 class ExecutorBase(ABC):
 class ExecutorBase(ABC):
     """Base class for all executors.
     """Base class for all executors.
-
     An executor is responsible for executing the model on a specific device
     An executor is responsible for executing the model on a specific device
     type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
     type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
     that can execute the model on multiple devices.
     that can execute the model on multiple devices.
@@ -29,11 +33,13 @@ class ExecutorBase(ABC):
         raise NotImplementedError
         raise NotImplementedError
 
 
     @abstractmethod
     @abstractmethod
-    def execute_model(self,
-                      seq_group_metadata_list: List[SequenceGroupMetadata],
-                      blocks_to_swap_in: Dict[int, int],
-                      blocks_to_swap_out: Dict[int, int],
-                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+    def execute_model(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
         """Executes one model step on the given sequences."""
         """Executes one model step on the given sequences."""
         raise NotImplementedError
         raise NotImplementedError
 
 

+ 55 - 27
aphrodite/executor/gpu_executor.py

@@ -1,16 +1,32 @@
+import importlib
 from typing import Dict, List, Optional
 from typing import Dict, List, Optional
 
 
 from loguru import logger
 from loguru import logger
 
 
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
-                                     ParallelConfig, SchedulerConfig,
-                                     LoRAConfig)
+from aphrodite.common.config import (
+    CacheConfig,
+    DeviceConfig,
+    ModelConfig,
+    ParallelConfig,
+    SchedulerConfig,
+    LoRAConfig,
+)
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.executor.utils import check_block_size_valid
 from aphrodite.executor.utils import check_block_size_valid
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
-from aphrodite.common.utils import (get_ip, get_open_port,
-                                    get_distributed_init_method, make_async)
+from aphrodite.common.utils import (
+    get_ip,
+    get_open_port,
+    get_distributed_init_method,
+    make_async,
+)
+
+# A map between the device type (in device config) to its worker module.
+DEVICE_TO_WORKER_MODULE_MAP = {
+    "cuda": "aphrodite.task_handler.worker",
+    "neuron": "aphrodite.task_handler.neuron_worker",
+}
 
 
 
 
 class GPUExecutor(ExecutorBase):
 class GPUExecutor(ExecutorBase):
@@ -37,13 +53,20 @@ class GPUExecutor(ExecutorBase):
         # Profile the memory usage and initialize the cache.
         # Profile the memory usage and initialize the cache.
         self._init_cache()
         self._init_cache()
 
 
+    def _dispatch_worker(self):
+        worker_module = DEVICE_TO_WORKER_MODULE_MAP[
+            self.device_config.device_type]
+        imported_worker = importlib.import_module(worker_module)
+        Worker = imported_worker.Worker
+        return Worker
+
     def _init_worker(self):
     def _init_worker(self):
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # before CUDA_VISIBLE_DEVICES is set in the Worker
         # before CUDA_VISIBLE_DEVICES is set in the Worker
-        from aphrodite.task_handler.worker import Worker
+        Worker = self._dispatch_worker()
 
 
-        assert self.parallel_config.world_size == 1, (
-            "GPUExecutor only supports single GPU.")
+        assert (self.parallel_config.world_size == 1
+                ), "GPUExecutor only supports single GPU."
 
 
         distributed_init_method = get_distributed_init_method(
         distributed_init_method = get_distributed_init_method(
             get_ip(), get_open_port())
             get_ip(), get_open_port())
@@ -59,28 +82,27 @@ class GPUExecutor(ExecutorBase):
             kv_cache_dtype=self.cache_config.cache_dtype,
             kv_cache_dtype=self.cache_config.cache_dtype,
             is_driver_worker=True,
             is_driver_worker=True,
         )
         )
-        self.driver_worker.init_device()
+        self.driver_worker.init_model()
         self.driver_worker.load_model()
         self.driver_worker.load_model()
 
 
     def _init_cache(self) -> None:
     def _init_cache(self) -> None:
         """Profiles the memory usage and initializes the KV cache.
         """Profiles the memory usage and initializes the KV cache.
-
         The engine first profiles the existing memory usage.
         The engine first profiles the existing memory usage.
         Then, it allocates the remaining memory for KV blocks.
         Then, it allocates the remaining memory for KV blocks.
-
         .. tip::
         .. tip::
             You may limit the usage of GPU memory
             You may limit the usage of GPU memory
             by adjusting the `gpu_memory_utilization` parameter.
             by adjusting the `gpu_memory_utilization` parameter.
         """
         """
         # Get the maximum number of blocks that can be allocated on GPU and CPU.
         # Get the maximum number of blocks that can be allocated on GPU and CPU.
-        num_gpu_blocks, num_cpu_blocks = (
-            self.driver_worker.profile_num_available_blocks(
-                block_size=self.cache_config.block_size,
-                gpu_memory_utilization=self.cache_config.
-                gpu_memory_utilization,
-                cpu_swap_space=self.cache_config.swap_space_bytes,
-                cache_dtype=self.cache_config.cache_dtype,
-            ))
+        (
+            num_gpu_blocks,
+            num_cpu_blocks,
+        ) = self.driver_worker.profile_num_available_blocks(
+            block_size=self.cache_config.block_size,
+            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
+            cpu_swap_space=self.cache_config.swap_space_bytes,
+            cache_dtype=self.cache_config.cache_dtype,
+        )
 
 
         logger.info(f"# GPU blocks: {num_gpu_blocks}, "
         logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                     f"# CPU blocks: {num_cpu_blocks}")
                     f"# CPU blocks: {num_cpu_blocks}")
@@ -89,8 +111,11 @@ class GPUExecutor(ExecutorBase):
             f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"  # noqa: E501
             f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"  # noqa: E501
         )
         )
 
 
-        check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
-                               self.model_config.max_model_len)
+        check_block_size_valid(
+            num_gpu_blocks,
+            self.cache_config.block_size,
+            self.model_config.max_model_len,
+        )
 
 
         self.cache_config.num_gpu_blocks = num_gpu_blocks
         self.cache_config.num_gpu_blocks = num_gpu_blocks
         self.cache_config.num_cpu_blocks = num_cpu_blocks
         self.cache_config.num_cpu_blocks = num_cpu_blocks
@@ -101,11 +126,13 @@ class GPUExecutor(ExecutorBase):
         # if enforce_eager is False.
         # if enforce_eager is False.
         self.driver_worker.warm_up_model()
         self.driver_worker.warm_up_model()
 
 
-    def execute_model(self,
-                      seq_group_metadata_list: List[SequenceGroupMetadata],
-                      blocks_to_swap_in: Dict[int, int],
-                      blocks_to_swap_out: Dict[int, int],
-                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+    def execute_model(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
         output = self.driver_worker.execute_model(
         output = self.driver_worker.execute_model(
             seq_group_metadata_list=seq_group_metadata_list,
             seq_group_metadata_list=seq_group_metadata_list,
             blocks_to_swap_in=blocks_to_swap_in,
             blocks_to_swap_in=blocks_to_swap_in,
@@ -144,7 +171,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
             seq_group_metadata_list=seq_group_metadata_list,
             seq_group_metadata_list=seq_group_metadata_list,
             blocks_to_swap_in=blocks_to_swap_in,
             blocks_to_swap_in=blocks_to_swap_in,
             blocks_to_swap_out=blocks_to_swap_out,
             blocks_to_swap_out=blocks_to_swap_out,
-            blocks_to_copy=blocks_to_copy)
+            blocks_to_copy=blocks_to_copy,
+        )
         return output
         return output
 
 
     async def check_health_async(self) -> None:
     async def check_health_async(self) -> None:

+ 35 - 37
aphrodite/executor/ray_gpu_executor.py

@@ -3,30 +3,22 @@ import copy
 from collections import defaultdict
 from collections import defaultdict
 import os
 import os
 import pickle
 import pickle
+import importlib
 from typing import TYPE_CHECKING, Any, Dict, List, Optional
 from typing import TYPE_CHECKING, Any, Dict, List, Optional
 
 
 from loguru import logger
 from loguru import logger
 
 
-from aphrodite.common.config import (
-    CacheConfig,
-    DeviceConfig,
-    ModelConfig,
-    ParallelConfig,
-    SchedulerConfig,
-    LoRAConfig,
-)
+from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     LoRAConfig)
 from aphrodite.engine.ray_tools import RayWorkerAphrodite, ray
 from aphrodite.engine.ray_tools import RayWorkerAphrodite, ray
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.executor.utils import check_block_size_valid
 from aphrodite.executor.utils import check_block_size_valid
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
-from aphrodite.common.utils import (
-    set_cuda_visible_devices,
-    get_ip,
-    get_open_port,
-    get_distributed_init_method,
-    make_async,
-)
+from aphrodite.common.utils import (set_cuda_visible_devices, get_ip,
+                                    get_open_port, get_distributed_init_method,
+                                    make_async)
 
 
 if ray is not None:
 if ray is not None:
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@@ -34,6 +26,12 @@ if ray is not None:
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from ray.util.placement_group import PlacementGroup
     from ray.util.placement_group import PlacementGroup
 
 
+# A map between the device type (in device config) to its worker module.
+DEVICE_TO_WORKER_MODULE_MAP = {
+    "cuda": "aphrodite.task_handler.worker",
+    "neuron": "aphrodite.task_handler.neuron_worker",
+}
+
 # If the env var is set, it uses the Ray's compiled DAG API
 # If the env var is set, it uses the Ray's compiled DAG API
 # which optimizes the control plane overhead.
 # which optimizes the control plane overhead.
 # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
 # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
@@ -76,6 +74,13 @@ class RayGPUExecutor(ExecutorBase):
         if USE_RAY_COMPILED_DAG:
         if USE_RAY_COMPILED_DAG:
             self.forward_dag = self._compiled_ray_dag()
             self.forward_dag = self._compiled_ray_dag()
 
 
+    def _dispatch_worker(self):
+        worker_module = DEVICE_TO_WORKER_MODULE_MAP[
+            self.device_config.device_type]
+        imported_worker = importlib.import_module(worker_module)
+        Worker = imported_worker.Worker
+        return Worker
+
     def _init_workers_ray(self, placement_group: "PlacementGroup",
     def _init_workers_ray(self, placement_group: "PlacementGroup",
                           **ray_remote_kwargs):
                           **ray_remote_kwargs):
         if self.parallel_config.tensor_parallel_size == 1:
         if self.parallel_config.tensor_parallel_size == 1:
@@ -151,7 +156,7 @@ class RayGPUExecutor(ExecutorBase):
 
 
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # before CUDA_VISIBLE_DEVICES is set in the Worker
         # before CUDA_VISIBLE_DEVICES is set in the Worker
-        from aphrodite.task_handler.worker import Worker
+        Worker = self._dispatch_worker()
 
 
         model_config = copy.deepcopy(self.model_config)
         model_config = copy.deepcopy(self.model_config)
         parallel_config = copy.deepcopy(self.parallel_config)
         parallel_config = copy.deepcopy(self.parallel_config)
@@ -195,13 +200,11 @@ class RayGPUExecutor(ExecutorBase):
             is_driver_worker=True,
             is_driver_worker=True,
         )
         )
 
 
-        # FIXME(woosuk): We are not properly initializing cupy NCCL when
+        # FIXME: We are not properly initializing cupy NCCL when
         # we have multiple nodes.
         # we have multiple nodes.
-        self._run_workers(
-            "init_device",
-            cupy_port=get_open_port()
-            if not model_config.enforce_eager else None,
-        )
+        self._run_workers("init_model",
+                          cupy_port=get_open_port()
+                          if not model_config.enforce_eager else None)
         self._run_workers(
         self._run_workers(
             "load_model",
             "load_model",
             max_concurrent_workers=self.parallel_config.
             max_concurrent_workers=self.parallel_config.
@@ -209,6 +212,7 @@ class RayGPUExecutor(ExecutorBase):
         )
         )
 
 
     def _init_cache(self) -> None:
     def _init_cache(self) -> None:
+        # ruff: noqa: E501
         """Profiles the memory usage and initializes the KV cache.
         """Profiles the memory usage and initializes the KV cache.
 
 
         The engine will first conduct a profiling of the existing memory usage.
         The engine will first conduct a profiling of the existing memory usage.
@@ -228,7 +232,7 @@ class RayGPUExecutor(ExecutorBase):
         .. tip::
         .. tip::
             You may limit the usage of GPU memory
             You may limit the usage of GPU memory
             by adjusting the `gpu_memory_utilization` parameter.
             by adjusting the `gpu_memory_utilization` parameter.
-        """  # noqa: E501
+        """
         # Get the maximum number of blocks that can be allocated on GPU and CPU.
         # Get the maximum number of blocks that can be allocated on GPU and CPU.
         num_blocks = self._run_workers(
         num_blocks = self._run_workers(
             "profile_num_available_blocks",
             "profile_num_available_blocks",
@@ -265,13 +269,11 @@ class RayGPUExecutor(ExecutorBase):
         # if enforce_eager is False.
         # if enforce_eager is False.
         self._run_workers("warm_up_model")
         self._run_workers("warm_up_model")
 
 
-    def execute_model(
-        self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
-        blocks_to_swap_in: Dict[int, int],
-        blocks_to_swap_out: Dict[int, int],
-        blocks_to_copy: Dict[int, List[int]],
-    ) -> SamplerOutput:
+    def execute_model(self,
+                      seq_group_metadata_list: List[SequenceGroupMetadata],
+                      blocks_to_swap_in: Dict[int, int],
+                      blocks_to_swap_out: Dict[int, int],
+                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
         all_outputs = self._run_workers(
         all_outputs = self._run_workers(
             "execute_model",
             "execute_model",
             driver_kwargs={
             driver_kwargs={
@@ -280,8 +282,7 @@ class RayGPUExecutor(ExecutorBase):
                 "blocks_to_swap_out": blocks_to_swap_out,
                 "blocks_to_swap_out": blocks_to_swap_out,
                 "blocks_to_copy": blocks_to_copy,
                 "blocks_to_copy": blocks_to_copy,
             },
             },
-            use_ray_compiled_dag=USE_RAY_COMPILED_DAG,
-        )
+            use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
 
 
         # Only the driver worker returns the sampling results.
         # Only the driver worker returns the sampling results.
         output = all_outputs[0]
         output = all_outputs[0]
@@ -322,7 +323,7 @@ class RayGPUExecutor(ExecutorBase):
 
 
         if use_ray_compiled_dag:
         if use_ray_compiled_dag:
             # Right now, compiled DAG can only accept a single
             # Right now, compiled DAG can only accept a single
-            # input. TODO(sang): Fix it.
+            # input. TODO: Fix it.
             output_channels = self.forward_dag.execute(1)
             output_channels = self.forward_dag.execute(1)
         else:
         else:
             # Start the ray workers first.
             # Start the ray workers first.
@@ -359,7 +360,6 @@ class RayGPUExecutor(ExecutorBase):
 
 
     def _compiled_ray_dag(self):
     def _compiled_ray_dag(self):
         import pkg_resources
         import pkg_resources
-
         required_version = "2.9"
         required_version = "2.9"
         current_version = pkg_resources.get_distribution("ray").version
         current_version = pkg_resources.get_distribution("ray").version
         if current_version < required_version:
         if current_version < required_version:
@@ -367,7 +367,6 @@ class RayGPUExecutor(ExecutorBase):
                              f"required, but found {current_version}")
                              f"required, but found {current_version}")
 
 
         from ray.dag import MultiOutputNode, InputNode
         from ray.dag import MultiOutputNode, InputNode
-
         assert self.parallel_config.worker_use_ray
         assert self.parallel_config.worker_use_ray
 
 
         # Right now, compiled DAG requires at least 1 arg. We send
         # Right now, compiled DAG requires at least 1 arg. We send
@@ -440,8 +439,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
                 "blocks_to_swap_in": blocks_to_swap_in,
                 "blocks_to_swap_in": blocks_to_swap_in,
                 "blocks_to_swap_out": blocks_to_swap_out,
                 "blocks_to_swap_out": blocks_to_swap_out,
                 "blocks_to_copy": blocks_to_copy,
                 "blocks_to_copy": blocks_to_copy,
-            },
-        )
+            })
 
 
         # Only the driver worker returns the sampling results.
         # Only the driver worker returns the sampling results.
         output = all_outputs[0]
         output = all_outputs[0]