openvino_worker.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. """An OpenVINO worker class."""
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import openvino as ov
  4. import torch
  5. import torch.distributed
  6. from aphrodite.attention import get_attn_backend
  7. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  8. LoRAConfig, ModelConfig, MultiModalConfig,
  9. ParallelConfig, SchedulerConfig)
  10. from aphrodite.common.sequence import ExecuteModelRequest
  11. from aphrodite.distributed import (broadcast_tensor_dict,
  12. ensure_model_parallel_initialized,
  13. init_distributed_environment)
  14. from aphrodite.modeling import set_random_seed
  15. from aphrodite.modeling.layers.sampler import SamplerOutput
  16. from aphrodite.worker.openvino_model_runner import OpenVINOModelRunner
  17. from aphrodite.worker.worker_base import LoraNotSupportedWorkerBase
  18. class OpenVINOCacheEngine:
  19. """Manages the KV cache for OpenVINO backend.
  20. This class is responsible for initializing and managing CPU KV
  21. caches. It also provides methods for performing KV cache operations, such
  22. as copying.
  23. """
  24. def __init__(
  25. self,
  26. cache_config: CacheConfig,
  27. model_config: ModelConfig,
  28. parallel_config: ParallelConfig,
  29. device_config: DeviceConfig,
  30. ) -> None:
  31. assert device_config.device_type == "openvino"
  32. self.cache_config = cache_config
  33. self.model_config = model_config
  34. self.parallel_config = parallel_config
  35. self.head_size = model_config.get_head_size()
  36. if device_config.device.type == "cpu" and \
  37. cache_config.cache_dtype == ov.Type.u8:
  38. # Scale, zero point and quantized data will be stored together.
  39. # The layout for per token per head:
  40. # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
  41. # so, we have to extend head_size by 8, which is sizeof(float)
  42. # for scale and sizeof(float) for zeropoint
  43. self.head_size += 8
  44. self.num_layers = model_config.get_num_layers(parallel_config)
  45. self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
  46. self.block_size = cache_config.block_size
  47. # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
  48. # for OpenVINO backend, because we want to reuse KV cache management
  49. # in the scheduler.
  50. self.num_cpu_blocks = cache_config.num_gpu_blocks
  51. # Get attention backend.
  52. self.attn_backend = get_attn_backend(
  53. self.head_size,
  54. self.model_config.get_sliding_window(),
  55. self.model_config.dtype,
  56. self.cache_config.cache_dtype,
  57. self.block_size,
  58. self.model_config.is_attention_free(),
  59. )
  60. # Initialize the cache.
  61. self.kv_cache: List[Tuple[ov.Tensor,
  62. ov.Tensor]] = self._allocate_kv_cache(
  63. self.num_cpu_blocks)
  64. def _allocate_kv_cache(
  65. self,
  66. num_blocks: int,
  67. ) -> List[Tuple[ov.Tensor, ov.Tensor]]:
  68. """Allocates KV cache."""
  69. k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
  70. num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
  71. kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
  72. for _ in range(self.num_layers):
  73. key_blocks = ov.Tensor(self.cache_config.cache_dtype,
  74. k_block_shape)
  75. value_blocks = ov.Tensor(self.cache_config.cache_dtype,
  76. v_block_shape)
  77. kv_cache.append((key_blocks, value_blocks))
  78. return kv_cache
  79. def swap_in(self, src_to_dst: Dict[int, int]) -> None:
  80. raise NotImplementedError(
  81. "Swap is not supported in OpenVINOCacheEngine.")
  82. def swap_out(self, src_to_dst: Dict[int, int]) -> None:
  83. raise NotImplementedError(
  84. "Swap is not supported in OpenVINOCacheEngine.")
  85. def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
  86. self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts)
  87. @staticmethod
  88. def get_cache_block_size(
  89. block_size: int,
  90. cache_dtype: ov.Type,
  91. model_config: ModelConfig,
  92. parallel_config: ParallelConfig,
  93. ) -> int:
  94. head_size = model_config.get_head_size()
  95. num_kv_heads = model_config.get_num_kv_heads(parallel_config)
  96. num_layers = model_config.get_num_layers(parallel_config)
  97. if cache_dtype == ov.Type.u8:
  98. # Scale, zero point and quantized data will be stored together.
  99. # The layout for per token per head:
  100. # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
  101. # so, we have to extend head_size by 8, which is sizeof(float)
  102. # for scale and sizeof(float) for zeropoint
  103. head_size += 8
  104. key_cache_block = block_size * num_kv_heads * head_size
  105. value_cache_block = key_cache_block
  106. total = num_layers * (key_cache_block + value_cache_block)
  107. dtype_size = cache_dtype.size
  108. return dtype_size * total
  109. class OpenVINOWorker(LoraNotSupportedWorkerBase):
  110. """A worker class that executes the model on OpenVINO backend.
  111. Each worker is associated with a single OpenVINO device. The worker is
  112. responsible for maintaining the KV cache and executing the model on the
  113. OpenVINO backend.
  114. """
  115. def __init__(
  116. self,
  117. model_config: ModelConfig,
  118. parallel_config: ParallelConfig,
  119. scheduler_config: SchedulerConfig,
  120. device_config: DeviceConfig,
  121. cache_config: CacheConfig,
  122. load_config: LoadConfig,
  123. local_rank: int,
  124. rank: int,
  125. distributed_init_method: str,
  126. lora_config: Optional[LoRAConfig] = None,
  127. multimodal_config: Optional[MultiModalConfig] = None,
  128. kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
  129. is_driver_worker: bool = False,
  130. ) -> None:
  131. self.model_config = model_config
  132. self.parallel_config = parallel_config
  133. self.parallel_config.rank = rank
  134. self.scheduler_config = scheduler_config
  135. self.device_config = device_config
  136. self.cache_config = cache_config
  137. self.load_config = load_config
  138. self.local_rank = local_rank
  139. self.rank = rank
  140. self.distributed_init_method = distributed_init_method
  141. self.lora_config = lora_config
  142. self.multimodal_config = multimodal_config
  143. self.is_driver_worker = is_driver_worker
  144. if self.is_driver_worker:
  145. assert self.rank == 0, "The driver worker must have rank 0."
  146. if self.model_config.trust_remote_code:
  147. # note: lazy import to avoid importing torch before initializing
  148. from aphrodite.common.utils import init_cached_hf_modules
  149. init_cached_hf_modules()
  150. self.model_runner = OpenVINOModelRunner(
  151. model_config,
  152. parallel_config,
  153. scheduler_config,
  154. device_config,
  155. cache_config,
  156. load_config=self.load_config,
  157. lora_config=self.lora_config,
  158. multimodal_config=self.multimodal_config,
  159. kv_cache_dtype=kv_cache_dtype,
  160. is_driver_worker=is_driver_worker,
  161. )
  162. # Uninitialized cache engine. Will be initialized by
  163. # initialize_cache.
  164. self.cache_engine: OpenVINOCacheEngine
  165. self.kv_cache: List[Tuple[ov.Tensor, ov.Tensor]]
  166. def init_device(self) -> None:
  167. self.init_distributed_environment()
  168. # Set random seed.
  169. set_random_seed(self.model_config.seed)
  170. def load_model(self):
  171. self.model_runner.load_model()
  172. def determine_num_available_blocks(self) -> Tuple[int, int]:
  173. """Determine the number of blocks available for the KV cache.
  174. This determines how many KV blocks can fit into the configured
  175. KV cache space.
  176. Note that since Aphrodite assumes a block resides on GPU if it can be
  177. modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
  178. This allows us to reuse the scheduler of Aphrodite without generalizing
  179. it to different devices.
  180. """
  181. # For OpenVINO backend, the block number will be calculated based on the
  182. # openvino_kvcache_space_bytes.
  183. cache_block_size = self.get_cache_block_size_bytes()
  184. num_cpu_blocks = int(self.cache_config.openvino_kvcache_space_bytes //
  185. cache_block_size)
  186. num_cpu_blocks = max(num_cpu_blocks, 0)
  187. # Note: To reuse the cache management procedure,
  188. # use cpu cache as 'gpu cache'.
  189. num_gpu_blocks = num_cpu_blocks
  190. num_cpu_blocks = 0
  191. return num_gpu_blocks, num_cpu_blocks
  192. def initialize_cache(self, num_gpu_blocks: int,
  193. num_cpu_blocks: int) -> None:
  194. """Initialize the KV cache. Currently, swappable CPU memory is not
  195. supported.
  196. Since this worker does not support GPUs, we use the num_gpu_blocks to
  197. determine how many non-swappable CPU blocks to allocate.
  198. """
  199. assert (num_cpu_blocks == 0
  200. ), f"{type(self)} does not support swappable cache"
  201. # Note: To reuse the cache management procedure,
  202. # use cpu cache as 'gpu cache'.
  203. num_cpu_blocks = num_gpu_blocks
  204. self._validate_num_cpu_blocks(num_cpu_blocks)
  205. self.cache_config.num_gpu_blocks = num_cpu_blocks
  206. self.cache_config.num_cpu_blocks = 0
  207. # Initialize the cache.
  208. self._init_cache_engine()
  209. def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
  210. """Raise errors if the num_cpu_blocks is invalid."""
  211. if num_cpu_blocks <= 0:
  212. raise ValueError(
  213. "No available memory for the cache blocks. "
  214. "Try increasing `APHRODITE_OPENVINO_KVCACHE_SPACE` when "
  215. "initializing the engine.")
  216. max_seq_len = self.cache_config.block_size * num_cpu_blocks
  217. if self.model_config.max_model_len > max_seq_len:
  218. raise ValueError(
  219. f"The model's max seq len ({self.model_config.max_model_len}) "
  220. "is larger than the maximum number of tokens that can be "
  221. f"stored in KV cache ({max_seq_len}). Try increasing "
  222. "`APHRODITE_OPENVINO_KVCACHE_SPACE` or decreasing "
  223. "`max_model_len` when initializing the engine.")
  224. def _init_cache_engine(self) -> None:
  225. self.cache_engine = OpenVINOCacheEngine(
  226. self.cache_config,
  227. self.model_config,
  228. self.parallel_config,
  229. self.device_config,
  230. )
  231. self.kv_cache = self.cache_engine.kv_cache
  232. self.model_runner.block_size = self.cache_engine.block_size
  233. assert self.kv_cache is not None
  234. # Populate the cache to warmup the memory
  235. for key_cache, value_cache in self.kv_cache:
  236. key_cache.data[:] = 0
  237. value_cache.data[:] = 0
  238. def cache_copy(
  239. self,
  240. blocks_to_copy: List[Tuple[int, int]],
  241. ) -> None:
  242. self.cache_engine.copy(blocks_to_copy) # type: ignore
  243. @torch.inference_mode()
  244. def execute_model(
  245. self,
  246. execute_model_req: Optional[ExecuteModelRequest] = None,
  247. ) -> List[SamplerOutput]:
  248. if execute_model_req is None:
  249. seq_group_metadata_list = None
  250. else:
  251. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  252. if self.is_driver_worker:
  253. assert seq_group_metadata_list is not None
  254. num_seq_groups: int = len(seq_group_metadata_list)
  255. assert execute_model_req is not None
  256. blocks_to_copy = execute_model_req.blocks_to_copy
  257. assert len(execute_model_req.blocks_to_swap_in) == 0
  258. assert len(execute_model_req.blocks_to_swap_out) == 0
  259. data: Dict[str, Any] = {
  260. "num_seq_groups": num_seq_groups,
  261. "blocks_to_copy": execute_model_req.blocks_to_copy,
  262. }
  263. broadcast_tensor_dict(data, src=0)
  264. else:
  265. data = broadcast_tensor_dict(src=0)
  266. num_seq_groups = data["num_seq_groups"]
  267. blocks_to_copy = data["blocks_to_copy"]
  268. self.cache_copy(blocks_to_copy)
  269. # If there is no input, we don't need to execute the model.
  270. if num_seq_groups == 0:
  271. return []
  272. output = self.model_runner.execute_model(seq_group_metadata_list,
  273. self.kv_cache)
  274. # OpenVINO worker only supports single-step execution.
  275. return [output]
  276. def init_distributed_environment(self) -> None:
  277. """Initialize the distributed environment."""
  278. parallel_config = self.parallel_config
  279. rank = self.rank
  280. distributed_init_method = self.distributed_init_method
  281. init_distributed_environment(
  282. world_size=parallel_config.world_size,
  283. rank=rank,
  284. distributed_init_method=distributed_init_method,
  285. backend="gloo",
  286. )
  287. # A small all_reduce for warmup.
  288. torch.distributed.all_reduce(torch.zeros(1).cpu())
  289. ensure_model_parallel_initialized(
  290. parallel_config.tensor_parallel_size,
  291. parallel_config.pipeline_parallel_size,
  292. )
  293. def get_cache_block_size_bytes(self) -> int:
  294. """Return the size in bytes of a single KV cache block."""
  295. return OpenVINOCacheEngine.get_cache_block_size(
  296. self.cache_config.block_size,
  297. self.cache_config.cache_dtype,
  298. self.model_config,
  299. self.parallel_config,
  300. )