浏览代码

centralize gpu worker construction

AlpinDale 8 月之前
父节点
当前提交
7bcf4c3fc9
共有 3 个文件被更改,包括 57 次插入68 次删除
  1. 39 44
      aphrodite/executor/gpu_executor.py
  2. 8 24
      aphrodite/executor/ray_gpu_executor.py
  3. 10 0
      aphrodite/spec_decode/spec_decode_worker.py

+ 39 - 44
aphrodite/executor/gpu_executor.py

@@ -1,10 +1,11 @@
-from typing import Dict, List, Set, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple
 
 from loguru import logger
 
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
+from aphrodite.task_handler.worker_base import WorkerWrapperBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
 
@@ -22,30 +23,47 @@ class GPUExecutor(ExecutorBase):
         else:
             self._init_spec_worker()
 
-    def _init_non_spec_worker(self):
-        # Lazy import the Worker to avoid importing torch.cuda/xformers
-        # before CUDA_VISIBLE_DEVICES is set in the Worker
-        from aphrodite.task_handler.worker import Worker
-
-        assert self.parallel_config.world_size == 1, (
-            "GPUExecutor only supports single GPU.")
-
-        distributed_init_method = get_distributed_init_method(
-            get_ip(), get_open_port())
-        self.driver_worker = Worker(
+    def _get_worker_kwargs(
+            self,
+            local_rank: int = 0,
+            rank: int = 0,
+            distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
+        """Return worker init args for a given rank."""
+        if distributed_init_method is None:
+            distributed_init_method = get_distributed_init_method(
+                get_ip(), get_open_port())
+        return dict(
             model_config=self.model_config,
             parallel_config=self.parallel_config,
             scheduler_config=self.scheduler_config,
             device_config=self.device_config,
             cache_config=self.cache_config,
             load_config=self.load_config,
-            local_rank=0,
-            rank=0,
+            local_rank=local_rank,
+            rank=rank,
             distributed_init_method=distributed_init_method,
             lora_config=self.lora_config,
             vision_language_config=self.vision_language_config,
-            is_driver_worker=True,
+            is_driver_worker=rank == 0,
+        )
+
+    def _create_worker(self,
+                       local_rank: int = 0,
+                       rank: int = 0,
+                       distributed_init_method: Optional[str] = None):
+        wrapper = WorkerWrapperBase(
+            worker_module_name="vllm.worker.worker",
+            worker_class_name="Worker",
         )
+        wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
+                                                      distributed_init_method))
+        return wrapper.worker
+
+    def _init_non_spec_worker(self):
+        assert self.parallel_config.world_size == 1, (
+            "GPUExecutor only supports single GPU.")
+
+        self.driver_worker = self._create_worker()
         self.driver_worker.init_device()
         self.driver_worker.load_model()
 
@@ -56,41 +74,18 @@ class GPUExecutor(ExecutorBase):
 
         from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
         from aphrodite.spec_decode.spec_decode_worker import SpecDecodeWorker
-        from aphrodite.task_handler.worker import Worker
 
-        distributed_init_method = get_distributed_init_method(
-            get_ip(), get_open_port())
-
-        target_worker = Worker(
-            model_config=self.model_config,
-            parallel_config=self.parallel_config,
-            scheduler_config=self.scheduler_config,
-            device_config=self.device_config,
-            cache_config=self.cache_config,
-            load_config=self.load_config,
-            local_rank=0,
-            rank=0,
-            distributed_init_method=distributed_init_method,
-            lora_config=self.lora_config,
-            vision_language_config=self.vision_language_config,
-            is_driver_worker=True,
-        )
+        target_worker = self._create_worker()
 
-        draft_worker = MultiStepWorker(
+        draft_worker_kwargs = self._get_worker_kwargs()
+        # Override draft-model specific worker args.
+        draft_worker_kwargs.update(
             model_config=self.speculative_config.draft_model_config,
             parallel_config=self.speculative_config.draft_parallel_config,
-            scheduler_config=self.scheduler_config,
-            device_config=self.device_config,
-            cache_config=self.cache_config,
             # TODO allow draft-model specific load config.
-            load_config=self.load_config,
-            local_rank=0,
-            rank=0,
-            distributed_init_method=distributed_init_method,
-            lora_config=self.lora_config,
-            vision_language_config=self.vision_language_config,
-            is_driver_worker=True,
+            #load_config=self.load_config,
         )
+        draft_worker = MultiStepWorker(**draft_worker_kwargs)
 
         spec_decode_worker = SpecDecodeWorker.from_workers(
             proposer_worker=draft_worker, scorer_worker=target_worker)

+ 8 - 24
aphrodite/executor/ray_gpu_executor.py

@@ -151,29 +151,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
         distributed_init_method = get_distributed_init_method(
             driver_ip, get_open_port())
 
-        def collect_arg_helper_func(**kwargs):
-            # avoid writing `{"name": value}` manually
-            return kwargs
-
         # Initialize the actual workers inside worker wrapper.
-        init_worker_all_kwargs = []
-        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
-            local_rank = node_workers[node_id].index(rank)
-            init_worker_all_kwargs.append(
-                collect_arg_helper_func(
-                    model_config=self.model_config,
-                    parallel_config=self.parallel_config,
-                    scheduler_config=self.scheduler_config,
-                    device_config=self.device_config,
-                    cache_config=self.cache_config,
-                    load_config=self.load_config,
-                    local_rank=local_rank,
-                    rank=rank,
-                    distributed_init_method=distributed_init_method,
-                    lora_config=self.lora_config,
-                    vision_language_config=self.vision_language_config,
-                    is_driver_worker=rank == 0,
-                ))
+        init_worker_all_kwargs = [
+            self._get_worker_kwargs(
+                local_rank=node_workers[node_id].index(rank),
+                rank=rank,
+                distributed_init_method=distributed_init_method,
+            ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
+        ]
         self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
 
         self._run_workers("init_device")
@@ -199,8 +184,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
             use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
 
         # Only the driver worker returns the sampling results.
-        output = all_outputs[0]
-        return output
+        return all_outputs[0]
 
     def _run_workers(
         self,

+ 10 - 0
aphrodite/spec_decode/spec_decode_worker.py

@@ -46,6 +46,16 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         correctness tests pass.
     """
 
+    @classmethod
+    def from_workers(cls, proposer_worker: MultiStepWorker,
+                     scorer_worker: WorkerBase) -> "SpecDecodeWorker":
+        return SpecDecodeWorker(
+            proposer_worker,
+            scorer_worker,
+            # TODO: disable strict mode for speedup.
+            rejection_sampler=RejectionSampler(strict_mode=True),
+        )
+
     @classmethod
     def create_worker(
         cls,