Browse Source

centralize gpu worker construction

AlpinDale 8 months ago
parent
commit
7bcf4c3fc9

+ 39 - 44
aphrodite/executor/gpu_executor.py

@@ -1,10 +1,11 @@
-from typing import Dict, List, Set, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple
 
 
 from loguru import logger
 from loguru import logger
 
 
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
                                     get_open_port, make_async)
+from aphrodite.task_handler.worker_base import WorkerWrapperBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 
 
@@ -22,30 +23,47 @@ class GPUExecutor(ExecutorBase):
         else:
         else:
             self._init_spec_worker()
             self._init_spec_worker()
 
 
-    def _init_non_spec_worker(self):
+    def _get_worker_kwargs(
-        # Lazy import the Worker to avoid importing torch.cuda/xformers
+            self,
-        # before CUDA_VISIBLE_DEVICES is set in the Worker
+            local_rank: int = 0,
-        from aphrodite.task_handler.worker import Worker
+            rank: int = 0,
-
+            distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
-        assert self.parallel_config.world_size == 1, (
+        """Return worker init args for a given rank."""
-            "GPUExecutor only supports single GPU.")
+        if distributed_init_method is None:
-
+            distributed_init_method = get_distributed_init_method(
-        distributed_init_method = get_distributed_init_method(
+                get_ip(), get_open_port())
-            get_ip(), get_open_port())
+        return dict(
-        self.driver_worker = Worker(
             model_config=self.model_config,
             model_config=self.model_config,
             parallel_config=self.parallel_config,
             parallel_config=self.parallel_config,
             scheduler_config=self.scheduler_config,
             scheduler_config=self.scheduler_config,
             device_config=self.device_config,
             device_config=self.device_config,
             cache_config=self.cache_config,
             cache_config=self.cache_config,
             load_config=self.load_config,
             load_config=self.load_config,
-            local_rank=0,
+            local_rank=local_rank,
-            rank=0,
+            rank=rank,
             distributed_init_method=distributed_init_method,
             distributed_init_method=distributed_init_method,
             lora_config=self.lora_config,
             lora_config=self.lora_config,
             vision_language_config=self.vision_language_config,
             vision_language_config=self.vision_language_config,
-            is_driver_worker=True,
+            is_driver_worker=rank == 0,
+        )
+
+    def _create_worker(self,
+                       local_rank: int = 0,
+                       rank: int = 0,
+                       distributed_init_method: Optional[str] = None):
+        wrapper = WorkerWrapperBase(
+            worker_module_name="vllm.worker.worker",
+            worker_class_name="Worker",
         )
         )
+        wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
+                                                      distributed_init_method))
+        return wrapper.worker
+
+    def _init_non_spec_worker(self):
+        assert self.parallel_config.world_size == 1, (
+            "GPUExecutor only supports single GPU.")
+
+        self.driver_worker = self._create_worker()
         self.driver_worker.init_device()
         self.driver_worker.init_device()
         self.driver_worker.load_model()
         self.driver_worker.load_model()
 
 
@@ -56,41 +74,18 @@ class GPUExecutor(ExecutorBase):
 
 
         from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
         from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
         from aphrodite.spec_decode.spec_decode_worker import SpecDecodeWorker
         from aphrodite.spec_decode.spec_decode_worker import SpecDecodeWorker
-        from aphrodite.task_handler.worker import Worker
 
 
-        distributed_init_method = get_distributed_init_method(
+        target_worker = self._create_worker()
-            get_ip(), get_open_port())
-
-        target_worker = Worker(
-            model_config=self.model_config,
-            parallel_config=self.parallel_config,
-            scheduler_config=self.scheduler_config,
-            device_config=self.device_config,
-            cache_config=self.cache_config,
-            load_config=self.load_config,
-            local_rank=0,
-            rank=0,
-            distributed_init_method=distributed_init_method,
-            lora_config=self.lora_config,
-            vision_language_config=self.vision_language_config,
-            is_driver_worker=True,
-        )
 
 
-        draft_worker = MultiStepWorker(
+        draft_worker_kwargs = self._get_worker_kwargs()
+        # Override draft-model specific worker args.
+        draft_worker_kwargs.update(
             model_config=self.speculative_config.draft_model_config,
             model_config=self.speculative_config.draft_model_config,
             parallel_config=self.speculative_config.draft_parallel_config,
             parallel_config=self.speculative_config.draft_parallel_config,
-            scheduler_config=self.scheduler_config,
-            device_config=self.device_config,
-            cache_config=self.cache_config,
             # TODO allow draft-model specific load config.
             # TODO allow draft-model specific load config.
-            load_config=self.load_config,
+            #load_config=self.load_config,
-            local_rank=0,
-            rank=0,
-            distributed_init_method=distributed_init_method,
-            lora_config=self.lora_config,
-            vision_language_config=self.vision_language_config,
-            is_driver_worker=True,
         )
         )
+        draft_worker = MultiStepWorker(**draft_worker_kwargs)
 
 
         spec_decode_worker = SpecDecodeWorker.from_workers(
         spec_decode_worker = SpecDecodeWorker.from_workers(
             proposer_worker=draft_worker, scorer_worker=target_worker)
             proposer_worker=draft_worker, scorer_worker=target_worker)

+ 8 - 24
aphrodite/executor/ray_gpu_executor.py

@@ -151,29 +151,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
         distributed_init_method = get_distributed_init_method(
         distributed_init_method = get_distributed_init_method(
             driver_ip, get_open_port())
             driver_ip, get_open_port())
 
 
-        def collect_arg_helper_func(**kwargs):
-            # avoid writing `{"name": value}` manually
-            return kwargs
-
         # Initialize the actual workers inside worker wrapper.
         # Initialize the actual workers inside worker wrapper.
-        init_worker_all_kwargs = []
+        init_worker_all_kwargs = [
-        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
+            self._get_worker_kwargs(
-            local_rank = node_workers[node_id].index(rank)
+                local_rank=node_workers[node_id].index(rank),
-            init_worker_all_kwargs.append(
+                rank=rank,
-                collect_arg_helper_func(
+                distributed_init_method=distributed_init_method,
-                    model_config=self.model_config,
+            ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
-                    parallel_config=self.parallel_config,
+        ]
-                    scheduler_config=self.scheduler_config,
-                    device_config=self.device_config,
-                    cache_config=self.cache_config,
-                    load_config=self.load_config,
-                    local_rank=local_rank,
-                    rank=rank,
-                    distributed_init_method=distributed_init_method,
-                    lora_config=self.lora_config,
-                    vision_language_config=self.vision_language_config,
-                    is_driver_worker=rank == 0,
-                ))
         self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
         self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
 
 
         self._run_workers("init_device")
         self._run_workers("init_device")
@@ -199,8 +184,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
             use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
             use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
 
 
         # Only the driver worker returns the sampling results.
         # Only the driver worker returns the sampling results.
-        output = all_outputs[0]
+        return all_outputs[0]
-        return output
 
 
     def _run_workers(
     def _run_workers(
         self,
         self,

+ 10 - 0
aphrodite/spec_decode/spec_decode_worker.py

@@ -46,6 +46,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         correctness tests pass.
         correctness tests pass.
     """
     """
 
 
+    @classmethod
+    def from_workers(cls, proposer_worker: MultiStepWorker,
+                     scorer_worker: WorkerBase) -> "SpecDecodeWorker":
+        return SpecDecodeWorker(
+            proposer_worker,
+            scorer_worker,
+            # TODO: disable strict mode for speedup.
+            rejection_sampler=RejectionSampler(strict_mode=True),
+        )
+
     @classmethod
     @classmethod
     def create_worker(
     def create_worker(
         cls,
         cls,