openvino_model_runner.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. from typing import List, NamedTuple, Optional, Tuple
  2. import openvino as ov
  3. import torch
  4. from torch import nn
  5. from aphrodite.attention import get_attn_backend
  6. from aphrodite.attention.backends.openvino import OpenVINOAttentionMetadata
  7. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  8. LoRAConfig, ModelConfig, MultiModalConfig,
  9. ParallelConfig, SchedulerConfig)
  10. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  11. from aphrodite.modeling import SamplingMetadata
  12. from aphrodite.modeling.model_loader.openvino import get_model
  13. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
  14. MultiModalInputs)
  15. class ModelInput(NamedTuple):
  16. input_tokens: torch.Tensor
  17. input_positions: torch.Tensor
  18. attn_metadata: Optional[OpenVINOAttentionMetadata]
  19. seq_lens: List[int]
  20. query_lens: List[int]
  21. multi_modal_kwargs: BatchedTensorInputs
  22. @classmethod
  23. def empty(cls, device):
  24. return ModelInput(input_tokens=torch.empty(0, device=device),
  25. input_positions=torch.empty(0, device=device),
  26. attn_metadata=None,
  27. seq_lens=[],
  28. query_lens=[],
  29. multi_modal_kwargs={})
  30. class OpenVINOModelRunner:
  31. def __init__(
  32. self,
  33. model_config: ModelConfig,
  34. parallel_config: ParallelConfig,
  35. scheduler_config: SchedulerConfig,
  36. device_config: DeviceConfig,
  37. cache_config: CacheConfig,
  38. load_config: LoadConfig,
  39. lora_config: Optional[LoRAConfig],
  40. multimodal_config: Optional[MultiModalConfig],
  41. kv_cache_dtype: Optional[str] = "auto",
  42. is_driver_worker: bool = False,
  43. *args,
  44. **kwargs,
  45. ):
  46. self.model_config = model_config
  47. self.parallel_config = parallel_config
  48. self.scheduler_config = scheduler_config
  49. self.device_config = device_config
  50. self.cache_config = cache_config
  51. self.lora_config = lora_config
  52. self.multimodal_config = multimodal_config
  53. self.load_config = load_config
  54. self.is_driver_worker = is_driver_worker
  55. self.device = self.device_config.device
  56. self.kv_cache_dtype = kv_cache_dtype
  57. self.sliding_window = model_config.get_sliding_window()
  58. self.block_size = cache_config.block_size
  59. self.attn_backend = get_attn_backend(
  60. self.model_config.get_head_size(),
  61. self.model_config.get_sliding_window(),
  62. self.model_config.dtype,
  63. self.kv_cache_dtype,
  64. self.block_size,
  65. )
  66. # Multi-modal data support
  67. self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
  68. .create_input_mapper(self.model_config)
  69. # Lazy initialization.
  70. self.model: nn.Module # Set after init_Model
  71. def load_model(self) -> None:
  72. self.model = get_model(
  73. model_config=self.model_config,
  74. device_config=self.device_config,
  75. kv_cache_dtype=self.kv_cache_dtype,
  76. )
  77. def _prepare_model_input(
  78. self,
  79. seq_group_metadata_list: List[SequenceGroupMetadata],
  80. ) -> ModelInput:
  81. """Prepare the model input based on a given sequence group.
  82. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  83. The result tensors and data structure also batches input in prefill
  84. -> decode order. For example,
  85. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  86. - input_tokens[num_prefill_tokens:] contains decode tokens.
  87. """
  88. input_tokens: List[int] = []
  89. input_positions: List[int] = []
  90. seq_lens: List[int] = []
  91. past_lens: List[int] = []
  92. query_lens: List[int] = []
  93. multi_modal_inputs_list: List[MultiModalInputs] = []
  94. subsequence_begins: List[int] = []
  95. block_indices: List[int] = []
  96. block_indices_begins: List[int] = []
  97. # initialize beginning of prefix sums
  98. subsequence_begins.append(0)
  99. block_indices_begins.append(0)
  100. if len(seq_group_metadata_list) == 0:
  101. return ModelInput.empty(self.device)
  102. for seq_group_metadata in seq_group_metadata_list:
  103. seq_ids = list(seq_group_metadata.seq_data.keys())
  104. is_prompt = seq_group_metadata.is_prompt
  105. for seq_id in seq_ids:
  106. computed_block_nums = seq_group_metadata.computed_block_nums
  107. if (self.scheduler_config is not None
  108. and self.scheduler_config.chunked_prefill_enabled
  109. and not (computed_block_nums is None
  110. or computed_block_nums == [])):
  111. raise RuntimeError(
  112. "chunked prefill cannot be used with prefix caching "
  113. "now.")
  114. seq_data = seq_group_metadata.seq_data[seq_id]
  115. if is_prompt:
  116. computed_len = seq_data.get_num_computed_tokens()
  117. else:
  118. # get_num_computed_tokens is incorrect for spec decoding.
  119. # So, we should have a special logic here.
  120. # TODO: Fix it.
  121. computed_len = seq_data.get_len() - 1
  122. seq_len = min(
  123. seq_data.get_len(),
  124. computed_len + seq_group_metadata.token_chunk_size,
  125. )
  126. if is_prompt:
  127. tokens = seq_data.get_token_ids()[computed_len:seq_len]
  128. else:
  129. # Optimization. get_token_ids requires the entire copy of
  130. # tokens.
  131. tokens = [seq_data.get_last_token_id()]
  132. # Prefix cache was hit.
  133. # Prefix is not supported with sliding_window
  134. prefix_cache_hit = (computed_block_nums is not None
  135. and len(computed_block_nums) > 0
  136. and self.sliding_window is None
  137. and is_prompt)
  138. mm_data = seq_group_metadata.multi_modal_data
  139. if mm_data:
  140. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  141. multi_modal_inputs_list.append(mm_kwargs)
  142. block_table = seq_group_metadata.block_tables[seq_id]
  143. # TODO: Combine chunked prefill and prefix caching by
  144. # only allowing multiple of block_size chunk size.
  145. # NOTE: This only works for oooooooxxx style attention.
  146. if prefix_cache_hit:
  147. assert computed_block_nums is not None
  148. computed_len = len(computed_block_nums) * self.block_size
  149. tokens = tokens[computed_len:]
  150. elif (self.scheduler_config.chunked_prefill_enabled
  151. or not is_prompt):
  152. if seq_group_metadata.block_tables is not None:
  153. # chunked prefill or decode
  154. block_table = seq_group_metadata.block_tables[seq_id]
  155. if self.sliding_window is not None:
  156. # chunked prefill doesn't support sliding window.
  157. assert not self.scheduler_config.chunked_prefill_enabled # noqa: E501
  158. sliding_window_blocks = (self.sliding_window //
  159. self.block_size)
  160. block_table = block_table[-sliding_window_blocks:]
  161. else:
  162. # Only happens when memory profiling runs.
  163. block_table = []
  164. else:
  165. # prompt phase w/o prefix_caching, chunked_prefill
  166. pass
  167. block_indices.extend(block_table)
  168. block_indices_begins.append(block_indices_begins[-1] +
  169. len(block_table))
  170. # TODO: This is a hack to make sliding window work with
  171. # paged attn. We can remove it if we make paged attn kernel
  172. # to properly handle slinding window attn.
  173. if self.sliding_window is not None and not is_prompt:
  174. seq_len = min(seq_len, self.sliding_window)
  175. computed_len = seq_len - 1
  176. seq_lens.append(seq_len)
  177. query_len = seq_len - computed_len
  178. query_lens.append(query_len)
  179. input_tokens.extend(tokens)
  180. input_positions.extend(list(range(computed_len, seq_len)))
  181. past_lens.append(computed_len)
  182. subsequence_begins.append(subsequence_begins[-1] + query_len)
  183. if is_prompt:
  184. assert len(seq_ids) == 1
  185. else:
  186. assert (
  187. query_len == 1
  188. ), "seq_len: {}, computed_len: {}, query_len: {}".format(
  189. seq_len, computed_len, query_len)
  190. max_query_len = max(query_lens)
  191. assert max_query_len > 0, "query_lens: {}".format(query_lens)
  192. input_tokens = torch.tensor(input_tokens,
  193. dtype=torch.long,
  194. device=self.device) # type: ignore
  195. input_positions = torch.tensor(input_positions,
  196. dtype=torch.long,
  197. device=self.device) # type: ignore
  198. past_lens_tensor = torch.tensor(past_lens,
  199. dtype=torch.int32,
  200. device=self.device) # type: ignore
  201. subsequence_begins_tensor = torch.tensor(
  202. subsequence_begins, dtype=torch.int32,
  203. device=self.device) # type: ignore
  204. block_indices_tensor = torch.tensor(block_indices,
  205. dtype=torch.int32,
  206. device=self.device) # type: ignore
  207. block_indices_begins_tensor = torch.tensor(
  208. block_indices_begins, dtype=torch.int32,
  209. device=self.device) # type: ignore
  210. max_context_len = max(seq_lens)
  211. max_context_len_tensor = torch.tensor(
  212. max_context_len, dtype=torch.int32,
  213. device=self.device) # type: ignore
  214. attn_metadata = self.attn_backend.make_openvino_metadata(
  215. past_lens=past_lens_tensor,
  216. subsequence_begins=subsequence_begins_tensor,
  217. block_indices=block_indices_tensor,
  218. block_indices_begins=block_indices_begins_tensor,
  219. max_context_len=max_context_len_tensor,
  220. )
  221. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
  222. return ModelInput(
  223. input_tokens,
  224. input_positions,
  225. attn_metadata,
  226. seq_lens,
  227. query_lens,
  228. multi_modal_kwargs=multi_modal_kwargs,
  229. )
  230. def prepare_input_tensors(
  231. self,
  232. seq_group_metadata_list: List[SequenceGroupMetadata],
  233. ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
  234. SamplingMetadata, BatchedTensorInputs]:
  235. # Prepare input tensors.
  236. (
  237. input_tokens,
  238. input_positions,
  239. attn_metadata,
  240. seq_lens,
  241. query_lens,
  242. multi_modal_kwargs,
  243. ) = self._prepare_model_input(seq_group_metadata_list)
  244. sampling_metadata = SamplingMetadata.prepare(
  245. seq_group_metadata_list,
  246. seq_lens,
  247. query_lens,
  248. self.device,
  249. pin_memory=False,
  250. )
  251. return (
  252. input_tokens,
  253. input_positions,
  254. attn_metadata,
  255. sampling_metadata,
  256. multi_modal_kwargs,
  257. )
  258. @torch.inference_mode()
  259. def execute_model(
  260. self,
  261. seq_group_metadata_list: List[SequenceGroupMetadata],
  262. kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]],
  263. ) -> Optional[SamplerOutput]:
  264. (
  265. input_tokens,
  266. input_positions,
  267. attn_metadata,
  268. sampling_metadata,
  269. multi_modal_kwargs,
  270. ) = self.prepare_input_tensors(seq_group_metadata_list)
  271. model_executable = self.model
  272. execute_model_kwargs = {
  273. "input_ids":
  274. input_tokens,
  275. "positions":
  276. input_positions,
  277. "kv_caches":
  278. kv_caches,
  279. "attn_metadata":
  280. attn_metadata,
  281. **MultiModalInputs.as_kwargs(multi_modal_kwargs or {},
  282. device=self.device),
  283. }
  284. hidden_states = model_executable(**execute_model_kwargs)
  285. # Compute the logits.
  286. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  287. # Sample the next token.
  288. output = self.model.sample(
  289. logits=logits,
  290. sampling_metadata=sampling_metadata,
  291. )
  292. return output