openvino_model_runner.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. from typing import List, NamedTuple, Optional, Tuple
  2. import openvino as ov
  3. import torch
  4. from torch import nn
  5. from aphrodite.attention import get_attn_backend
  6. from aphrodite.attention.backends.openvino import OpenVINOAttentionMetadata
  7. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  8. LoRAConfig, ModelConfig, MultiModalConfig,
  9. ParallelConfig, SchedulerConfig)
  10. from aphrodite.common.sequence import SequenceGroupMetadata
  11. from aphrodite.modeling.layers.sampler import SamplerOutput
  12. from aphrodite.modeling.model_loader.openvino import get_model
  13. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  14. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
  15. MultiModalInputs)
  16. class ModelInput(NamedTuple):
  17. input_tokens: torch.Tensor
  18. input_positions: torch.Tensor
  19. attn_metadata: Optional[OpenVINOAttentionMetadata]
  20. seq_lens: List[int]
  21. query_lens: List[int]
  22. multi_modal_kwargs: BatchedTensorInputs
  23. @classmethod
  24. def empty(cls, device):
  25. return ModelInput(input_tokens=torch.empty(0, device=device),
  26. input_positions=torch.empty(0, device=device),
  27. attn_metadata=None,
  28. seq_lens=[],
  29. query_lens=[],
  30. multi_modal_kwargs={})
  31. class OpenVINOModelRunner:
  32. def __init__(
  33. self,
  34. model_config: ModelConfig,
  35. parallel_config: ParallelConfig,
  36. scheduler_config: SchedulerConfig,
  37. device_config: DeviceConfig,
  38. cache_config: CacheConfig,
  39. load_config: LoadConfig,
  40. lora_config: Optional[LoRAConfig],
  41. multimodal_config: Optional[MultiModalConfig],
  42. kv_cache_dtype: Optional[str] = "auto",
  43. is_driver_worker: bool = False,
  44. *args,
  45. **kwargs,
  46. ):
  47. self.model_config = model_config
  48. self.parallel_config = parallel_config
  49. self.scheduler_config = scheduler_config
  50. self.device_config = device_config
  51. self.cache_config = cache_config
  52. self.lora_config = lora_config
  53. self.multimodal_config = multimodal_config
  54. self.load_config = load_config
  55. self.is_driver_worker = is_driver_worker
  56. self.device = self.device_config.device
  57. self.kv_cache_dtype = kv_cache_dtype
  58. self.sliding_window = model_config.get_sliding_window()
  59. self.block_size = cache_config.block_size
  60. self.attn_backend = get_attn_backend(
  61. self.model_config.get_head_size(),
  62. self.model_config.get_sliding_window(),
  63. self.model_config.dtype,
  64. self.kv_cache_dtype,
  65. self.block_size,
  66. )
  67. # Multi-modal data support
  68. self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
  69. .create_input_mapper(self.model_config)
  70. # Lazy initialization.
  71. self.model: nn.Module # Set after init_Model
  72. def load_model(self) -> None:
  73. self.model = get_model(
  74. model_config=self.model_config,
  75. device_config=self.device_config,
  76. kv_cache_dtype=self.kv_cache_dtype,
  77. )
  78. def _prepare_model_input(
  79. self,
  80. seq_group_metadata_list: List[SequenceGroupMetadata],
  81. ) -> ModelInput:
  82. """Prepare the model input based on a given sequence group.
  83. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  84. The result tensors and data structure also batches input in prefill
  85. -> decode order. For example,
  86. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  87. - input_tokens[num_prefill_tokens:] contains decode tokens.
  88. """
  89. input_tokens: List[int] = []
  90. input_positions: List[int] = []
  91. seq_lens: List[int] = []
  92. past_lens: List[int] = []
  93. query_lens: List[int] = []
  94. multi_modal_inputs_list: List[MultiModalInputs] = []
  95. subsequence_begins: List[int] = []
  96. block_indices: List[int] = []
  97. block_indices_begins: List[int] = []
  98. # initialize beginning of prefix sums
  99. subsequence_begins.append(0)
  100. block_indices_begins.append(0)
  101. if len(seq_group_metadata_list) == 0:
  102. return ModelInput.empty(self.device)
  103. for seq_group_metadata in seq_group_metadata_list:
  104. seq_ids = list(seq_group_metadata.seq_data.keys())
  105. is_prompt = seq_group_metadata.is_prompt
  106. for seq_id in seq_ids:
  107. computed_block_nums = seq_group_metadata.computed_block_nums
  108. if (self.scheduler_config is not None
  109. and self.scheduler_config.chunked_prefill_enabled
  110. and not (computed_block_nums is None
  111. or computed_block_nums == [])):
  112. raise RuntimeError(
  113. "chunked prefill cannot be used with prefix caching "
  114. "now.")
  115. seq_data = seq_group_metadata.seq_data[seq_id]
  116. if is_prompt:
  117. computed_len = seq_data.get_num_computed_tokens()
  118. else:
  119. # get_num_computed_tokens is incorrect for spec decoding.
  120. # So, we should have a special logic here.
  121. # TODO: Fix it.
  122. computed_len = seq_data.get_len() - 1
  123. seq_len = min(
  124. seq_data.get_len(),
  125. computed_len + seq_group_metadata.token_chunk_size,
  126. )
  127. if is_prompt:
  128. tokens = seq_data.get_token_ids()[computed_len:seq_len]
  129. else:
  130. # Optimization. get_token_ids requires the entire copy of
  131. # tokens.
  132. tokens = [seq_data.get_last_token_id()]
  133. # Prefix cache was hit.
  134. # Prefix is not supported with sliding_window
  135. prefix_cache_hit = (computed_block_nums is not None
  136. and len(computed_block_nums) > 0
  137. and self.sliding_window is None
  138. and is_prompt)
  139. mm_data = seq_group_metadata.multi_modal_data
  140. if mm_data:
  141. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  142. multi_modal_inputs_list.append(mm_kwargs)
  143. block_table = seq_group_metadata.block_tables[seq_id]
  144. # TODO: Combine chunked prefill and prefix caching by
  145. # only allowing multiple of block_size chunk size.
  146. # NOTE: This only works for oooooooxxx style attention.
  147. if prefix_cache_hit:
  148. assert computed_block_nums is not None
  149. computed_len = len(computed_block_nums) * self.block_size
  150. tokens = tokens[computed_len:]
  151. elif (self.scheduler_config.chunked_prefill_enabled
  152. or not is_prompt):
  153. if seq_group_metadata.block_tables is not None:
  154. # chunked prefill or decode
  155. block_table = seq_group_metadata.block_tables[seq_id]
  156. if self.sliding_window is not None:
  157. # chunked prefill doesn't support sliding window.
  158. assert not self.scheduler_config.chunked_prefill_enabled # noqa: E501
  159. sliding_window_blocks = (self.sliding_window //
  160. self.block_size)
  161. block_table = block_table[-sliding_window_blocks:]
  162. else:
  163. # Only happens when memory profiling runs.
  164. block_table = []
  165. else:
  166. # prompt phase w/o prefix_caching, chunked_prefill
  167. pass
  168. block_indices.extend(block_table)
  169. block_indices_begins.append(block_indices_begins[-1] +
  170. len(block_table))
  171. # TODO: This is a hack to make sliding window work with
  172. # paged attn. We can remove it if we make paged attn kernel
  173. # to properly handle slinding window attn.
  174. if self.sliding_window is not None and not is_prompt:
  175. seq_len = min(seq_len, self.sliding_window)
  176. computed_len = seq_len - 1
  177. seq_lens.append(seq_len)
  178. query_len = seq_len - computed_len
  179. query_lens.append(query_len)
  180. input_tokens.extend(tokens)
  181. input_positions.extend(list(range(computed_len, seq_len)))
  182. past_lens.append(computed_len)
  183. subsequence_begins.append(subsequence_begins[-1] + query_len)
  184. if is_prompt:
  185. assert len(seq_ids) == 1
  186. else:
  187. assert (
  188. query_len == 1
  189. ), "seq_len: {}, computed_len: {}, query_len: {}".format(
  190. seq_len, computed_len, query_len)
  191. max_query_len = max(query_lens)
  192. assert max_query_len > 0, "query_lens: {}".format(query_lens)
  193. input_tokens = torch.tensor(input_tokens,
  194. dtype=torch.long,
  195. device=self.device) # type: ignore
  196. input_positions = torch.tensor(input_positions,
  197. dtype=torch.long,
  198. device=self.device) # type: ignore
  199. past_lens_tensor = torch.tensor(past_lens,
  200. dtype=torch.int32,
  201. device=self.device) # type: ignore
  202. subsequence_begins_tensor = torch.tensor(
  203. subsequence_begins, dtype=torch.int32,
  204. device=self.device) # type: ignore
  205. block_indices_tensor = torch.tensor(block_indices,
  206. dtype=torch.int32,
  207. device=self.device) # type: ignore
  208. block_indices_begins_tensor = torch.tensor(
  209. block_indices_begins, dtype=torch.int32,
  210. device=self.device) # type: ignore
  211. max_context_len = max(seq_lens)
  212. max_context_len_tensor = torch.tensor(
  213. max_context_len, dtype=torch.int32,
  214. device=self.device) # type: ignore
  215. attn_metadata = self.attn_backend.make_openvino_metadata(
  216. past_lens=past_lens_tensor,
  217. subsequence_begins=subsequence_begins_tensor,
  218. block_indices=block_indices_tensor,
  219. block_indices_begins=block_indices_begins_tensor,
  220. max_context_len=max_context_len_tensor,
  221. )
  222. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
  223. return ModelInput(
  224. input_tokens,
  225. input_positions,
  226. attn_metadata,
  227. seq_lens,
  228. query_lens,
  229. multi_modal_kwargs=multi_modal_kwargs,
  230. )
  231. def prepare_input_tensors(
  232. self,
  233. seq_group_metadata_list: List[SequenceGroupMetadata],
  234. ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
  235. SamplingMetadata, BatchedTensorInputs]:
  236. # Prepare input tensors.
  237. (
  238. input_tokens,
  239. input_positions,
  240. attn_metadata,
  241. seq_lens,
  242. query_lens,
  243. multi_modal_kwargs,
  244. ) = self._prepare_model_input(seq_group_metadata_list)
  245. sampling_metadata = SamplingMetadata.prepare(
  246. seq_group_metadata_list,
  247. seq_lens,
  248. query_lens,
  249. self.device,
  250. pin_memory=False,
  251. )
  252. return (
  253. input_tokens,
  254. input_positions,
  255. attn_metadata,
  256. sampling_metadata,
  257. multi_modal_kwargs,
  258. )
  259. @torch.inference_mode()
  260. def execute_model(
  261. self,
  262. seq_group_metadata_list: List[SequenceGroupMetadata],
  263. kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]],
  264. ) -> Optional[SamplerOutput]:
  265. (
  266. input_tokens,
  267. input_positions,
  268. attn_metadata,
  269. sampling_metadata,
  270. multi_modal_kwargs,
  271. ) = self.prepare_input_tensors(seq_group_metadata_list)
  272. model_executable = self.model
  273. execute_model_kwargs = {
  274. "input_ids":
  275. input_tokens,
  276. "positions":
  277. input_positions,
  278. "kv_caches":
  279. kv_caches,
  280. "attn_metadata":
  281. attn_metadata,
  282. **MultiModalInputs.as_kwargs(multi_modal_kwargs or {},
  283. device=self.device),
  284. }
  285. hidden_states = model_executable(**execute_model_kwargs)
  286. # Compute the logits.
  287. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  288. # Sample the next token.
  289. output = self.model.sample(
  290. logits=logits,
  291. sampling_metadata=sampling_metadata,
  292. )
  293. return output