Browse Source

reduce dedupe by wrapping in general worker class

AlpinDale 10 months ago
parent
commit
9fff6fb507

+ 15 - 2
aphrodite/common/utils.py

@@ -252,8 +252,16 @@ def get_open_port() -> int:
             return s.getsockname()[1]
 
 
-def set_cuda_visible_devices(device_ids: List[int]) -> None:
-    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
+# def set_cuda_visible_devices(device_ids: List[int]) -> None:
+#     os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
+
+
+def update_environment_variables(envs: Dict[str, str]):
+    for k, v in envs.items():
+        if k in os.environ:
+            logger.warning("Overwriting environment variable "
+                           f"{k}={os.environ[k]} with {v}")
+        os.environ[k] = v
 
 
 def chunk_list(lst, chunk_size):
@@ -486,3 +494,8 @@ def merge_dicts(dict1: Dict[Any, List[Any]],
         merged_dict[key].extend(value)
 
     return dict(merged_dict)
+
+
+def init_cached_hf_modules():
+    from transformers.dynamic_module_utils import init_hf_modules
+    init_hf_modules()

+ 6 - 29
aphrodite/engine/ray_tools.py

@@ -4,44 +4,24 @@ from typing import Optional, List, Tuple
 from loguru import logger
 
 from aphrodite.common.config import ParallelConfig
-from aphrodite.common.utils import is_hip, set_cuda_visible_devices, get_ip
+from aphrodite.common.utils import is_hip, get_ip
+from aphrodite.task_handler.worker_base import WorkerWrapperBase
 
 try:
     import ray
 
-    class RayWorkerAphrodite:
+    class RayWorkerWrapper(WorkerWrapperBase):
         """Ray wrapper for aphrodite.task_handler.Worker, allowing Worker to be
         lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
 
-        def __init__(self, init_cached_hf_modules=False) -> None:
-            if init_cached_hf_modules:
-                from transformers.dynamic_module_utils import init_hf_modules
-                init_hf_modules()
-            self.worker = None
+        def __init__(self, *args, **kwargs) -> None:
+            super().__init__(*args, **kwargs)
             # Since the compiled DAG runs a main execution
             # in a different thread that calls cuda.set_device.
             # The flag indicates is set_device is called on
             # that thread.
             self.compiled_dag_cuda_device_set = False
 
-        def init_worker(self, worker_init_fn):
-            self.worker = worker_init_fn()
-
-        def __getattr__(self, name):
-            return getattr(self.worker, name)
-
-        def execute_method(self, method, *args, **kwargs):
-            try:
-                executor = getattr(self, method)
-                return executor(*args, **kwargs)
-            except Exception as e:
-                # exceptions in ray worker may cause deadlock
-                # print the error and inform the user to solve the error
-                msg = (f"Error executing method {method}. "
-                       "This might cause deadlock in distributed execution.")
-                logger.exception(msg)
-                raise e
-
         def get_node_ip(self) -> str:
             return get_ip()
 
@@ -50,9 +30,6 @@ try:
             gpu_ids = ray.get_gpu_ids()
             return node_id, gpu_ids
 
-        def set_cuda_visible_devices(self, device_ids) -> None:
-            set_cuda_visible_devices(device_ids)
-
         def execute_model_compiled_dag_remote(self, ignored):
             """Used only when compiled DAG is enabled."""
             import torch
@@ -69,7 +46,7 @@ except ImportError as e:
                    "For distributed inference, please install Ray with "
                    "`pip install ray`.")
     ray = None
-    RayWorkerAphrodite = None
+    RayWorkerWrapper = None
 
 
 def initialize_ray_cluster(

+ 78 - 73
aphrodite/executor/ray_gpu_executor.py

@@ -1,19 +1,17 @@
 import asyncio
-import copy
-from collections import defaultdict
 import os
 import pickle
-from typing import TYPE_CHECKING, Any, Dict, List, Set, Optional
+from collections import defaultdict
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
 
 from loguru import logger
 
-from aphrodite.engine.ray_tools import RayWorkerAphrodite, ray
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.utils import (get_distributed_init_method, get_ip,
+                                    get_open_port, make_async)
+from aphrodite.engine.ray_tools import RayWorkerWrapper, ray
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
-from aphrodite.common.utils import (set_cuda_visible_devices, get_ip,
-                                    get_open_port, get_distributed_init_method,
-                                    make_async)
 
 if ray is not None:
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@@ -74,9 +72,9 @@ class RayGPUExecutor(ExecutorBase):
 
         # The driver dummy worker does not actually use any resources.
         # It holds the resource for the driver worker.
-        self.driver_dummy_worker: RayWorkerAphrodite = None
+        self.driver_dummy_worker: RayWorkerWrapper = None
         # The remaining workers are the actual ray actors.
-        self.workers: List[RayWorkerAphrodite] = []
+        self.workers: List[RayWorkerWrapper] = []
 
         if self.parallel_config.ray_workers_use_nsight:
             ray_remote_kwargs = self._configure_ray_workers_use_nsight(
@@ -97,13 +95,20 @@ class RayGPUExecutor(ExecutorBase):
                 num_gpus=num_gpus,
                 scheduling_strategy=scheduling_strategy,
                 **ray_remote_kwargs,
-            )(RayWorkerAphrodite).remote(self.model_config.trust_remote_code)
+            )(RayWorkerWrapper).remote(
+                worker_module_name="aphrodite.task_handler.worker",
+                worker_class_name="Worker",
+            )
 
             worker_ip = ray.get(worker.get_node_ip.remote())
             if worker_ip == driver_ip and self.driver_dummy_worker is None:
                 # If the worker is on the same node as the driver, we use it
                 # as the resource holder for the driver process.
                 self.driver_dummy_worker = worker
+                self.driver_worker = RayWorkerWrapper(
+                    worker_module_name="aphrodite.task_handler.worker",
+                    worker_class_name="Worker",
+                )
             else:
                 # Else, added to the list of workers.
                 self.workers.append(worker)
@@ -114,80 +119,51 @@ class RayGPUExecutor(ExecutorBase):
                 "adjusting the Ray placement group or running the driver on a "
                 "GPU node.")
 
-        # Get the set of GPU IDs used on each node.
-        driver_node_id, driver_gpu_ids = ray.get(
-            self.driver_dummy_worker.get_node_and_gpu_ids.remote())
-        worker_node_and_gpu_ids = ray.get(
-            [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
+        worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
+                                                    use_dummy_driver=True)
 
         node_workers = defaultdict(list)
         node_gpus = defaultdict(list)
 
-        node_workers[driver_node_id].append(0)
-        node_gpus[driver_node_id].extend(driver_gpu_ids)
-        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
-                                               start=1):
+        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
             node_workers[node_id].append(i)
             node_gpus[node_id].extend(gpu_ids)
         for node_id, gpu_ids in node_gpus.items():
             node_gpus[node_id] = sorted(gpu_ids)
 
         # Set CUDA_VISIBLE_DEVICES for the driver and workers.
-        set_cuda_visible_devices(node_gpus[driver_node_id])
-        for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
-            worker.set_cuda_visible_devices.remote(node_gpus[node_id])
+        all_args_to_update_environment_variables = []
+        for (node_id, _) in worker_node_and_gpu_ids:
+            all_args_to_update_environment_variables.append([{
+                "CUDA_VISIBLE_DEVICES":
+                ",".join(map(str, node_gpus[node_id])),
+            }])
 
         distributed_init_method = get_distributed_init_method(
             driver_ip, get_open_port())
 
-        # Lazy import the Worker to avoid importing torch.cuda/xformers
-        # before CUDA_VISIBLE_DEVICES is set in the Worker
-        from aphrodite.task_handler.worker import Worker
-
-        model_config = copy.deepcopy(self.model_config)
-        parallel_config = copy.deepcopy(self.parallel_config)
-        scheduler_config = copy.deepcopy(self.scheduler_config)
-        device_config = copy.deepcopy(self.device_config)
-        lora_config = copy.deepcopy(self.lora_config)
-        cache_config = copy.deepcopy(self.cache_config)
-        vision_language_config = copy.deepcopy(self.vision_language_config)
-
-        # Initialize the actual workers with the Worker class.
-        for rank, (worker, (node_id, _)) in enumerate(
-                zip(self.workers, worker_node_and_gpu_ids),
-                start=1,
-        ):
+        def collect_arg_helper_func(**kwargs):
+            return kwargs
+
+        init_worker_all_kwargs = []
+
+        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
             local_rank = node_workers[node_id].index(rank)
-            worker.init_worker.remote(
-                lambda rank=rank, local_rank=local_rank: Worker(
-                    model_config=model_config,
-                    parallel_config=parallel_config,
-                    scheduler_config=scheduler_config,
-                    device_config=device_config,
-                    cache_config=cache_config,
+            init_worker_all_kwargs.append(
+                collect_arg_helper_func(
+                    model_config=self.model_config,
+                    parallel_config=self.parallel_config,
+                    scheduler_config=self.scheduler_config,
+                    device_config=self.device_config,
+                    cache_config=self.cache_config,
                     local_rank=local_rank,
                     rank=rank,
                     distributed_init_method=distributed_init_method,
-                    lora_config=lora_config,
-                    vision_language_config=vision_language_config,
+                    lora_config=self.lora_config,
+                    vision_language_config=self.vision_language_config,
                 ))
 
-        # Initialize the driver worker with the Worker class.
-        driver_rank = 0
-        driver_local_rank = node_workers[driver_node_id].index(driver_rank)
-        self.driver_worker = Worker(
-            model_config=self.model_config,
-            parallel_config=self.parallel_config,
-            scheduler_config=self.scheduler_config,
-            device_config=self.device_config,
-            cache_config=self.cache_config,
-            local_rank=driver_local_rank,
-            rank=driver_rank,
-            distributed_init_method=distributed_init_method,
-            lora_config=self.lora_config,
-            vision_language_config=self.vision_language_config,
-            is_driver_worker=True,
-        )
+        self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
 
         self._run_workers("init_device")
         self._run_workers(
@@ -277,13 +253,31 @@ class RayGPUExecutor(ExecutorBase):
         self,
         method: str,
         *args,
-        driver_args: Optional[List[Any]] = None,
+        driver_args: Optional[Tuple[Any]] = None,
         driver_kwargs: Optional[Dict[str, Any]] = None,
+        all_args: Optional[List[List[Any]]] = None,
+        all_kwargs: Optional[List[Dict[str, Any]]] = None,
+        use_dummy_driver: bool = False,
         max_concurrent_workers: Optional[int] = None,
         use_ray_compiled_dag: bool = False,
         **kwargs,
     ) -> Any:
         """Runs the given method on all workers."""
+        if driver_args is None:
+            driver_args = args
+        if driver_kwargs is None:
+            driver_kwargs = kwargs
+
+        # For MyPy type checking
+        assert driver_args is not None
+        assert driver_kwargs is not None
+        if all_args is None:
+            all_args = [driver_args] + [args] * len(self.workers)
+        if all_kwargs is None:
+            all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers)
+
+        assert all_args is not None
+        assert all_kwargs is not None
 
         if max_concurrent_workers:
             raise NotImplementedError(
@@ -296,8 +290,10 @@ class RayGPUExecutor(ExecutorBase):
         else:
             # Start the ray workers first.
             ray_worker_outputs = [
-                worker.execute_method.remote(method, *args, **kwargs)
-                for worker in self.workers
+                worker.execute_method.remote(method, *worker_args,
+                                             **worker_kwargs)
+                for (worker, worker_args, worker_kwargs
+                     ) in zip(self.workers, all_args[1:], all_kwargs[1:])
             ]
 
         if driver_args is None:
@@ -306,8 +302,13 @@ class RayGPUExecutor(ExecutorBase):
             driver_kwargs = kwargs
 
         # Start the driver worker after all the ray workers.
-        driver_worker_output = getattr(self.driver_worker,
-                                       method)(*driver_args, **driver_kwargs)
+        if not use_dummy_driver:
+            driver_worker_output = self.driver_worker.execute_method(
+                method, *all_args[0], **all_kwargs[0])
+        else:
+            driver_worker_output = ray.get(
+                self.driver_dummy_worker.execute_method.remote(
+                    method, *all_args[0], **all_kwargs[0]))
 
         # Get the results of the ray workers.
         if self.workers:
@@ -334,7 +335,7 @@ class RayGPUExecutor(ExecutorBase):
             raise ValueError(f"Ray version {required_version} or greater is "
                              f"required, but found {current_version}")
 
-        from ray.dag import MultiOutputNode, InputNode
+        from ray.dag import InputNode, MultiOutputNode
         assert self.parallel_config.worker_use_ray
 
         # Right now, compiled DAG requires at least 1 arg. We send
@@ -383,8 +384,12 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
             driver_kwargs = kwargs
 
         # Run the driver worker asynchronously.
-        driver_executor = make_async(getattr(self.driver_worker, method))
-        coros.append(driver_executor(*driver_args, **driver_kwargs))
+        def helper():
+            return self.driver_worker.execute_method(method, *driver_args,
+                                                     **driver_kwargs)
+
+        driver_executor = make_async(helper)
+        coros.append(driver_executor())
 
         # Run the ray workers asynchronously.
         for worker in self.workers:

+ 4 - 0
aphrodite/task_handler/cpu_worker.py

@@ -140,6 +140,10 @@ class CPUWorker(LoraNotSupportedWorkerBase):
         if self.is_driver_worker:
             assert self.rank == 0, "The driver worker must have rank 0."
 
+        if self.model_config.trust_remote_code:
+            from aphrodite.common.utils import init_cached_hf_modules
+            init_cached_hf_modules()
+
         self.model_runner = CPUModelRunner(model_config,
                                            parallel_config,
                                            scheduler_config,

+ 3 - 0
aphrodite/task_handler/neuron_worker.py

@@ -33,6 +33,9 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
         self.scheduler_config = scheduler_config
         self.device_config = device_config
         self.cache_config = cache_config
+        if self.model_config.trust_remote_code:
+            from aphrodite.common.utils import init_cached_hf_modules
+            init_cached_hf_modules()
 
         self.model_runner = NeuronModelRunner(model_config, parallel_config,
                                               scheduler_config, device_config)

+ 4 - 0
aphrodite/task_handler/worker.py

@@ -68,6 +68,10 @@ class Worker(WorkerBase):
         if self.is_driver_worker:
             assert self.rank == 0, "The driver worker must have rank 0."
 
+        if self.model_config.trust_remote_code:
+            from aphrodite.common.utils import init_cached_hf_modules
+            init_cached_hf_modules()
+
         self.vision_language_config = vision_language_config
         if self.vision_language_config:
             assert not self.lora_config, (

+ 43 - 0
aphrodite/task_handler/worker_base.py

@@ -1,7 +1,12 @@
+import importlib
+import os
 from abc import ABC, abstractmethod
 from typing import Dict, List
 
+from loguru import logger
+
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.utils import update_environment_variables
 from aphrodite.lora.request import LoRARequest
 
 
@@ -80,3 +85,41 @@ class LoraNotSupportedWorkerBase(WorkerBase):
 
     def list_loras(self) -> List[int]:
         raise ValueError(f"{type(self)} does not support LoRA")
+
+
+class WorkerWrapperBase:
+
+    def __init__(self,
+                 worker_module_name=None,
+                 worker_class_name=None) -> None:
+        self.worker_module_name = worker_module_name
+        self.worker_class_name = worker_class_name
+        self.worker = None
+
+    def update_environment_variables(self, envs: Dict[str, str]) -> None:
+        """Update environment variables for the worker."""
+        key = "CUDA_VISIBLE_DEVICES"
+        if key in envs and key in os.environ:
+            del os.environ[key]
+
+        update_environment_variables(envs)
+
+    def init_worker(self, *args, **kwargs):
+        mod = importlib.import_module(self.worker_module_name)
+        worker_class = getattr(mod, self.worker_class_name)
+        self.worker = worker_class(*args, **kwargs)
+
+    def execute_method(self, method, *args, **kwargs):
+        try:
+            if hasattr(self, method):
+                executor = getattr(self, method)
+            else:
+                executor = getattr(self.worker, method)
+            return executor(*args, **kwargs)
+        except Exception as e:
+            # exceptions in ray worker may cause deadlock
+            # print the error and inform the user to solve the error
+            msg = (f"Error executing method {method}. "
+                   "This might cause deadlock in distributed execution.")
+            logger.exception(msg)
+            raise e