openvino_executor.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from typing import List, Set, Tuple
  2. import openvino as ov
  3. import openvino.properties.hint as hints
  4. import torch
  5. from loguru import logger
  6. import aphrodite.common.envs as envs
  7. from aphrodite.common.config import CacheConfig, ModelConfig
  8. from aphrodite.common.sequence import ExecuteModelRequest
  9. from aphrodite.common.utils import (GiB_bytes, get_distributed_init_method,
  10. get_ip, get_open_port, make_async)
  11. from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
  12. from aphrodite.lora.request import LoRARequest
  13. from aphrodite.modeling.layers.sampler import SamplerOutput
  14. APHRODITE_OPENVINO_KVCACHE_SPACE = envs.APHRODITE_OPENVINO_KVCACHE_SPACE
  15. APHRODITE_OPENVINO_CPU_KV_CACHE_PRECISION = (
  16. envs.APHRODITE_OPENVINO_CPU_KV_CACHE_PRECISION)
  17. class OpenVINOExecutor(ExecutorBase):
  18. uses_ray: bool = False
  19. def _init_executor(self) -> None:
  20. assert self.device_config.device_type == "openvino"
  21. assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
  22. self.model_config = _verify_and_get_model_config(self.model_config)
  23. self.cache_config = _verify_and_get_cache_config(self.cache_config)
  24. # Instantiate the worker and load the model to CPU.
  25. self._init_worker()
  26. def _init_worker(self):
  27. from aphrodite.worker.openvino_worker import OpenVINOWorker
  28. assert (
  29. self.parallel_config.world_size == 1
  30. ), "OpenVINOExecutor only supports single CPU socket currently."
  31. distributed_init_method = get_distributed_init_method(
  32. get_ip(), get_open_port())
  33. self.driver_worker = OpenVINOWorker(
  34. model_config=self.model_config,
  35. parallel_config=self.parallel_config,
  36. scheduler_config=self.scheduler_config,
  37. device_config=self.device_config,
  38. cache_config=self.cache_config,
  39. load_config=self.load_config,
  40. local_rank=0,
  41. rank=0,
  42. distributed_init_method=distributed_init_method,
  43. lora_config=self.lora_config,
  44. kv_cache_dtype=self.cache_config.cache_dtype,
  45. is_driver_worker=True,
  46. )
  47. self.driver_worker.init_device()
  48. self.driver_worker.load_model()
  49. def determine_num_available_blocks(self) -> Tuple[int, int]:
  50. """Determine the number of available KV blocks by invoking the
  51. underlying worker.
  52. """
  53. return self.driver_worker.determine_num_available_blocks()
  54. def initialize_cache(self, num_gpu_blocks: int,
  55. num_cpu_blocks: int) -> None:
  56. """Initialize the KV cache by invoking the underlying worker."""
  57. # NOTE: We log here to avoid multiple logs when number of workers is
  58. # greater than one. We could log in the engine, but not all executors
  59. # have GPUs.
  60. # NOTE: `cpu block` for OpenVINO backend is located on CPU memory but is
  61. # referred as `gpu block`. Because we want to reuse the existing block
  62. # management procedure.
  63. logger.info(f"# CPU blocks: {num_gpu_blocks}")
  64. logger.info(
  65. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  66. )
  67. self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  68. def execute_model(
  69. self,
  70. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  71. output = self.driver_worker.execute_model(execute_model_req)
  72. return output
  73. def add_lora(self, lora_request: LoRARequest) -> bool:
  74. return self.driver_worker.add_lora(lora_request)
  75. def remove_lora(self, lora_id: int) -> bool:
  76. return self.driver_worker.remove_lora(lora_id)
  77. def pin_lora(self, lora_id: int) -> bool:
  78. return self.driver_worker.pin_lora(lora_id)
  79. def list_loras(self) -> Set[int]:
  80. return self.driver_worker.list_loras()
  81. def add_prompt_adapter(self, prompt_adapter_request) -> bool:
  82. raise NotImplementedError(
  83. "Soft prompt is currently not supported by the OPENVINO backend.")
  84. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  85. raise NotImplementedError(
  86. "Soft prompt is currently not supported by the OPENVINO backend.")
  87. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  88. raise NotImplementedError(
  89. "Soft prompt is currently not supported by the OPENVINO backend.")
  90. def list_prompt_adapters(self) -> Set[int]:
  91. raise NotImplementedError(
  92. "Soft prompt is currently not supported by the OPENVINO backend.")
  93. def check_health(self) -> None:
  94. # OpenVINOExecutor will always be healthy as long as
  95. # it's running.
  96. return
  97. class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase):
  98. async def execute_model_async(
  99. self,
  100. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  101. output = await make_async(self.driver_worker.execute_model
  102. )(execute_model_req=execute_model_req, )
  103. return output
  104. async def check_health_async(self) -> None:
  105. # OpenVINOExecutor will always be healthy as long as
  106. # it's running.
  107. return
  108. def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
  109. if config.dtype != torch.float32:
  110. logger.warning(
  111. f"Only float32 dtype is supported on OpenVINO, casting from {config.dtype}." # noqa: G004, E501
  112. )
  113. config.dtype = torch.float32
  114. if not config.enforce_eager:
  115. logger.warning(
  116. "CUDA graph is not supported on OpenVINO backend, fallback to the "
  117. "eager mode.")
  118. config.enforce_eager = True
  119. return config
  120. def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
  121. if APHRODITE_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
  122. logger.info("KV cache type is overried to u8 via "
  123. "APHRODITE_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
  124. config.cache_dtype = ov.Type.u8
  125. else:
  126. core = ov.Core()
  127. inference_precision = core.get_property("CPU",
  128. hints.inference_precision)
  129. if inference_precision == ov.Type.bf16:
  130. config.cache_dtype = ov.Type.bf16
  131. else:
  132. config.cache_dtype = ov.Type.f16
  133. if config.block_size != 32:
  134. logger.info(
  135. f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
  136. )
  137. config.block_size = 32
  138. kv_cache_space = APHRODITE_OPENVINO_KVCACHE_SPACE
  139. if kv_cache_space >= 0:
  140. if kv_cache_space == 0:
  141. config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
  142. logger.warning(
  143. "Environment variable APHRODITE_OPENVINO_KVCACHE_SPACE (GB) "
  144. "for OpenVINO backend is not set, using 4 by default.")
  145. else:
  146. config.openvino_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
  147. else:
  148. raise RuntimeError(
  149. "Invalid environment variable APHRODITE_OPENVINO_KVCACHE_SPACE"
  150. f" {kv_cache_space}, expect a positive integer value.")
  151. return config