|
@@ -1,4 +1,4 @@
|
|
|
-from typing import List, Optional
|
|
|
+from typing import List, Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
from loguru import logger
|
|
@@ -7,11 +7,11 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
|
|
|
LoRAConfig, ModelConfig, ParallelConfig,
|
|
|
PromptAdapterConfig, SchedulerConfig,
|
|
|
SpeculativeConfig)
|
|
|
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
|
|
|
+from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
|
|
|
+ SamplerOutput)
|
|
|
from aphrodite.common.utils import make_async
|
|
|
from aphrodite.executor.executor_base import ExecutorAsyncBase
|
|
|
from aphrodite.executor.gpu_executor import GPUExecutor
|
|
|
-from aphrodite.task_handler.worker_base import WorkerWrapperBase
|
|
|
|
|
|
|
|
|
class XPUExecutor(GPUExecutor):
|
|
@@ -49,28 +49,18 @@ class XPUExecutor(GPUExecutor):
|
|
|
# Instantiate the worker and load the model to GPU.
|
|
|
self._init_executor()
|
|
|
|
|
|
- def _create_worker(self,
|
|
|
- local_rank: int = 0,
|
|
|
- rank: int = 0,
|
|
|
- distributed_init_method: Optional[str] = None):
|
|
|
- if self.speculative_config is None:
|
|
|
- worker_module_name = "aphrodite.task_handler.xpu_worker"
|
|
|
- worker_class_name = "XPUWorker"
|
|
|
- else:
|
|
|
+ def _get_worker_module_and_class(self) -> Tuple[str, str]:
|
|
|
+ if self.speculative_config is not None:
|
|
|
raise NotImplementedError(
|
|
|
"XPU does not support speculative decoding")
|
|
|
-
|
|
|
- wrapper = WorkerWrapperBase(
|
|
|
- worker_module_name=worker_module_name,
|
|
|
- worker_class_name=worker_class_name,
|
|
|
- )
|
|
|
- wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
|
|
|
- distributed_init_method))
|
|
|
- return wrapper.worker
|
|
|
+ else:
|
|
|
+ worker_module_name = "aphrodite.task_handler.xpu_worker"
|
|
|
+ worker_class_name = "XPUWorker"
|
|
|
+ return (worker_module_name, worker_class_name)
|
|
|
|
|
|
def execute_model(
|
|
|
- self,
|
|
|
- execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
|
|
+ self, execute_model_req: ExecuteModelRequest
|
|
|
+ ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
|
|
output = self.driver_worker.execute_model(execute_model_req)
|
|
|
return output
|
|
|
|