123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- from typing import List, Optional
- import torch
- from loguru import logger
- from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
- LoRAConfig, ModelConfig, ParallelConfig,
- PromptAdapterConfig, SchedulerConfig,
- SpeculativeConfig)
- from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
- from aphrodite.common.utils import make_async
- from aphrodite.executor.executor_base import ExecutorAsyncBase
- from aphrodite.executor.gpu_executor import GPUExecutor
- from aphrodite.task_handler.worker_base import WorkerWrapperBase
- class XPUExecutor(GPUExecutor):
- uses_ray: bool = False
- def __init__(
- self,
- model_config: ModelConfig,
- cache_config: CacheConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
- prompt_adapter_config: Optional[PromptAdapterConfig],
- speculative_config: Optional[SpeculativeConfig],
- ) -> None:
- assert device_config.device_type == "xpu"
- assert (not speculative_config
- ), "Speculative decoding not yet supported for XPU backend"
- model_config = _verify_and_get_model_config(model_config)
- self.model_config = model_config
- self.cache_config = cache_config
- self.load_config = load_config
- self.lora_config = lora_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
- self.prompt_adapter_config = prompt_adapter_config
- self.speculative_config = None
- # Instantiate the worker and load the model to GPU.
- self._init_executor()
- def _create_worker(self,
- local_rank: int = 0,
- rank: int = 0,
- distributed_init_method: Optional[str] = None):
- if self.speculative_config is None:
- worker_module_name = "aphrodite.task_handler.xpu_worker"
- worker_class_name = "XPUWorker"
- else:
- raise NotImplementedError(
- "XPU does not support speculative decoding")
- wrapper = WorkerWrapperBase(
- worker_module_name=worker_module_name,
- worker_class_name=worker_class_name,
- )
- wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
- distributed_init_method))
- return wrapper.worker
- def execute_model(
- self,
- execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
- output = self.driver_worker.execute_model(execute_model_req)
- return output
- class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase):
- async def execute_model_async(
- self,
- execute_model_req: ExecuteModelRequest,
- ) -> List[SamplerOutput]:
- output = await make_async(self.driver_worker.execute_model
- )(execute_model_req=execute_model_req)
- return output
- def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
- if config.dtype == torch.bfloat16:
- logger.warning(
- "bfloat16 is not fully supported on XPU, casting to float16.")
- config.dtype = torch.float16
- if not config.enforce_eager:
- logger.warning(
- "CUDA graph is not supported on XPU, fallback to the eager "
- "mode.")
- config.enforce_eager = True
- return config
|