from typing import List, NamedTuple, Optional, Tuple import openvino as ov import torch from torch import nn from aphrodite.attention import get_attn_backend from aphrodite.attention.backends.openvino import OpenVINOAttentionMetadata from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, SchedulerConfig) from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata from aphrodite.modeling.model_loader.openvino import get_model from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) class ModelInput(NamedTuple): input_tokens: torch.Tensor input_positions: torch.Tensor attn_metadata: Optional[OpenVINOAttentionMetadata] seq_lens: List[int] query_lens: List[int] multi_modal_kwargs: BatchedTensorInputs @classmethod def empty(cls, device): return ModelInput(input_tokens=torch.empty(0, device=device), input_positions=torch.empty(0, device=device), attn_metadata=None, seq_lens=[], query_lens=[], multi_modal_kwargs={}) class OpenVINOModelRunner: def __init__( self, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, *args, **kwargs, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config self.lora_config = lora_config self.multimodal_config = multimodal_config self.load_config = load_config self.is_driver_worker = is_driver_worker self.device = self.device_config.device self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, ) # Multi-modal data support self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ .create_input_mapper(self.model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model def load_model(self) -> None: self.model = get_model( model_config=self.model_config, device_config=self.device_config, kv_cache_dtype=self.kv_cache_dtype, ) def _prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> ModelInput: """Prepare the model input based on a given sequence group. The API assumes seq_group_metadata_list is sorted by prefill -> decode. The result tensors and data structure also batches input in prefill -> decode order. For example, - input_tokens[:num_prefill_tokens] contains prefill tokens. - input_tokens[num_prefill_tokens:] contains decode tokens. """ input_tokens: List[int] = [] input_positions: List[int] = [] seq_lens: List[int] = [] past_lens: List[int] = [] query_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] subsequence_begins: List[int] = [] block_indices: List[int] = [] block_indices_begins: List[int] = [] # initialize beginning of prefix sums subsequence_begins.append(0) block_indices_begins.append(0) if len(seq_group_metadata_list) == 0: return ModelInput.empty(self.device) for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) is_prompt = seq_group_metadata.is_prompt for seq_id in seq_ids: computed_block_nums = seq_group_metadata.computed_block_nums if (self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled and not (computed_block_nums is None or computed_block_nums == [])): raise RuntimeError( "chunked prefill cannot be used with prefix caching " "now.") seq_data = seq_group_metadata.seq_data[seq_id] if is_prompt: computed_len = seq_data.get_num_computed_tokens() else: # get_num_computed_tokens is incorrect for spec decoding. # So, we should have a special logic here. # TODO: Fix it. computed_len = seq_data.get_len() - 1 seq_len = min( seq_data.get_len(), computed_len + seq_group_metadata.token_chunk_size, ) if is_prompt: tokens = seq_data.get_token_ids()[computed_len:seq_len] else: # Optimization. get_token_ids requires the entire copy of # tokens. tokens = [seq_data.get_last_token_id()] # Prefix cache was hit. # Prefix is not supported with sliding_window prefix_cache_hit = (computed_block_nums is not None and len(computed_block_nums) > 0 and self.sliding_window is None and is_prompt) mm_data = seq_group_metadata.multi_modal_data if mm_data: mm_kwargs = self.multi_modal_input_mapper(mm_data) multi_modal_inputs_list.append(mm_kwargs) block_table = seq_group_metadata.block_tables[seq_id] # TODO: Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. if prefix_cache_hit: assert computed_block_nums is not None computed_len = len(computed_block_nums) * self.block_size tokens = tokens[computed_len:] elif (self.scheduler_config.chunked_prefill_enabled or not is_prompt): if seq_group_metadata.block_tables is not None: # chunked prefill or decode block_table = seq_group_metadata.block_tables[seq_id] if self.sliding_window is not None: # chunked prefill doesn't support sliding window. assert not self.scheduler_config.chunked_prefill_enabled # noqa: E501 sliding_window_blocks = (self.sliding_window // self.block_size) block_table = block_table[-sliding_window_blocks:] else: # Only happens when memory profiling runs. block_table = [] else: # prompt phase w/o prefix_caching, chunked_prefill pass block_indices.extend(block_table) block_indices_begins.append(block_indices_begins[-1] + len(block_table)) # TODO: This is a hack to make sliding window work with # paged attn. We can remove it if we make paged attn kernel # to properly handle slinding window attn. if self.sliding_window is not None and not is_prompt: seq_len = min(seq_len, self.sliding_window) computed_len = seq_len - 1 seq_lens.append(seq_len) query_len = seq_len - computed_len query_lens.append(query_len) input_tokens.extend(tokens) input_positions.extend(list(range(computed_len, seq_len))) past_lens.append(computed_len) subsequence_begins.append(subsequence_begins[-1] + query_len) if is_prompt: assert len(seq_ids) == 1 else: assert ( query_len == 1 ), "seq_len: {}, computed_len: {}, query_len: {}".format( seq_len, computed_len, query_len) max_query_len = max(query_lens) assert max_query_len > 0, "query_lens: {}".format(query_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) # type: ignore input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) # type: ignore past_lens_tensor = torch.tensor(past_lens, dtype=torch.int32, device=self.device) # type: ignore subsequence_begins_tensor = torch.tensor( subsequence_begins, dtype=torch.int32, device=self.device) # type: ignore block_indices_tensor = torch.tensor(block_indices, dtype=torch.int32, device=self.device) # type: ignore block_indices_begins_tensor = torch.tensor( block_indices_begins, dtype=torch.int32, device=self.device) # type: ignore max_context_len = max(seq_lens) max_context_len_tensor = torch.tensor( max_context_len, dtype=torch.int32, device=self.device) # type: ignore attn_metadata = self.attn_backend.make_openvino_metadata( past_lens=past_lens_tensor, subsequence_begins=subsequence_begins_tensor, block_indices=block_indices_tensor, block_indices_begins=block_indices_begins_tensor, max_context_len=max_context_len_tensor, ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) return ModelInput( input_tokens, input_positions, attn_metadata, seq_lens, query_lens, multi_modal_kwargs=multi_modal_kwargs, ) def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata, SamplingMetadata, BatchedTensorInputs]: # Prepare input tensors. ( input_tokens, input_positions, attn_metadata, seq_lens, query_lens, multi_modal_kwargs, ) = self._prepare_model_input(seq_group_metadata_list) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, pin_memory=False, ) return ( input_tokens, input_positions, attn_metadata, sampling_metadata, multi_modal_kwargs, ) @torch.inference_mode() def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]], ) -> Optional[SamplerOutput]: ( input_tokens, input_positions, attn_metadata, sampling_metadata, multi_modal_kwargs, ) = self.prepare_input_tensors(seq_group_metadata_list) model_executable = self.model execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, "kv_caches": kv_caches, "attn_metadata": attn_metadata, **MultiModalInputs.as_kwargs(multi_modal_kwargs or {}, device=self.device), } hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) # Sample the next token. output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, ) return output