Browse Source

chore: refactor executor classes for easier inheritance (#840)

AlpinDale 3 months ago
parent
commit
7c7ec12f36
2 changed files with 26 additions and 22 deletions
  1. 17 12
      aphrodite/executor/gpu_executor.py
  2. 9 10
      aphrodite/executor/ray_gpu_executor.py

+ 17 - 12
aphrodite/executor/gpu_executor.py

@@ -61,6 +61,18 @@ class GPUExecutor(ExecutorBase):
             or (rank % self.parallel_config.tensor_parallel_size == 0),
         )
 
+    def _get_worker_module_and_class(self) -> Tuple[str, str]:
+        if self.scheduler_config.is_multi_step:
+            worker_module_name = "aphrodite.task_handler.multi_step_worker"
+            worker_class_name = "MultiStepWorker"
+        elif self.speculative_config:
+            worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
+            worker_class_name = "create_spec_worker"
+        else:
+            worker_module_name = "aphrodite.task_handler.worker"
+            worker_class_name = "Worker"
+        return (worker_module_name, worker_class_name)
+
     def _get_create_worker_kwargs(
             self,
             local_rank: int = 0,
@@ -68,18 +80,11 @@ class GPUExecutor(ExecutorBase):
             distributed_init_method: Optional[str] = None) -> Dict:
         worker_kwargs = self._get_worker_kwargs(local_rank, rank,
                                                 distributed_init_method)
-        if self.scheduler_config.is_multi_step:
-            worker_kwargs.update(
-                worker_module_name="aphrodite.task_handler.multi_step_worker",
-                worker_class_name="MultiStepWorker")
-        elif self.speculative_config:
-            worker_kwargs.update(
-                worker_module_name="aphrodite.spec_decode.spec_decode_worker",
-                worker_class_name="create_spec_worker")
-        else:
-            worker_kwargs.update(
-                worker_module_name="aphrodite.task_handler.worker",
-                worker_class_name="Worker")
+        (worker_module_name,
+         worker_class_name) = self._get_worker_module_and_class()
+        worker_kwargs.update(worker_module_name=worker_module_name,
+                             worker_class_name=worker_class_name)
+
         return worker_kwargs
 
     def _create_worker(self,

+ 9 - 10
aphrodite/executor/ray_gpu_executor.py

@@ -101,15 +101,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
         return ray_remote_kwargs
 
     def _get_worker_wrapper_args(self) -> Dict[str, Any]:
-        if self.speculative_config is not None:
-            worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
-            worker_class_name = "create_spec_worker"
-        elif self.scheduler_config.is_multi_step:
-            worker_module_name = "aphrodite.task_handler.multi_step_worker"
-            worker_class_name = "MultiStepWorker"
-        else:
-            worker_module_name = "aphrodite.task_handler.worker"
-            worker_class_name = "Worker"
+        (worker_module_name,
+         worker_class_name) = self._get_worker_module_and_class()
 
         return dict(
             worker_module_name=worker_module_name,
@@ -117,6 +110,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
             trust_remote_code=self.model_config.trust_remote_code,
         )
 
+    # child class could overwrite this to return actual env vars.
+    def _get_env_vars_to_be_updated(self):
+        return self._env_vars_for_all_workers
+
     def _init_workers_ray(self, placement_group: "PlacementGroup",
                           **ray_remote_kwargs):
         if (self.parallel_config.tensor_parallel_size == 1
@@ -240,8 +237,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
             "APHRODITE_TRACE_FUNCTION":
             str(APHRODITE_TRACE_FUNCTION),
         }, ) for (node_id, _) in worker_node_and_gpu_ids]
+        self._env_vars_for_all_workers = (
+            all_args_to_update_environment_variables)
         self._run_workers("update_environment_variables",
-                          all_args=all_args_to_update_environment_variables)
+                          all_args=self._get_env_vars_to_be_updated())
 
         if len(node_gpus) == 1:
             # in single node case, we don't need to get the IP address.