import dataclasses import weakref from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, TypeVar) import torch import torch.nn as nn from loguru import logger from aphrodite.attention import get_attn_backend from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from aphrodite.common.sampling_params import SamplingParams from aphrodite.common.sequence import (IntermediateTensors, SequenceGroupMetadata) from aphrodite.common.utils import CudaMemoryProfiler, make_tensor_with_pad from aphrodite.distributed import get_pp_group from aphrodite.inputs import INPUT_REGISTRY, InputRegistry from aphrodite.modeling.layers.sampler import SamplerOutput from aphrodite.modeling.model_loader import get_model from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalRegistry) from aphrodite.worker.model_runner import (AttentionMetadata, SamplingMetadata) from aphrodite.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) if TYPE_CHECKING: from aphrodite.attention.backends.abstract import AttentionBackend _PAD_SLOT_ID = -1 _BATCH_SIZE_ALIGNMENT = 8 _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) ] TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU") @dataclass(frozen=True) class ModelInputForXPU(ModelRunnerInputBase): """ Used by the NeuronModelRunner. """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None virtual_engine: Optional[int] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @classmethod def from_broadcasted_tensor_dict( cls: Type[TModelInputForXPU], tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, ) -> TModelInputForXPU: if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) @dataclass(frozen=True) class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU): """ Used by the ModelRunner. """ sampling_metadata: Optional["SamplingMetadata"] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, self.sampling_metadata) return tensor_dict @classmethod def from_broadcasted_tensor_dict( cls, tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, ) -> "ModelInputForXPUWithSamplingMetadata": tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): def __init__(self, runner: "XPUModelRunner", finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend self.sliding_window = self.runner.sliding_window self.block_size = self.runner.block_size self.device = self.runner.device def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) def build(self) -> ModelInputForXPU: is_prompt = self.seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) = self._prepare_prompt( self.seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata) = self._prepare_decode( self.seq_group_metadata_list) seq_lens = [] multi_modal_kwargs = None return self.model_input_cls( input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, multi_modal_kwargs=multi_modal_kwargs, seq_lens=seq_lens, query_lens=seq_lens, ) def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], BatchedTensorInputs]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 seq_id = seq_ids[0] seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() computed_len = seq_data.get_num_computed_tokens() seq_len = len(prompt_tokens) seq_lens.append(seq_len) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. input_positions.extend(list(range(computed_len, seq_len))) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. slot_mapping.extend([_PAD_SLOT_ID] * seq_len) continue # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: start_idx = max(0, seq_len - self.sliding_window) for i in range(computed_len, seq_len): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue block_number = block_table[i // self.block_size] # type: ignore block_offset = i % self.block_size # type: ignore slot = block_number * self.block_size + block_offset slot_mapping.append(slot) num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) # type: ignore input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore max_seqlen = max(seq_lens) tmp = [0] tmp.extend(seq_lens) seqlen = torch.tensor(tmp) seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, slot_mapping=slot_mapping, seq_lens=seq_lens, seqlen_q=seqlen_q, max_seqlen=max_seqlen, seq_lens_tensor=torch.tensor([]), max_decode_seq_len=0, num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, block_tables=torch.tensor([], device=self.device, dtype=torch.int), ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt assert seq_group_metadata.token_chunk_size == 1 seq_ids = list(seq_group_metadata.seq_data.keys()) for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() input_tokens.append(generation_token) seq_len = seq_data.get_len() position = seq_len - 1 input_positions.append(position) seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append(slot) if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window // self.block_size) block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) max_decode_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) block_tables = make_tensor_with_pad( block_tables, pad=0, dtype=torch.int, device=self.device, ) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, seq_lens=seq_lens, seqlen_q=torch.tensor([]), max_seqlen=0, seq_lens_tensor=seq_lens_tensor, max_decode_seq_len=max_decode_seq_len, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), num_prefills=0, block_tables=block_tables, ) return ( input_tokens, input_positions, attn_metadata, ) class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): _model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = ( ModelInputForXPUWithSamplingMetadata) _builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder def __init__( self, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config self.lora_config = lora_config self.load_config = load_config self.is_driver_worker = is_driver_worker self.prompt_adapter_config = prompt_adapter_config self.return_hidden_states = return_hidden_states self.device = self.device_config.device self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_context_len_to_capture = ( self.model_config.max_context_len_to_capture if self.model_config is not None else 0) self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, model_config.is_attention_free(), ) # Multi-modal data support self.input_registry = input_registry self.mm_registry = mm_registry self.multi_modal_input_mapper = mm_registry \ .create_input_mapper(model_config) self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( model_config=self.model_config, device_config=self.device_config, load_config=self.load_config, lora_config=self.lora_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, cache_config=self.cache_config, ) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took " f"{self.model_memory_usage / float(2**30):.4f} GB") @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] # Additional GPU memory may be needed for multi-modal encoding, which # needs to be accounted for when calculating the GPU blocks for # Aphrodite blocker manager. # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: max_num_seqs_orig = max_num_seqs max_num_seqs = min(max_num_seqs, max_num_batched_tokens // max_mm_tokens) if max_num_seqs < 1: expr = (f"min({max_num_seqs_orig}, " f"{max_num_batched_tokens} // {max_mm_tokens})") logger.warning( f"Computed max_num_seqs ({expr}) to be less than 1. " "Setting it to the minimum value of 1.") max_num_seqs = 1 batch_size = 0 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len seq_data, dummy_multi_modal_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, lora_request=None, multi_modal_data=dummy_multi_modal_data, ) seqs.append(seq) # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) intermediate_tensors = None if not get_pp_group().is_first_rank: intermediate_tensors = self.model.make_empty_intermediate_tensors( batch_size=batch_size, dtype=self.model_config.dtype, device=self.device) self.execute_model(model_input, kv_caches, intermediate_tensors) torch.xpu.synchronize() return def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any]) -> ModelInputForXPUWithSamplingMetadata: return ( ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, )) def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForXPUWithSamplingMetadata: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. """ builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: builder.add_seq_group(seq_group_metadata) return builder.build() # type: ignore def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForXPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. """ model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) # Sampling metadata is only required for the final pp group generators = self.get_generators(finished_requests_ids) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, model_input.seq_lens, model_input.query_lens, self.device, pin_memory=False, generators=generators) return dataclasses.replace(model_input, sampling_metadata=sampling_metadata, virtual_engine=virtual_engine) @torch.inference_mode() def execute_model( self, model_input: ModelInputForXPUWithSamplingMetadata, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( "XPUModelRunner does not support multi-step execution.") model_executable = self.model hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device)) # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: return hidden_or_intermediate_states # Compute the logits. logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) # Only perform sampling in the driver worker. if not self.is_driver_worker: return [] # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) return [output]