openvino.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # ruff: noqa: SIM117
  2. from pathlib import Path
  3. from typing import List, Optional, Tuple
  4. import openvino as ov
  5. import torch
  6. from huggingface_hub import HfApi
  7. from loguru import logger
  8. from openvino._offline_transformations import paged_attention_transformation
  9. from optimum.intel import OVModelForCausalLM
  10. from torch import nn
  11. import aphrodite.common.envs as envs
  12. from aphrodite.attention.backends.openvino import OpenVINOAttentionMetadata
  13. from aphrodite.common.config import DeviceConfig, ModelConfig
  14. from aphrodite.common.sequence import SamplerOutput
  15. from aphrodite.modeling.layers.logits_processor import (LogitsProcessor,
  16. _prune_hidden_states)
  17. from aphrodite.modeling.layers.sampler import Sampler
  18. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  19. APHRODITE_OPENVINO_ENABLE_QUANTIZED_WEIGHTS = (
  20. envs.APHRODITE_OPENVINO_ENABLE_QUANTIZED_WEIGHTS)
  21. def _flattenize_inputs(inputs):
  22. """
  23. Helper function for making nested inputs flattens
  24. """
  25. flatten_inputs = []
  26. for input_data in inputs:
  27. if input_data is None:
  28. continue
  29. if isinstance(input_data, (list, tuple)):
  30. flatten_inputs.extend(_flattenize_inputs(input_data))
  31. elif isinstance(input_data, dict):
  32. flatten_inputs.extend(_flattenize_inputs(list(
  33. input_data.values())))
  34. else:
  35. flatten_inputs.append(input_data)
  36. return flatten_inputs
  37. def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type,
  38. is_cpu: bool):
  39. # Apply hardware dependent modifications to KV tensors
  40. for parameter in model.get_parameters():
  41. input = parameter.get_output_tensor(0)
  42. input_names = input.get_names()
  43. if len(input_names) != 1:
  44. continue
  45. input_name = next(iter(input_names))
  46. shape = parameter.get_partial_shape()
  47. # use real block size if available, just a placeholder
  48. # to provide the expected rank
  49. x_size = 1
  50. num_blocks = ov.Dimension()
  51. block_size = ov.Dimension()
  52. head_size = ov.Dimension()
  53. # TODO: Negotiate required layout with plugins (CPU is ~OK, GPU is TBD),
  54. # pass more parameters to this function to set more static dimensions
  55. if input_name.startswith("key_cache."):
  56. cpu_shape = [num_blocks, shape[1], block_size, head_size]
  57. gpu_shape = [
  58. num_blocks,
  59. shape[1],
  60. shape[2].get_length() //
  61. x_size if shape[2].is_static else ov.Dimension(),
  62. block_size,
  63. x_size,
  64. ]
  65. elif input_name.startswith("value_cache."):
  66. cpu_shape = [num_blocks, shape[1], block_size, head_size]
  67. gpu_shape = [num_blocks, shape[1], shape[2], block_size]
  68. else:
  69. continue
  70. parameter.set_partial_shape(
  71. ov.PartialShape(cpu_shape if is_cpu else gpu_shape))
  72. parameter.set_element_type(kv_cache_dtype)
  73. model.validate_nodes_and_infer_types()
  74. def _require_model_export(model_id, revision=None, subfolder=None):
  75. model_dir = Path(model_id)
  76. if subfolder is not None:
  77. model_dir = model_dir / subfolder
  78. if model_dir.is_dir():
  79. return (not (model_dir / "openvino_model.xml").exists()
  80. or not (model_dir / "openvino_model.bin").exists())
  81. hf_api = HfApi()
  82. try:
  83. model_info = hf_api.model_info(model_id, revision=revision or "main")
  84. normalized_subfolder = (None if subfolder is None else
  85. Path(subfolder).as_posix())
  86. model_files = [
  87. file.rfilename for file in model_info.siblings
  88. if normalized_subfolder is None
  89. or file.rfilename.startswith(normalized_subfolder)
  90. ]
  91. ov_model_path = ("openvino_model.xml" if normalized_subfolder is None
  92. else f"{normalized_subfolder}/openvino_model.xml")
  93. return (ov_model_path not in model_files
  94. or ov_model_path.replace(".xml", ".bin") not in model_files)
  95. except Exception:
  96. return True
  97. class OpenVINOCasualLM(nn.Module):
  98. def __init__(
  99. self,
  100. model_config: ModelConfig,
  101. device_config: DeviceConfig,
  102. kv_cache_dtype: ov.Type,
  103. ) -> None:
  104. super().__init__()
  105. self.logits_processor = LogitsProcessor(
  106. model_config.hf_config.vocab_size, logits_as_input=True)
  107. self.sampler = Sampler()
  108. export = _require_model_export(model_config.model)
  109. if export:
  110. logger.warning(
  111. f"Provided model id {model_config.model} does not " # noqa: G004
  112. "contain OpenVINO IR, the model will be converted to IR with "
  113. "default options. If you need to use specific options for "
  114. "model conversion, use optimum-cli export openvino with "
  115. "desired options.")
  116. else:
  117. logger.warning(
  118. "OpenVINO IR is available for provided model id " # noqa: G004
  119. f"{model_config.model}. This IR will be used for inference "
  120. "as-is, all possible options that may affect model conversion "
  121. "are ignored.")
  122. load_in_8bit = APHRODITE_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
  123. pt_model = OVModelForCausalLM.from_pretrained(
  124. model_config.model,
  125. export=export,
  126. compile=False,
  127. load_in_8bit=load_in_8bit,
  128. trust_remote_code=model_config.trust_remote_code,
  129. )
  130. paged_attention_transformation(pt_model.model)
  131. _modify_cache_parameters(pt_model.model, kv_cache_dtype,
  132. device_config.device.type == "cpu")
  133. core = ov.Core()
  134. ov_compiled = core.compile_model(pt_model.model, "CPU")
  135. self.ov_request = ov_compiled.create_infer_request()
  136. def forward(
  137. self,
  138. input_ids: torch.Tensor,
  139. positions: torch.Tensor,
  140. kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
  141. attn_metadata: OpenVINOAttentionMetadata,
  142. ) -> torch.Tensor:
  143. flatten_kv_cache = _flattenize_inputs(kv_caches)
  144. inputs = [
  145. input_ids,
  146. positions,
  147. *flatten_kv_cache,
  148. attn_metadata.past_lens,
  149. attn_metadata.subsequence_begins,
  150. attn_metadata.block_indices,
  151. attn_metadata.block_indices_begins,
  152. attn_metadata.max_context_len,
  153. ]
  154. self.ov_request.start_async(inputs, share_inputs=True)
  155. self.ov_request.wait()
  156. logits = torch.from_numpy(self.ov_request.get_tensor("logits").data)
  157. # TODO: remove 'view' once OpenVINO PA will drop 'seq_len' dimension
  158. return logits.view(-1, logits.shape[-1])
  159. def compute_logits(
  160. self,
  161. hidden_states: torch.Tensor,
  162. sampling_metadata: SamplingMetadata,
  163. ) -> Optional[torch.Tensor]:
  164. hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
  165. logits = self.logits_processor(None, hidden_states, sampling_metadata)
  166. return logits
  167. def sample(
  168. self,
  169. logits: torch.Tensor,
  170. sampling_metadata: SamplingMetadata,
  171. ) -> Optional[SamplerOutput]:
  172. next_tokens = self.sampler(logits, sampling_metadata)
  173. return next_tokens
  174. def get_model(
  175. model_config: ModelConfig,
  176. device_config: DeviceConfig,
  177. kv_cache_dtype: ov.Type,
  178. **kwargs,
  179. ) -> torch.nn.Module:
  180. lora_config = kwargs.get("lora_config", None)
  181. if lora_config:
  182. raise ValueError(
  183. "OpenVINO modeling does not support LoRA, "
  184. "but LoRA is enabled. Support for this model may "
  185. "be added in the future. If this is important to you, "
  186. "please open an issue on github.")
  187. return OpenVINOCasualLM(model_config, device_config, kv_cache_dtype)