openvino.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # ruff: noqa: SIM117
  2. from pathlib import Path
  3. from typing import List, Optional, Tuple
  4. import openvino as ov
  5. import torch
  6. from huggingface_hub import HfApi
  7. from loguru import logger
  8. from openvino._offline_transformations import paged_attention_transformation
  9. from optimum.intel import OVModelForCausalLM
  10. from torch import nn
  11. import aphrodite.common.envs as envs
  12. from aphrodite.attention.backends.openvino import OpenVINOAttentionMetadata
  13. from aphrodite.common.config import DeviceConfig, ModelConfig
  14. from aphrodite.modeling.layers.logits_processor import (LogitsProcessor,
  15. _prune_hidden_states)
  16. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  17. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  18. APHRODITE_OPENVINO_ENABLE_QUANTIZED_WEIGHTS = (
  19. envs.APHRODITE_OPENVINO_ENABLE_QUANTIZED_WEIGHTS)
  20. def _flattenize_inputs(inputs):
  21. """
  22. Helper function for making nested inputs flattens
  23. """
  24. flatten_inputs = []
  25. for input_data in inputs:
  26. if input_data is None:
  27. continue
  28. if isinstance(input_data, (list, tuple)):
  29. flatten_inputs.extend(_flattenize_inputs(input_data))
  30. elif isinstance(input_data, dict):
  31. flatten_inputs.extend(_flattenize_inputs(list(
  32. input_data.values())))
  33. else:
  34. flatten_inputs.append(input_data)
  35. return flatten_inputs
  36. def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type,
  37. is_cpu: bool):
  38. # Apply hardware dependent modifications to KV tensors
  39. for parameter in model.get_parameters():
  40. input = parameter.get_output_tensor(0)
  41. input_names = input.get_names()
  42. if len(input_names) != 1:
  43. continue
  44. input_name = next(iter(input_names))
  45. shape = parameter.get_partial_shape()
  46. # use real block size if available, just a placeholder
  47. # to provide the expected rank
  48. x_size = 1
  49. num_blocks = ov.Dimension()
  50. block_size = ov.Dimension()
  51. head_size = ov.Dimension()
  52. # TODO: Negotiate required layout with plugins (CPU is ~OK, GPU is TBD),
  53. # pass more parameters to this function to set more static dimensions
  54. if input_name.startswith("key_cache."):
  55. cpu_shape = [num_blocks, shape[1], block_size, head_size]
  56. gpu_shape = [
  57. num_blocks,
  58. shape[1],
  59. shape[2].get_length() //
  60. x_size if shape[2].is_static else ov.Dimension(),
  61. block_size,
  62. x_size,
  63. ]
  64. elif input_name.startswith("value_cache."):
  65. cpu_shape = [num_blocks, shape[1], block_size, head_size]
  66. gpu_shape = [num_blocks, shape[1], shape[2], block_size]
  67. else:
  68. continue
  69. parameter.set_partial_shape(
  70. ov.PartialShape(cpu_shape if is_cpu else gpu_shape))
  71. parameter.set_element_type(kv_cache_dtype)
  72. model.validate_nodes_and_infer_types()
  73. def _require_model_export(model_id, revision=None, subfolder=None):
  74. model_dir = Path(model_id)
  75. if subfolder is not None:
  76. model_dir = model_dir / subfolder
  77. if model_dir.is_dir():
  78. return (not (model_dir / "openvino_model.xml").exists()
  79. or not (model_dir / "openvino_model.bin").exists())
  80. hf_api = HfApi()
  81. try:
  82. model_info = hf_api.model_info(model_id, revision=revision or "main")
  83. normalized_subfolder = (None if subfolder is None else
  84. Path(subfolder).as_posix())
  85. model_files = [
  86. file.rfilename for file in model_info.siblings
  87. if normalized_subfolder is None
  88. or file.rfilename.startswith(normalized_subfolder)
  89. ]
  90. ov_model_path = ("openvino_model.xml" if normalized_subfolder is None
  91. else f"{normalized_subfolder}/openvino_model.xml")
  92. return (ov_model_path not in model_files
  93. or ov_model_path.replace(".xml", ".bin") not in model_files)
  94. except Exception:
  95. return True
  96. class OpenVINOCasualLM(nn.Module):
  97. def __init__(
  98. self,
  99. model_config: ModelConfig,
  100. device_config: DeviceConfig,
  101. kv_cache_dtype: ov.Type,
  102. ) -> None:
  103. super().__init__()
  104. self.logits_processor = LogitsProcessor(
  105. model_config.hf_config.vocab_size, logits_as_input=True)
  106. self.sampler = Sampler()
  107. export = _require_model_export(model_config.model)
  108. if export:
  109. logger.warning(
  110. f"Provided model id {model_config.model} does not " # noqa: G004
  111. "contain OpenVINO IR, the model will be converted to IR with "
  112. "default options. If you need to use specific options for "
  113. "model conversion, use optimum-cli export openvino with "
  114. "desired options.")
  115. else:
  116. logger.warning(
  117. "OpenVINO IR is available for provided model id " # noqa: G004
  118. f"{model_config.model}. This IR will be used for inference "
  119. "as-is, all possible options that may affect model conversion "
  120. "are ignored.")
  121. load_in_8bit = APHRODITE_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
  122. pt_model = OVModelForCausalLM.from_pretrained(
  123. model_config.model,
  124. export=export,
  125. compile=False,
  126. load_in_8bit=load_in_8bit,
  127. trust_remote_code=model_config.trust_remote_code,
  128. )
  129. paged_attention_transformation(pt_model.model)
  130. _modify_cache_parameters(pt_model.model, kv_cache_dtype,
  131. device_config.device.type == "cpu")
  132. core = ov.Core()
  133. ov_compiled = core.compile_model(pt_model.model, "CPU")
  134. self.ov_request = ov_compiled.create_infer_request()
  135. def forward(
  136. self,
  137. input_ids: torch.Tensor,
  138. positions: torch.Tensor,
  139. kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
  140. attn_metadata: OpenVINOAttentionMetadata,
  141. ) -> torch.Tensor:
  142. flatten_kv_cache = _flattenize_inputs(kv_caches)
  143. inputs = [
  144. input_ids,
  145. positions,
  146. *flatten_kv_cache,
  147. attn_metadata.past_lens,
  148. attn_metadata.subsequence_begins,
  149. attn_metadata.block_indices,
  150. attn_metadata.block_indices_begins,
  151. attn_metadata.max_context_len,
  152. ]
  153. self.ov_request.start_async(inputs, share_inputs=True)
  154. self.ov_request.wait()
  155. logits = torch.from_numpy(self.ov_request.get_tensor("logits").data)
  156. # TODO: remove 'view' once OpenVINO PA will drop 'seq_len' dimension
  157. return logits.view(-1, logits.shape[-1])
  158. def compute_logits(
  159. self,
  160. hidden_states: torch.Tensor,
  161. sampling_metadata: SamplingMetadata,
  162. ) -> Optional[torch.Tensor]:
  163. hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
  164. logits = self.logits_processor(None, hidden_states, sampling_metadata)
  165. return logits
  166. def sample(
  167. self,
  168. logits: torch.Tensor,
  169. sampling_metadata: SamplingMetadata,
  170. ) -> Optional[SamplerOutput]:
  171. next_tokens = self.sampler(logits, sampling_metadata)
  172. return next_tokens
  173. def get_model(
  174. model_config: ModelConfig,
  175. device_config: DeviceConfig,
  176. kv_cache_dtype: ov.Type,
  177. **kwargs,
  178. ) -> torch.nn.Module:
  179. lora_config = kwargs.get("lora_config", None)
  180. if lora_config:
  181. raise ValueError(
  182. "OpenVINO modeling does not support LoRA, "
  183. "but LoRA is enabled. Support for this model may "
  184. "be added in the future. If this is important to you, "
  185. "please open an issue on github.")
  186. return OpenVINOCasualLM(model_config, device_config, kv_cache_dtype)