openvino.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # ruff: noqa: SIM117
  2. import os
  3. from pathlib import Path
  4. from typing import List, Optional, Tuple
  5. import openvino as ov
  6. import torch
  7. from huggingface_hub import HfApi
  8. from loguru import logger
  9. from openvino._offline_transformations import paged_attention_transformation
  10. from optimum.intel import OVModelForCausalLM
  11. from torch import nn
  12. from aphrodite.attention.backends.openvino import OpenVINOAttentionMetadata
  13. from aphrodite.common.config import DeviceConfig, ModelConfig
  14. from aphrodite.common.sequence import SamplerOutput
  15. from aphrodite.modeling.layers.logits_processor import (LogitsProcessor,
  16. _prune_hidden_states)
  17. from aphrodite.modeling.layers.sampler import Sampler
  18. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  19. APHRODITE_OPENVINO_ENABLE_QUANTIZED_WEIGHTS = bool(
  20. os.getenv("APHRODITE_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False))
  21. def _flattenize_inputs(inputs):
  22. """
  23. Helper function for making nested inputs flattens
  24. """
  25. flatten_inputs = []
  26. for input_data in inputs:
  27. if input_data is None:
  28. continue
  29. if isinstance(input_data, (list, tuple)):
  30. flatten_inputs.extend(_flattenize_inputs(input_data))
  31. elif isinstance(input_data, dict):
  32. flatten_inputs.extend(_flattenize_inputs(list(
  33. input_data.values())))
  34. else:
  35. flatten_inputs.append(input_data)
  36. return flatten_inputs
  37. def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type,
  38. is_cpu: bool):
  39. # Apply hardware dependent modifications to KV tensors
  40. for parameter in model.get_parameters():
  41. input = parameter.get_output_tensor(0)
  42. input_names = input.get_names()
  43. if len(input_names) != 1:
  44. continue
  45. input_name = next(iter(input_names))
  46. shape = parameter.get_partial_shape()
  47. # use real block size if available, just a placeholder
  48. # to provide the expected rank
  49. x_size = 1
  50. num_blocks = ov.Dimension()
  51. block_size = ov.Dimension()
  52. head_size = ov.Dimension()
  53. # TODO: Negotiate required layout with plugins (CPU is ~OK, GPU is TBD),
  54. # pass more parameters to this function to set more static dimensions
  55. if input_name.startswith("key_cache."):
  56. cpu_shape = [num_blocks, shape[1], block_size, head_size]
  57. gpu_shape = [
  58. num_blocks,
  59. shape[1],
  60. shape[2].get_length() //
  61. x_size if shape[2].is_static else ov.Dimension(),
  62. block_size,
  63. x_size,
  64. ]
  65. elif input_name.startswith("value_cache."):
  66. cpu_shape = [num_blocks, shape[1], block_size, head_size]
  67. gpu_shape = [num_blocks, shape[1], shape[2], block_size]
  68. else:
  69. continue
  70. parameter.set_partial_shape(
  71. ov.PartialShape(cpu_shape if is_cpu else gpu_shape))
  72. parameter.set_element_type(kv_cache_dtype)
  73. model.validate_nodes_and_infer_types()
  74. def _require_model_export(model_id, revision=None, subfolder=None):
  75. model_dir = Path(model_id)
  76. if subfolder is not None:
  77. model_dir = model_dir / subfolder
  78. if model_dir.is_dir():
  79. return (not (model_dir / "openvino_model.xml").exists()
  80. or not (model_dir / "openvino_model.bin").exists())
  81. hf_api = HfApi()
  82. try:
  83. model_info = hf_api.model_info(model_id, revision=revision or "main")
  84. normalized_subfolder = (None if subfolder is None else
  85. Path(subfolder).as_posix())
  86. model_files = [
  87. file.rfilename for file in model_info.siblings
  88. if normalized_subfolder is None
  89. or file.rfilename.startswith(normalized_subfolder)
  90. ]
  91. ov_model_path = ("openvino_model.xml" if normalized_subfolder is None
  92. else f"{normalized_subfolder}/openvino_model.xml")
  93. return (ov_model_path not in model_files
  94. or ov_model_path.replace(".xml", ".bin") not in model_files)
  95. except Exception:
  96. return True
  97. class OpenVINOCasualLM(nn.Module):
  98. def __init__(
  99. self,
  100. model_config: ModelConfig,
  101. device_config: DeviceConfig,
  102. kv_cache_dtype: ov.Type,
  103. ) -> None:
  104. super().__init__()
  105. self.logits_processor = LogitsProcessor(
  106. model_config.hf_config.vocab_size, logits_as_input=True)
  107. self.sampler = Sampler()
  108. export = _require_model_export(model_config.model)
  109. if export:
  110. logger.warning(
  111. f"Provided model id {model_config.model} does not " # noqa: G004
  112. "contain OpenVINO IR, the model will be converted to IR with "
  113. "default options. If you need to use specific options for "
  114. "model conversion, use optimum-cli export openvino with "
  115. "desired options.")
  116. else:
  117. logger.warning(
  118. "OpenVINO IR is available for provided model id " # noqa: G004
  119. f"{model_config.model}. This IR will be used for inference "
  120. "as-is, all possible options that may affect model conversion "
  121. "are ignored.")
  122. load_in_8bit = APHRODITE_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
  123. pt_model = OVModelForCausalLM.from_pretrained(
  124. model_config.model,
  125. export=export,
  126. compile=False,
  127. load_in_8bit=load_in_8bit,
  128. trust_remote_code=model_config.trust_remote_code,
  129. )
  130. paged_attention_transformation(pt_model.model)
  131. _modify_cache_parameters(pt_model.model, kv_cache_dtype,
  132. device_config.device.type == "cpu")
  133. core = ov.Core()
  134. ov_compiled = core.compile_model(pt_model.model, "CPU")
  135. self.ov_request = ov_compiled.create_infer_request()
  136. def forward(
  137. self,
  138. input_ids: torch.Tensor,
  139. positions: torch.Tensor,
  140. kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
  141. attn_metadata: OpenVINOAttentionMetadata,
  142. ) -> torch.Tensor:
  143. flatten_kv_cache = _flattenize_inputs(kv_caches)
  144. inputs = [
  145. input_ids,
  146. positions,
  147. *flatten_kv_cache,
  148. attn_metadata.past_lens,
  149. attn_metadata.subsequence_begins,
  150. attn_metadata.block_indices,
  151. attn_metadata.block_indices_begins,
  152. attn_metadata.max_context_len,
  153. ]
  154. self.ov_request.start_async(inputs, share_inputs=True)
  155. self.ov_request.wait()
  156. logits = torch.from_numpy(self.ov_request.get_tensor("logits").data)
  157. # TODO: remove 'view' once OpenVINO PA will drop 'seq_len' dimension
  158. return logits.view(-1, logits.shape[-1])
  159. def compute_logits(
  160. self,
  161. hidden_states: torch.Tensor,
  162. sampling_metadata: SamplingMetadata,
  163. ) -> Optional[torch.Tensor]:
  164. hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
  165. logits = self.logits_processor(None, hidden_states, sampling_metadata)
  166. return logits
  167. def sample(
  168. self,
  169. logits: torch.Tensor,
  170. sampling_metadata: SamplingMetadata,
  171. ) -> Optional[SamplerOutput]:
  172. next_tokens = self.sampler(logits, sampling_metadata)
  173. return next_tokens
  174. def get_model(
  175. model_config: ModelConfig,
  176. device_config: DeviceConfig,
  177. kv_cache_dtype: ov.Type,
  178. **kwargs,
  179. ) -> torch.nn.Module:
  180. lora_config = kwargs.get("lora_config", None)
  181. if lora_config:
  182. raise ValueError(
  183. "OpenVINO modeling does not support LoRA, "
  184. "but LoRA is enabled. Support for this model may "
  185. "be added in the future. If this is important to you, "
  186. "please open an issue on github.")
  187. return OpenVINOCasualLM(model_config, device_config, kv_cache_dtype)