瀏覽代碼

xpu: refactor XPU worker & executor (#861)

AlpinDale 3 月之前
父節點
當前提交
9094a8a2a3
共有 2 個文件被更改,包括 18 次插入23 次删除
  1. 11 21
      aphrodite/executor/xpu_executor.py
  2. 7 2
      aphrodite/task_handler/xpu_worker.py

+ 11 - 21
aphrodite/executor/xpu_executor.py

@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import List, Optional, Tuple, Union
 
 import torch
 from loguru import logger
@@ -7,11 +7,11 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig,
                                      SpeculativeConfig)
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
+                                       SamplerOutput)
 from aphrodite.common.utils import make_async
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.gpu_executor import GPUExecutor
-from aphrodite.task_handler.worker_base import WorkerWrapperBase
 
 
 class XPUExecutor(GPUExecutor):
@@ -49,28 +49,18 @@ class XPUExecutor(GPUExecutor):
         # Instantiate the worker and load the model to GPU.
         self._init_executor()
 
-    def _create_worker(self,
-                       local_rank: int = 0,
-                       rank: int = 0,
-                       distributed_init_method: Optional[str] = None):
-        if self.speculative_config is None:
-            worker_module_name = "aphrodite.task_handler.xpu_worker"
-            worker_class_name = "XPUWorker"
-        else:
+    def _get_worker_module_and_class(self) -> Tuple[str, str]:
+        if self.speculative_config is not None:
             raise NotImplementedError(
                 "XPU does not support speculative decoding")
-
-        wrapper = WorkerWrapperBase(
-            worker_module_name=worker_module_name,
-            worker_class_name=worker_class_name,
-        )
-        wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
-                                                      distributed_init_method))
-        return wrapper.worker
+        else:
+            worker_module_name = "aphrodite.task_handler.xpu_worker"
+            worker_class_name = "XPUWorker"
+        return (worker_module_name, worker_class_name)
 
     def execute_model(
-            self,
-            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
+        self, execute_model_req: ExecuteModelRequest
+    ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
         output = self.driver_worker.execute_model(execute_model_req)
         return output
 

+ 7 - 2
aphrodite/task_handler/xpu_worker.py

@@ -63,8 +63,9 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
         self.lora_config = lora_config
         self.prompt_adapter_config = prompt_adapter_config
         self.is_driver_worker = is_driver_worker
-        if self.is_driver_worker:
-            assert self.rank == 0, "The driver worker must have rank 0."
+        if parallel_config and is_driver_worker:
+            assert rank % parallel_config.tensor_parallel_size == 0, \
+                   "Driver worker should be rank 0 of tensor parallel group."
 
         self.multimodal_config = multimodal_config
 
@@ -175,7 +176,11 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
             # dependency (libdrm and drm headers) on your system.
             ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE",
                                                 "sockets")
+            ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
+                                             str(parallel_config.world_size))
             os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE
+            os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
+            os.environ["LOCAL_RANK"] = str(self.local_rank)
             init_distributed_environment(
                 world_size=parallel_config.world_size,
                 rank=rank,