openvino_model_runner.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. from typing import List, Mapping, NamedTuple, Optional, Tuple
  2. import openvino as ov
  3. import torch
  4. from torch import nn
  5. from aphrodite.attention import get_attn_backend
  6. from aphrodite.attention.backends.openvino import OpenVINOAttentionMetadata
  7. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  8. LoRAConfig, ModelConfig, MultiModalConfig,
  9. ParallelConfig, SchedulerConfig)
  10. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  11. from aphrodite.modeling import SamplingMetadata
  12. from aphrodite.modeling.model_loader.openvino import get_model
  13. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
  14. MultiModalInputs)
  15. class ModelInput(NamedTuple):
  16. input_tokens: torch.Tensor
  17. input_positions: torch.Tensor
  18. attn_metadata: Optional[OpenVINOAttentionMetadata]
  19. seq_lens: List[int]
  20. query_lens: List[int]
  21. multi_modal_kwargs: Mapping[str, BatchedTensors]
  22. @classmethod
  23. def empty(cls, device):
  24. return ModelInput(input_tokens=torch.empty(0, device=device),
  25. input_positions=torch.empty(0, device=device),
  26. attn_metadata=None,
  27. seq_lens=[],
  28. query_lens=[],
  29. multi_modal_kwargs={})
  30. class OpenVINOModelRunner:
  31. def __init__(
  32. self,
  33. model_config: ModelConfig,
  34. parallel_config: ParallelConfig,
  35. scheduler_config: SchedulerConfig,
  36. device_config: DeviceConfig,
  37. cache_config: CacheConfig,
  38. load_config: LoadConfig,
  39. lora_config: Optional[LoRAConfig],
  40. multimodal_config: Optional[MultiModalConfig],
  41. kv_cache_dtype: Optional[str] = "auto",
  42. is_driver_worker: bool = False,
  43. *args,
  44. **kwargs,
  45. ):
  46. self.model_config = model_config
  47. self.parallel_config = parallel_config
  48. self.scheduler_config = scheduler_config
  49. self.device_config = device_config
  50. self.cache_config = cache_config
  51. self.lora_config = lora_config
  52. self.multimodal_config = multimodal_config
  53. self.load_config = load_config
  54. self.is_driver_worker = is_driver_worker
  55. self.device = self.device_config.device
  56. self.kv_cache_dtype = kv_cache_dtype
  57. self.sliding_window = model_config.get_sliding_window()
  58. self.block_size = cache_config.block_size
  59. self.attn_backend = get_attn_backend(
  60. self.model_config.get_num_attention_heads(self.parallel_config),
  61. self.model_config.get_head_size(),
  62. self.model_config.get_num_kv_heads(self.parallel_config),
  63. self.model_config.get_sliding_window(),
  64. self.model_config.dtype,
  65. self.kv_cache_dtype,
  66. self.block_size,
  67. )
  68. # Multi-modal data support
  69. self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
  70. .create_input_mapper(self.model_config)
  71. # Lazy initialization.
  72. self.model: nn.Module # Set after init_Model
  73. def load_model(self) -> None:
  74. self.model = get_model(
  75. model_config=self.model_config,
  76. device_config=self.device_config,
  77. kv_cache_dtype=self.kv_cache_dtype,
  78. )
  79. def _prepare_model_input(
  80. self,
  81. seq_group_metadata_list: List[SequenceGroupMetadata],
  82. ) -> ModelInput:
  83. """Prepare the model input based on a given sequence group.
  84. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  85. The result tensors and data structure also batches input in prefill
  86. -> decode order. For example,
  87. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  88. - input_tokens[num_prefill_tokens:] contains decode tokens.
  89. """
  90. input_tokens: List[int] = []
  91. input_positions: List[int] = []
  92. seq_lens: List[int] = []
  93. past_lens: List[int] = []
  94. query_lens: List[int] = []
  95. multi_modal_inputs_list: List[MultiModalInputs] = []
  96. subsequence_begins: List[int] = []
  97. block_indices: List[int] = []
  98. block_indices_begins: List[int] = []
  99. # initialize beginning of prefix sums
  100. subsequence_begins.append(0)
  101. block_indices_begins.append(0)
  102. if len(seq_group_metadata_list) == 0:
  103. return ModelInput.empty(self.device)
  104. for seq_group_metadata in seq_group_metadata_list:
  105. seq_ids = list(seq_group_metadata.seq_data.keys())
  106. is_prompt = seq_group_metadata.is_prompt
  107. for seq_id in seq_ids:
  108. computed_block_nums = seq_group_metadata.computed_block_nums
  109. if (self.scheduler_config is not None
  110. and self.scheduler_config.chunked_prefill_enabled
  111. and not (computed_block_nums is None
  112. or computed_block_nums == [])):
  113. raise RuntimeError(
  114. "chunked prefill cannot be used with prefix caching "
  115. "now.")
  116. seq_data = seq_group_metadata.seq_data[seq_id]
  117. if is_prompt:
  118. computed_len = seq_data.get_num_computed_tokens()
  119. else:
  120. # get_num_computed_tokens is incorrect for spec decoding.
  121. # So, we should have a special logic here.
  122. # TODO: Fix it.
  123. computed_len = seq_data.get_len() - 1
  124. seq_len = min(
  125. seq_data.get_len(),
  126. computed_len + seq_group_metadata.token_chunk_size,
  127. )
  128. if is_prompt:
  129. tokens = seq_data.get_token_ids()[computed_len:seq_len]
  130. else:
  131. # Optimization. get_token_ids requires the entire copy of
  132. # tokens.
  133. tokens = [seq_data.get_last_token_id()]
  134. # Prefix cache was hit.
  135. # Prefix is not supported with sliding_window
  136. prefix_cache_hit = (computed_block_nums is not None
  137. and len(computed_block_nums) > 0
  138. and self.sliding_window is None
  139. and is_prompt)
  140. mm_data = seq_group_metadata.multi_modal_data
  141. if mm_data:
  142. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  143. multi_modal_inputs_list.append(mm_kwargs)
  144. block_table = seq_group_metadata.block_tables[seq_id]
  145. # TODO: Combine chunked prefill and prefix caching by
  146. # only allowing multiple of block_size chunk size.
  147. # NOTE: This only works for oooooooxxx style attention.
  148. if prefix_cache_hit:
  149. assert computed_block_nums is not None
  150. computed_len = len(computed_block_nums) * self.block_size
  151. tokens = tokens[computed_len:]
  152. elif (self.scheduler_config.chunked_prefill_enabled
  153. or not is_prompt):
  154. if seq_group_metadata.block_tables is not None:
  155. # chunked prefill or decode
  156. block_table = seq_group_metadata.block_tables[seq_id]
  157. if self.sliding_window is not None:
  158. # chunked prefill doesn't support sliding window.
  159. assert not self.scheduler_config.chunked_prefill_enabled # noqa: E501
  160. sliding_window_blocks = (self.sliding_window //
  161. self.block_size)
  162. block_table = block_table[-sliding_window_blocks:]
  163. else:
  164. # Only happens when memory profiling runs.
  165. block_table = []
  166. else:
  167. # prompt phase w/o prefix_caching, chunked_prefill
  168. pass
  169. block_indices.extend(block_table)
  170. block_indices_begins.append(block_indices_begins[-1] +
  171. len(block_table))
  172. # TODO: This is a hack to make sliding window work with
  173. # paged attn. We can remove it if we make paged attn kernel
  174. # to properly handle slinding window attn.
  175. if self.sliding_window is not None and not is_prompt:
  176. seq_len = min(seq_len, self.sliding_window)
  177. computed_len = seq_len - 1
  178. seq_lens.append(seq_len)
  179. query_len = seq_len - computed_len
  180. query_lens.append(query_len)
  181. input_tokens.extend(tokens)
  182. input_positions.extend(list(range(computed_len, seq_len)))
  183. past_lens.append(computed_len)
  184. subsequence_begins.append(subsequence_begins[-1] + query_len)
  185. if is_prompt:
  186. assert len(seq_ids) == 1
  187. else:
  188. assert (
  189. query_len == 1
  190. ), "seq_len: {}, computed_len: {}, query_len: {}".format(
  191. seq_len, computed_len, query_len)
  192. max_query_len = max(query_lens)
  193. assert max_query_len > 0, "query_lens: {}".format(query_lens)
  194. input_tokens = torch.tensor(input_tokens,
  195. dtype=torch.long,
  196. device=self.device) # type: ignore
  197. input_positions = torch.tensor(input_positions,
  198. dtype=torch.long,
  199. device=self.device) # type: ignore
  200. past_lens_tensor = torch.tensor(past_lens,
  201. dtype=torch.int32,
  202. device=self.device) # type: ignore
  203. subsequence_begins_tensor = torch.tensor(
  204. subsequence_begins, dtype=torch.int32,
  205. device=self.device) # type: ignore
  206. block_indices_tensor = torch.tensor(block_indices,
  207. dtype=torch.int32,
  208. device=self.device) # type: ignore
  209. block_indices_begins_tensor = torch.tensor(
  210. block_indices_begins, dtype=torch.int32,
  211. device=self.device) # type: ignore
  212. max_context_len = max(seq_lens)
  213. max_context_len_tensor = torch.tensor(
  214. max_context_len, dtype=torch.int32,
  215. device=self.device) # type: ignore
  216. attn_metadata = self.attn_backend.make_openvino_metadata(
  217. past_lens=past_lens_tensor,
  218. subsequence_begins=subsequence_begins_tensor,
  219. block_indices=block_indices_tensor,
  220. block_indices_begins=block_indices_begins_tensor,
  221. max_context_len=max_context_len_tensor,
  222. )
  223. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
  224. device=self.device)
  225. return ModelInput(
  226. input_tokens,
  227. input_positions,
  228. attn_metadata,
  229. seq_lens,
  230. query_lens,
  231. multi_modal_kwargs=multi_modal_kwargs,
  232. )
  233. def prepare_input_tensors(
  234. self,
  235. seq_group_metadata_list: List[SequenceGroupMetadata],
  236. ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
  237. SamplingMetadata, Mapping[str, BatchedTensors]]:
  238. # Prepare input tensors.
  239. (
  240. input_tokens,
  241. input_positions,
  242. attn_metadata,
  243. seq_lens,
  244. query_lens,
  245. multi_modal_kwargs,
  246. ) = self._prepare_model_input(seq_group_metadata_list)
  247. sampling_metadata = SamplingMetadata.prepare(
  248. seq_group_metadata_list,
  249. seq_lens,
  250. query_lens,
  251. self.device,
  252. pin_memory=False,
  253. )
  254. return (
  255. input_tokens,
  256. input_positions,
  257. attn_metadata,
  258. sampling_metadata,
  259. multi_modal_kwargs,
  260. )
  261. @torch.inference_mode()
  262. def execute_model(
  263. self,
  264. seq_group_metadata_list: List[SequenceGroupMetadata],
  265. kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]],
  266. ) -> Optional[SamplerOutput]:
  267. (
  268. input_tokens,
  269. input_positions,
  270. attn_metadata,
  271. sampling_metadata,
  272. multi_modal_kwargs,
  273. ) = self.prepare_input_tensors(seq_group_metadata_list)
  274. model_executable = self.model
  275. execute_model_kwargs = {
  276. "input_ids": input_tokens,
  277. "positions": input_positions,
  278. "kv_caches": kv_caches,
  279. "attn_metadata": attn_metadata,
  280. **(multi_modal_kwargs or {}),
  281. }
  282. hidden_states = model_executable(**execute_model_kwargs)
  283. # Compute the logits.
  284. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  285. # Sample the next token.
  286. output = self.model.sample(
  287. logits=logits,
  288. sampling_metadata=sampling_metadata,
  289. )
  290. return output