from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn from loguru import logger from aphrodite.attention import get_attn_backend from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from aphrodite.common.sampling_params import SamplingParams from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from aphrodite.common.utils import CudaMemoryProfiler, make_tensor_with_pad from aphrodite.distributed import broadcast_tensor_dict from aphrodite.inputs import INPUT_REGISTRY, InputRegistry from aphrodite.modeling.model_loader import get_model from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalRegistry) from aphrodite.task_handler.model_runner import (AttentionMetadata, SamplingMetadata) from aphrodite.task_handler.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) if TYPE_CHECKING: from aphrodite.attention.backends.abstract import AttentionBackend _PAD_SLOT_ID = -1 _BATCH_SIZE_ALIGNMENT = 8 _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) ] @dataclass(frozen=True) class ModelInputForXPU(ModelRunnerInputBase): """ Used by the NeuronModelRunner. """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, self.sampling_metadata) return tensor_dict @classmethod def from_broadcasted_tensor_dict( cls: Type["ModelInputForXPU"], tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, ) -> "ModelInputForXPU": tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): def __init__( self, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], kv_cache_dtype: Optional[str] = "auto", prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, *args, **kwargs, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.lora_config = lora_config self.load_config = load_config self.cache_config = cache_config self.prompt_adapter_config = prompt_adapter_config self.multimodal_config = multimodal_config self.is_driver_worker = is_driver_worker self.sliding_window = model_config.get_sliding_window() self.device_config = device_config self.device = self.device_config.device self.kv_cache_dtype = kv_cache_dtype self.block_size = cache_config.block_size self.max_context_len_to_capture = ( self.model_config.max_context_len_to_capture if self.model_config is not None else 0) self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, model_config.is_attention_free(), ) # Multi-modal data support self.input_registry = input_registry self.mm_registry = mm_registry self.multi_modal_input_mapper = mm_registry \ .create_input_mapper(model_config) self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( model_config=self.model_config, device_config=self.device_config, load_config=self.load_config, lora_config=self.lora_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, cache_config=self.cache_config, ) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took " f"{self.model_memory_usage / float(2**30):.4f} GB") @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] # Additional GPU memory may be needed for multi-modal encoding, which # needs to be accounted for when calculating the GPU blocks for # Aphrodite blocker manager. # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: max_num_seqs_orig = max_num_seqs max_num_seqs = min(max_num_seqs, max_num_batched_tokens // max_mm_tokens) if max_num_seqs < 1: expr = (f"min({max_num_seqs_orig}, " f"{max_num_batched_tokens} // {max_mm_tokens})") logger.warning( f"Computed max_num_seqs ({expr}) to be less than 1. " "Setting it to the minimum value of 1.") max_num_seqs = 1 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) seq_data, dummy_multi_modal_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, lora_request=None, multi_modal_data=dummy_multi_modal_data, ) seqs.append(seq) # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers model_input = self.prepare_model_input(seqs) self.execute_model(model_input, kv_caches) torch.xpu.synchronize() return def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any]) -> ModelInputForXPU: return (ModelInputForXPU.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, )) def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForXPU: multi_modal_kwargs = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata) = self._prepare_decode(seq_group_metadata_list) seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, # subquery_lens is not needed if chunked prefill is not # supported. Since CPU worker doesn't support chunked prefill # just use seq_lens instead. seq_lens, self.device, pin_memory=False, generators=self.get_generators(finished_requests_ids)) # Broadcast the metadata. metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, "selected_token_indices": sampling_metadata.selected_token_indices, "multi_modal_kwargs": multi_modal_kwargs, } metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") selected_token_indices = metadata_dict.pop( "selected_token_indices") multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") attn_metadata = self.attn_backend.make_metadata(**metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, selected_token_indices=selected_token_indices, categorized_sample_indices=None, num_prompts=0, ) return ModelInputForXPU(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, sampling_metadata=sampling_metadata, multi_modal_kwargs=multi_modal_kwargs) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt assert seq_group_metadata.token_chunk_size == 1 seq_ids = list(seq_group_metadata.seq_data.keys()) for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() input_tokens.append(generation_token) seq_len = seq_data.get_len() position = seq_len - 1 input_positions.append(position) seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append(slot) if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window // self.block_size) block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) max_decode_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) block_tables = make_tensor_with_pad( block_tables, pad=0, dtype=torch.int, device=self.device, ) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, seq_lens=seq_lens, seqlen_q=None, max_seqlen=None, seq_lens_tensor=seq_lens_tensor, max_decode_seq_len=max_decode_seq_len, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), num_prefills=0, block_tables=block_tables, ) return ( input_tokens, input_positions, attn_metadata, ) @torch.inference_mode() def execute_model( self, model_input: ModelInputForXPU, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( "XPUModelRunner does not support multi-step execution.") model_executable = self.model execute_model_kwargs = { "input_ids": model_input.input_tokens, "positions": model_input.input_positions, "kv_caches": kv_caches, "attn_metadata": model_input.attn_metadata, **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), } hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, model_input.sampling_metadata) # Only perform sampling in the driver worker. if not self.is_driver_worker: return [] # Sample the next token. output = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) return [output] def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], BatchedTensorInputs]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 seq_id = seq_ids[0] seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() computed_len = seq_data.get_num_computed_tokens() seq_len = len(prompt_tokens) seq_lens.append(seq_len) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids # Token position ids # NOTE: Here we assume that the first token in the prompt # is always the first token in the sequence. input_positions.extend(list(range(computed_len, seq_len))) mm_data = seq_group_metadata.multi_modal_data if mm_data: mm_kwargs = self.multi_modal_input_mapper(mm_data) multi_modal_inputs_list.append(mm_kwargs) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. slot_mapping.extend([_PAD_SLOT_ID] * seq_len) continue # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: start_idx = max(0, seq_len - self.sliding_window) for i in range(computed_len, seq_len): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue block_number = block_table[i // self.block_size] # type: ignore block_offset = i % self.block_size # type: ignore slot = block_number * self.block_size + block_offset slot_mapping.append(slot) num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) # type: ignore input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore max_seqlen = max(seq_lens) tmp = [0] tmp.extend(seq_lens) seqlen = torch.tensor(tmp) seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, slot_mapping=slot_mapping, seq_lens=seq_lens, seqlen_q=seqlen_q, max_seqlen=max_seqlen, seq_lens_tensor=None, max_decode_seq_len=None, num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, block_tables=torch.tensor([], device=self.device, dtype=torch.int), ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs)