xpu_executor.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from typing import List, Optional
  2. import torch
  3. from loguru import logger
  4. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  5. LoRAConfig, ModelConfig, ParallelConfig,
  6. PromptAdapterConfig, SchedulerConfig,
  7. SpeculativeConfig)
  8. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  9. from aphrodite.common.utils import make_async
  10. from aphrodite.executor.executor_base import ExecutorAsyncBase
  11. from aphrodite.executor.gpu_executor import GPUExecutor
  12. from aphrodite.task_handler.worker_base import WorkerWrapperBase
  13. class XPUExecutor(GPUExecutor):
  14. uses_ray: bool = False
  15. def __init__(
  16. self,
  17. model_config: ModelConfig,
  18. cache_config: CacheConfig,
  19. parallel_config: ParallelConfig,
  20. scheduler_config: SchedulerConfig,
  21. device_config: DeviceConfig,
  22. load_config: LoadConfig,
  23. lora_config: Optional[LoRAConfig],
  24. prompt_adapter_config: Optional[PromptAdapterConfig],
  25. speculative_config: Optional[SpeculativeConfig],
  26. ) -> None:
  27. assert device_config.device_type == "xpu"
  28. assert (not speculative_config
  29. ), "Speculative decoding not yet supported for XPU backend"
  30. model_config = _verify_and_get_model_config(model_config)
  31. self.model_config = model_config
  32. self.cache_config = cache_config
  33. self.load_config = load_config
  34. self.lora_config = lora_config
  35. self.parallel_config = parallel_config
  36. self.scheduler_config = scheduler_config
  37. self.device_config = device_config
  38. self.prompt_adapter_config = prompt_adapter_config
  39. self.speculative_config = None
  40. # Instantiate the worker and load the model to GPU.
  41. self._init_executor()
  42. def _create_worker(self,
  43. local_rank: int = 0,
  44. rank: int = 0,
  45. distributed_init_method: Optional[str] = None):
  46. if self.speculative_config is None:
  47. worker_module_name = "aphrodite.task_handler.xpu_worker"
  48. worker_class_name = "XPUWorker"
  49. else:
  50. raise NotImplementedError(
  51. "XPU does not support speculative decoding")
  52. wrapper = WorkerWrapperBase(
  53. worker_module_name=worker_module_name,
  54. worker_class_name=worker_class_name,
  55. )
  56. wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
  57. distributed_init_method))
  58. return wrapper.worker
  59. def execute_model(
  60. self,
  61. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  62. output = self.driver_worker.execute_model(execute_model_req)
  63. return output
  64. class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase):
  65. async def execute_model_async(
  66. self,
  67. execute_model_req: ExecuteModelRequest,
  68. ) -> List[SamplerOutput]:
  69. output = await make_async(self.driver_worker.execute_model
  70. )(execute_model_req=execute_model_req)
  71. return output
  72. def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
  73. if config.dtype == torch.bfloat16:
  74. logger.warning(
  75. "bfloat16 is not fully supported on XPU, casting to float16.")
  76. config.dtype = torch.float16
  77. if not config.enforce_eager:
  78. logger.warning(
  79. "CUDA graph is not supported on XPU, fallback to the eager "
  80. "mode.")
  81. config.enforce_eager = True
  82. return config