openvino_executor.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from typing import List, Set, Tuple
  2. import os
  3. import openvino as ov
  4. import openvino.properties.hint as hints
  5. import torch
  6. from loguru import logger
  7. from aphrodite.common.config import CacheConfig, ModelConfig
  8. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  9. from aphrodite.common.utils import (get_distributed_init_method, get_ip,
  10. get_open_port, make_async)
  11. from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
  12. from aphrodite.lora.request import LoRARequest
  13. APHRODITE_OPENVINO_KVCACHE_SPACE = int(
  14. os.getenv("APHRODITE_OPENVINO_KVCACHE_SPACE", 0))
  15. APHRODITE_OPENVINO_CPU_KV_CACHE_PRECISION = os.getenv(
  16. "APHRODITE_OPENVINO_CPU_KV_CACHE_PRECISION", None)
  17. class OpenVINOExecutor(ExecutorBase):
  18. uses_ray: bool = False
  19. def _init_executor(self) -> None:
  20. assert self.device_config.device_type == "openvino"
  21. assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
  22. self.model_config = _verify_and_get_model_config(self.model_config)
  23. self.cache_config = _verify_and_get_cache_config(self.cache_config)
  24. # Instantiate the worker and load the model to CPU.
  25. self._init_worker()
  26. def _init_worker(self):
  27. from aphrodite.task_handler.openvino_worker import OpenVINOWorker
  28. assert (
  29. self.parallel_config.world_size == 1
  30. ), "OpenVINOExecutor only supports single CPU socket currently."
  31. distributed_init_method = get_distributed_init_method(
  32. get_ip(), get_open_port())
  33. self.driver_worker = OpenVINOWorker(
  34. model_config=self.model_config,
  35. parallel_config=self.parallel_config,
  36. scheduler_config=self.scheduler_config,
  37. device_config=self.device_config,
  38. cache_config=self.cache_config,
  39. load_config=self.load_config,
  40. local_rank=0,
  41. rank=0,
  42. distributed_init_method=distributed_init_method,
  43. lora_config=self.lora_config,
  44. multimodal_config=self.multimodal_config,
  45. kv_cache_dtype=self.cache_config.cache_dtype,
  46. is_driver_worker=True,
  47. )
  48. self.driver_worker.init_device()
  49. self.driver_worker.load_model()
  50. def determine_num_available_blocks(self) -> Tuple[int, int]:
  51. """Determine the number of available KV blocks by invoking the
  52. underlying worker.
  53. """
  54. return self.driver_worker.determine_num_available_blocks()
  55. def initialize_cache(self, num_gpu_blocks: int,
  56. num_cpu_blocks: int) -> None:
  57. """Initialize the KV cache by invoking the underlying worker."""
  58. # NOTE: We log here to avoid multiple logs when number of workers is
  59. # greater than one. We could log in the engine, but not all executors
  60. # have GPUs.
  61. # NOTE: `cpu block` for OpenVINO backend is located on CPU memory but is
  62. # referred as `gpu block`. Because we want to reuse the existing block
  63. # management procedure.
  64. logger.info(f"# CPU blocks: {num_gpu_blocks}")
  65. logger.info(
  66. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  67. )
  68. self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  69. def execute_model(
  70. self,
  71. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  72. output = self.driver_worker.execute_model(execute_model_req)
  73. return output
  74. def add_lora(self, lora_request: LoRARequest) -> bool:
  75. return self.driver_worker.add_lora(lora_request)
  76. def remove_lora(self, lora_id: int) -> bool:
  77. return self.driver_worker.remove_lora(lora_id)
  78. def pin_lora(self, lora_id: int) -> bool:
  79. return self.driver_worker.pin_lora(lora_id)
  80. def list_loras(self) -> Set[int]:
  81. return self.driver_worker.list_loras()
  82. def add_prompt_adapter(self, prompt_adapter_request) -> bool:
  83. raise NotImplementedError(
  84. "Soft prompt is currently not supported by the OPENVINO backend.")
  85. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  86. raise NotImplementedError(
  87. "Soft prompt is currently not supported by the OPENVINO backend.")
  88. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  89. raise NotImplementedError(
  90. "Soft prompt is currently not supported by the OPENVINO backend.")
  91. def list_prompt_adapters(self) -> Set[int]:
  92. raise NotImplementedError(
  93. "Soft prompt is currently not supported by the OPENVINO backend.")
  94. def check_health(self) -> None:
  95. # OpenVINOExecutor will always be healthy as long as
  96. # it's running.
  97. return
  98. class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase):
  99. async def execute_model_async(
  100. self,
  101. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  102. output = await make_async(self.driver_worker.execute_model
  103. )(execute_model_req=execute_model_req, )
  104. return output
  105. async def check_health_async(self) -> None:
  106. # OpenVINOExecutor will always be healthy as long as
  107. # it's running.
  108. return
  109. def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
  110. if config.dtype != torch.float32:
  111. logger.warning(
  112. f"Only float32 dtype is supported on OpenVINO, casting from {config.dtype}." # noqa: G004, E501
  113. )
  114. config.dtype = torch.float32
  115. if not config.enforce_eager:
  116. logger.warning(
  117. "CUDA graph is not supported on OpenVINO backend, fallback to the "
  118. "eager mode.")
  119. config.enforce_eager = True
  120. return config
  121. def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
  122. if APHRODITE_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
  123. logger.info("KV cache type is overried to u8 via "
  124. "APHRODITE_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
  125. config.cache_dtype = ov.Type.u8
  126. else:
  127. core = ov.Core()
  128. inference_precision = core.get_property("CPU",
  129. hints.inference_precision)
  130. if inference_precision == ov.Type.bf16:
  131. config.cache_dtype = ov.Type.bf16
  132. else:
  133. config.cache_dtype = ov.Type.f16
  134. if config.block_size != 32:
  135. logger.info(
  136. f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
  137. )
  138. config.block_size = 32
  139. kv_cache_space = APHRODITE_OPENVINO_KVCACHE_SPACE
  140. if kv_cache_space >= 0:
  141. _GB = 1 << 30
  142. if kv_cache_space == 0:
  143. config.openvino_kvcache_space_bytes = 4 * _GB # type: ignore
  144. logger.warning(
  145. "Environment variable APHRODITE_OPENVINO_KVCACHE_SPACE (GB) "
  146. "for OpenVINO backend is not set, using 4 by default.")
  147. else:
  148. config.openvino_kvcache_space_bytes = kv_cache_space * _GB # type: ignore
  149. else:
  150. raise RuntimeError(
  151. "Invalid environment variable APHRODITE_OPENVINO_KVCACHE_SPACE"
  152. f" {kv_cache_space}, expect a positive integer value.")
  153. return config