1
0

xpu_executor.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from typing import Callable, List, Optional, Tuple, Type, Union
  2. import torch
  3. from loguru import logger
  4. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  5. LoRAConfig, ModelConfig, ParallelConfig,
  6. PromptAdapterConfig, SchedulerConfig,
  7. SpeculativeConfig)
  8. from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
  9. SamplerOutput)
  10. from aphrodite.common.utils import make_async
  11. from aphrodite.executor.executor_base import ExecutorAsyncBase
  12. from aphrodite.executor.gpu_executor import GPUExecutor
  13. from aphrodite.task_handler.worker_base import WorkerBase
  14. class XPUExecutor(GPUExecutor):
  15. uses_ray: bool = False
  16. def __init__(
  17. self,
  18. model_config: ModelConfig,
  19. cache_config: CacheConfig,
  20. parallel_config: ParallelConfig,
  21. scheduler_config: SchedulerConfig,
  22. device_config: DeviceConfig,
  23. load_config: LoadConfig,
  24. lora_config: Optional[LoRAConfig],
  25. prompt_adapter_config: Optional[PromptAdapterConfig],
  26. speculative_config: Optional[SpeculativeConfig],
  27. ) -> None:
  28. assert device_config.device_type == "xpu"
  29. assert (not speculative_config
  30. ), "Speculative decoding not yet supported for XPU backend"
  31. model_config = _verify_and_get_model_config(model_config)
  32. self.model_config = model_config
  33. self.cache_config = cache_config
  34. self.load_config = load_config
  35. self.lora_config = lora_config
  36. self.parallel_config = parallel_config
  37. self.scheduler_config = scheduler_config
  38. self.device_config = device_config
  39. self.prompt_adapter_config = prompt_adapter_config
  40. self.speculative_config = None
  41. # Instantiate the worker and load the model to GPU.
  42. self._init_executor()
  43. def _get_worker_module_and_class(
  44. self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
  45. worker_class_fn = None
  46. if self.speculative_config is not None:
  47. raise NotImplementedError(
  48. "XPU does not support speculative decoding")
  49. else:
  50. worker_module_name = "aphrodite.task_handler.xpu_worker"
  51. worker_class_name = "XPUWorker"
  52. return (worker_module_name, worker_class_name, worker_class_fn)
  53. def execute_model(
  54. self, execute_model_req: ExecuteModelRequest
  55. ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
  56. output = self.driver_worker.execute_model(execute_model_req)
  57. return output
  58. class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase):
  59. async def execute_model_async(
  60. self,
  61. execute_model_req: ExecuteModelRequest,
  62. ) -> List[SamplerOutput]:
  63. output = await make_async(self.driver_worker.execute_model
  64. )(execute_model_req=execute_model_req)
  65. return output
  66. def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
  67. if config.dtype == torch.bfloat16:
  68. logger.warning(
  69. "bfloat16 is not fully supported on XPU, casting to float16.")
  70. config.dtype = torch.float16
  71. if not config.enforce_eager:
  72. logger.warning(
  73. "CUDA graph is not supported on XPU, fallback to the eager "
  74. "mode.")
  75. config.enforce_eager = True
  76. return config