openvino_worker.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. """An OpenVINO worker class."""
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import openvino as ov
  4. import torch
  5. import torch.distributed
  6. from aphrodite.attention import get_attn_backend
  7. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  8. LoRAConfig, ModelConfig, MultiModalConfig,
  9. ParallelConfig, SchedulerConfig)
  10. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  11. from aphrodite.distributed import (broadcast_tensor_dict,
  12. ensure_model_parallel_initialized,
  13. init_distributed_environment)
  14. from aphrodite.modeling import set_random_seed
  15. from aphrodite.task_handler.openvino_model_runner import OpenVINOModelRunner
  16. from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
  17. class OpenVINOCacheEngine:
  18. """Manages the KV cache for OpenVINO backend.
  19. This class is responsible for initializing and managing CPU KV
  20. caches. It also provides methods for performing KV cache operations, such
  21. as copying.
  22. """
  23. def __init__(
  24. self,
  25. cache_config: CacheConfig,
  26. model_config: ModelConfig,
  27. parallel_config: ParallelConfig,
  28. device_config: DeviceConfig,
  29. ) -> None:
  30. assert device_config.device_type == "openvino"
  31. self.cache_config = cache_config
  32. self.model_config = model_config
  33. self.parallel_config = parallel_config
  34. self.head_size = model_config.get_head_size()
  35. if device_config.device.type == "cpu" and \
  36. cache_config.cache_dtype == ov.Type.u8:
  37. # Scale, zero point and quantized data will be stored together.
  38. # The layout for per token per head:
  39. # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
  40. # so, we have to extend head_size by 8, which is sizeof(float)
  41. # for scale and sizeof(float) for zeropoint
  42. self.head_size += 8
  43. self.num_layers = model_config.get_num_layers(parallel_config)
  44. self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
  45. self.block_size = cache_config.block_size
  46. # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
  47. # for OpenVINO backend, because we want to reuse KV cache management
  48. # in the scheduler.
  49. self.num_cpu_blocks = cache_config.num_gpu_blocks
  50. # Get attention backend.
  51. self.attn_backend = get_attn_backend(
  52. self.head_size,
  53. self.model_config.get_sliding_window(),
  54. self.model_config.dtype,
  55. self.cache_config.cache_dtype,
  56. self.block_size,
  57. self.model_config.is_attention_free(),
  58. )
  59. # Initialize the cache.
  60. self.kv_cache: List[Tuple[ov.Tensor,
  61. ov.Tensor]] = self._allocate_kv_cache(
  62. self.num_cpu_blocks)
  63. def _allocate_kv_cache(
  64. self,
  65. num_blocks: int,
  66. ) -> List[Tuple[ov.Tensor, ov.Tensor]]:
  67. """Allocates KV cache."""
  68. k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
  69. num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
  70. kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
  71. for _ in range(self.num_layers):
  72. key_blocks = ov.Tensor(self.cache_config.cache_dtype,
  73. k_block_shape)
  74. value_blocks = ov.Tensor(self.cache_config.cache_dtype,
  75. v_block_shape)
  76. kv_cache.append((key_blocks, value_blocks))
  77. return kv_cache
  78. def swap_in(self, src_to_dst: Dict[int, int]) -> None:
  79. raise NotImplementedError(
  80. "Swap is not supported in OpenVINOCacheEngine.")
  81. def swap_out(self, src_to_dst: Dict[int, int]) -> None:
  82. raise NotImplementedError(
  83. "Swap is not supported in OpenVINOCacheEngine.")
  84. def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
  85. self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts)
  86. @staticmethod
  87. def get_cache_block_size(
  88. block_size: int,
  89. cache_dtype: ov.Type,
  90. model_config: ModelConfig,
  91. parallel_config: ParallelConfig,
  92. ) -> int:
  93. head_size = model_config.get_head_size()
  94. num_kv_heads = model_config.get_num_kv_heads(parallel_config)
  95. num_layers = model_config.get_num_layers(parallel_config)
  96. if cache_dtype == ov.Type.u8:
  97. # Scale, zero point and quantized data will be stored together.
  98. # The layout for per token per head:
  99. # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
  100. # so, we have to extend head_size by 8, which is sizeof(float)
  101. # for scale and sizeof(float) for zeropoint
  102. head_size += 8
  103. key_cache_block = block_size * num_kv_heads * head_size
  104. value_cache_block = key_cache_block
  105. total = num_layers * (key_cache_block + value_cache_block)
  106. dtype_size = cache_dtype.size
  107. return dtype_size * total
  108. class OpenVINOWorker(LoraNotSupportedWorkerBase):
  109. """A worker class that executes the model on OpenVINO backend.
  110. Each worker is associated with a single OpenVINO device. The worker is
  111. responsible for maintaining the KV cache and executing the model on the
  112. OpenVINO backend.
  113. """
  114. def __init__(
  115. self,
  116. model_config: ModelConfig,
  117. parallel_config: ParallelConfig,
  118. scheduler_config: SchedulerConfig,
  119. device_config: DeviceConfig,
  120. cache_config: CacheConfig,
  121. load_config: LoadConfig,
  122. local_rank: int,
  123. rank: int,
  124. distributed_init_method: str,
  125. lora_config: Optional[LoRAConfig] = None,
  126. multimodal_config: Optional[MultiModalConfig] = None,
  127. kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
  128. is_driver_worker: bool = False,
  129. ) -> None:
  130. self.model_config = model_config
  131. self.parallel_config = parallel_config
  132. self.parallel_config.rank = rank
  133. self.scheduler_config = scheduler_config
  134. self.device_config = device_config
  135. self.cache_config = cache_config
  136. self.load_config = load_config
  137. self.local_rank = local_rank
  138. self.rank = rank
  139. self.distributed_init_method = distributed_init_method
  140. self.lora_config = lora_config
  141. self.multimodal_config = multimodal_config
  142. self.is_driver_worker = is_driver_worker
  143. if self.is_driver_worker:
  144. assert self.rank == 0, "The driver worker must have rank 0."
  145. if self.model_config.trust_remote_code:
  146. # note: lazy import to avoid importing torch before initializing
  147. from aphrodite.common.utils import init_cached_hf_modules
  148. init_cached_hf_modules()
  149. self.model_runner = OpenVINOModelRunner(
  150. model_config,
  151. parallel_config,
  152. scheduler_config,
  153. device_config,
  154. cache_config,
  155. load_config=self.load_config,
  156. lora_config=self.lora_config,
  157. multimodal_config=self.multimodal_config,
  158. kv_cache_dtype=kv_cache_dtype,
  159. is_driver_worker=is_driver_worker,
  160. )
  161. # Uninitialized cache engine. Will be initialized by
  162. # initialize_cache.
  163. self.cache_engine: OpenVINOCacheEngine
  164. self.kv_cache: List[Tuple[ov.Tensor, ov.Tensor]]
  165. def init_device(self) -> None:
  166. self.init_distributed_environment()
  167. # Set random seed.
  168. set_random_seed(self.model_config.seed)
  169. def load_model(self):
  170. self.model_runner.load_model()
  171. def determine_num_available_blocks(self) -> Tuple[int, int]:
  172. """Determine the number of blocks available for the KV cache.
  173. This determines how many KV blocks can fit into the configured
  174. KV cache space.
  175. Note that since Aphrodite assumes a block resides on GPU if it can be
  176. modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
  177. This allows us to reuse the scheduler of Aphrodite without generalizing
  178. it to different devices.
  179. """
  180. # For OpenVINO backend, the block number will be calculated based on the
  181. # openvino_kvcache_space_bytes.
  182. cache_block_size = self.get_cache_block_size_bytes()
  183. num_cpu_blocks = int(self.cache_config.openvino_kvcache_space_bytes //
  184. cache_block_size)
  185. num_cpu_blocks = max(num_cpu_blocks, 0)
  186. # Note: To reuse the cache management procedure,
  187. # use cpu cache as 'gpu cache'.
  188. num_gpu_blocks = num_cpu_blocks
  189. num_cpu_blocks = 0
  190. return num_gpu_blocks, num_cpu_blocks
  191. def initialize_cache(self, num_gpu_blocks: int,
  192. num_cpu_blocks: int) -> None:
  193. """Initialize the KV cache. Currently, swappable CPU memory is not
  194. supported.
  195. Since this worker does not support GPUs, we use the num_gpu_blocks to
  196. determine how many non-swappable CPU blocks to allocate.
  197. """
  198. assert (num_cpu_blocks == 0
  199. ), f"{type(self)} does not support swappable cache"
  200. # Note: To reuse the cache management procedure,
  201. # use cpu cache as 'gpu cache'.
  202. num_cpu_blocks = num_gpu_blocks
  203. self._validate_num_cpu_blocks(num_cpu_blocks)
  204. self.cache_config.num_gpu_blocks = num_cpu_blocks
  205. self.cache_config.num_cpu_blocks = 0
  206. # Initialize the cache.
  207. self._init_cache_engine()
  208. def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
  209. """Raise errors if the num_cpu_blocks is invalid."""
  210. if num_cpu_blocks <= 0:
  211. raise ValueError(
  212. "No available memory for the cache blocks. "
  213. "Try increasing `APHRODITE_OPENVINO_KVCACHE_SPACE` when "
  214. "initializing the engine.")
  215. max_seq_len = self.cache_config.block_size * num_cpu_blocks
  216. if self.model_config.max_model_len > max_seq_len:
  217. raise ValueError(
  218. f"The model's max seq len ({self.model_config.max_model_len}) "
  219. "is larger than the maximum number of tokens that can be "
  220. f"stored in KV cache ({max_seq_len}). Try increasing "
  221. "`APHRODITE_OPENVINO_KVCACHE_SPACE` or decreasing "
  222. "`max_model_len` when initializing the engine.")
  223. def _init_cache_engine(self) -> None:
  224. self.cache_engine = OpenVINOCacheEngine(
  225. self.cache_config,
  226. self.model_config,
  227. self.parallel_config,
  228. self.device_config,
  229. )
  230. self.kv_cache = self.cache_engine.kv_cache
  231. self.model_runner.block_size = self.cache_engine.block_size
  232. assert self.kv_cache is not None
  233. # Populate the cache to warmup the memory
  234. for key_cache, value_cache in self.kv_cache:
  235. key_cache.data[:] = 0
  236. value_cache.data[:] = 0
  237. def cache_copy(
  238. self,
  239. blocks_to_copy: List[Tuple[int, int]],
  240. ) -> None:
  241. self.cache_engine.copy(blocks_to_copy) # type: ignore
  242. @torch.inference_mode()
  243. def execute_model(
  244. self,
  245. execute_model_req: Optional[ExecuteModelRequest] = None,
  246. ) -> List[SamplerOutput]:
  247. if execute_model_req is None:
  248. seq_group_metadata_list = None
  249. else:
  250. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  251. if self.is_driver_worker:
  252. assert seq_group_metadata_list is not None
  253. num_seq_groups: int = len(seq_group_metadata_list)
  254. assert execute_model_req is not None
  255. blocks_to_copy = execute_model_req.blocks_to_copy
  256. assert len(execute_model_req.blocks_to_swap_in) == 0
  257. assert len(execute_model_req.blocks_to_swap_out) == 0
  258. data: Dict[str, Any] = {
  259. "num_seq_groups": num_seq_groups,
  260. "blocks_to_copy": execute_model_req.blocks_to_copy,
  261. }
  262. broadcast_tensor_dict(data, src=0)
  263. else:
  264. data = broadcast_tensor_dict(src=0)
  265. num_seq_groups = data["num_seq_groups"]
  266. blocks_to_copy = data["blocks_to_copy"]
  267. self.cache_copy(blocks_to_copy)
  268. # If there is no input, we don't need to execute the model.
  269. if num_seq_groups == 0:
  270. return []
  271. output = self.model_runner.execute_model(seq_group_metadata_list,
  272. self.kv_cache)
  273. # OpenVINO worker only supports single-step execution.
  274. return [output]
  275. def init_distributed_environment(self) -> None:
  276. """Initialize the distributed environment."""
  277. parallel_config = self.parallel_config
  278. rank = self.rank
  279. distributed_init_method = self.distributed_init_method
  280. init_distributed_environment(
  281. world_size=parallel_config.world_size,
  282. rank=rank,
  283. distributed_init_method=distributed_init_method,
  284. backend="gloo",
  285. )
  286. # A small all_reduce for warmup.
  287. torch.distributed.all_reduce(torch.zeros(1).cpu())
  288. ensure_model_parallel_initialized(
  289. parallel_config.tensor_parallel_size,
  290. parallel_config.pipeline_parallel_size,
  291. )
  292. def get_cache_block_size_bytes(self) -> int:
  293. """Return the size in bytes of a single KV cache block."""
  294. return OpenVINOCacheEngine.get_cache_block_size(
  295. self.cache_config.block_size,
  296. self.cache_config.cache_dtype,
  297. self.model_config,
  298. self.parallel_config,
  299. )