123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539 |
- import dataclasses
- import weakref
- from dataclasses import dataclass
- from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
- TypeVar)
- import torch
- import torch.nn as nn
- from loguru import logger
- from aphrodite.attention import get_attn_backend
- from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
- LoRAConfig, ModelConfig, ParallelConfig,
- PromptAdapterConfig, SchedulerConfig)
- from aphrodite.common.sampling_params import SamplingParams
- from aphrodite.common.sequence import (IntermediateTensors,
- SequenceGroupMetadata)
- from aphrodite.common.utils import CudaMemoryProfiler, make_tensor_with_pad
- from aphrodite.distributed import get_pp_group
- from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
- from aphrodite.modeling.layers.sampler import SamplerOutput
- from aphrodite.modeling.model_loader import get_model
- from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
- MultiModalInputs, MultiModalRegistry)
- from aphrodite.worker.model_runner import (AttentionMetadata,
- SamplingMetadata)
- from aphrodite.worker.model_runner_base import (
- ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
- _add_attn_metadata_broadcastable_dict,
- _add_sampling_metadata_broadcastable_dict,
- _init_attn_metadata_from_tensor_dict,
- _init_sampling_metadata_from_tensor_dict)
- if TYPE_CHECKING:
- from aphrodite.attention.backends.abstract import AttentionBackend
- _PAD_SLOT_ID = -1
- _BATCH_SIZE_ALIGNMENT = 8
- _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
- _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
- ]
- TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU")
- @dataclass(frozen=True)
- class ModelInputForXPU(ModelRunnerInputBase):
- """
- Used by the NeuronModelRunner.
- """
- input_tokens: Optional[torch.Tensor] = None
- input_positions: Optional[torch.Tensor] = None
- attn_metadata: Optional["AttentionMetadata"] = None
- multi_modal_kwargs: Optional[BatchedTensorInputs] = None
- virtual_engine: Optional[int] = None
- seq_lens: Optional[List[int]] = None
- query_lens: Optional[List[int]] = None
- def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
- tensor_dict = {
- "input_tokens": self.input_tokens,
- "input_positions": self.input_positions,
- }
- _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
- return tensor_dict
- @classmethod
- def from_broadcasted_tensor_dict(
- cls: Type[TModelInputForXPU],
- tensor_dict: Dict[str, Any],
- attn_backend: Optional["AttentionBackend"] = None,
- ) -> TModelInputForXPU:
- if attn_backend is not None:
- tensor_dict = _init_attn_metadata_from_tensor_dict(
- attn_backend, tensor_dict)
- return cls(**tensor_dict)
- @dataclass(frozen=True)
- class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU):
- """
- Used by the ModelRunner.
- """
- sampling_metadata: Optional["SamplingMetadata"] = None
- def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
- tensor_dict = {
- "input_tokens": self.input_tokens,
- "input_positions": self.input_positions,
- }
- _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
- _add_sampling_metadata_broadcastable_dict(tensor_dict,
- self.sampling_metadata)
- return tensor_dict
- @classmethod
- def from_broadcasted_tensor_dict(
- cls,
- tensor_dict: Dict[str, Any],
- attn_backend: Optional["AttentionBackend"] = None,
- ) -> "ModelInputForXPUWithSamplingMetadata":
- tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
- if attn_backend is not None:
- tensor_dict = _init_attn_metadata_from_tensor_dict(
- attn_backend, tensor_dict)
- return cls(**tensor_dict)
- class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
- def __init__(self,
- runner: "XPUModelRunner",
- finished_requests_ids: Optional[List[str]] = None) -> None:
- super().__init__()
- self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
- self.runner = runner
- self.model_input_cls = self.runner._model_input_cls
- self.attn_backend = self.runner.attn_backend
- self.sliding_window = self.runner.sliding_window
- self.block_size = self.runner.block_size
- self.device = self.runner.device
- def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
- self.seq_group_metadata_list.append(seq_group_metadata)
- def build(self) -> ModelInputForXPU:
- is_prompt = self.seq_group_metadata_list[0].is_prompt
- # Prepare input tensors.
- if is_prompt:
- (input_tokens, input_positions, attn_metadata, seq_lens,
- multi_modal_kwargs) = self._prepare_prompt(
- self.seq_group_metadata_list)
- else:
- (input_tokens, input_positions,
- attn_metadata) = self._prepare_decode(
- self.seq_group_metadata_list)
- seq_lens = []
- multi_modal_kwargs = None
- return self.model_input_cls(
- input_tokens=input_tokens,
- input_positions=input_positions,
- attn_metadata=attn_metadata,
- multi_modal_kwargs=multi_modal_kwargs,
- seq_lens=seq_lens,
- query_lens=seq_lens,
- )
- def _prepare_prompt(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
- BatchedTensorInputs]:
- assert len(seq_group_metadata_list) > 0
- input_tokens: List[int] = []
- input_positions: List[int] = []
- slot_mapping: List[int] = []
- seq_lens: List[int] = []
- multi_modal_inputs_list: List[MultiModalInputs] = []
- for seq_group_metadata in seq_group_metadata_list:
- assert seq_group_metadata.is_prompt
- seq_ids = list(seq_group_metadata.seq_data.keys())
- assert len(seq_ids) == 1
- seq_id = seq_ids[0]
- seq_data = seq_group_metadata.seq_data[seq_id]
- prompt_tokens = seq_data.get_token_ids()
- computed_len = seq_data.get_num_computed_tokens()
- seq_len = len(prompt_tokens)
- seq_lens.append(seq_len) # Prompt token num
- input_tokens.extend(prompt_tokens) # Token ids
- # Token position ids
- # NOTE(woosuk): Here we assume that the first token in the prompt
- # is always the first token in the sequence.
- input_positions.extend(list(range(computed_len, seq_len)))
- if seq_group_metadata.block_tables is None:
- # During memory profiling, the block tables are not initialized
- # yet. In this case, we just use a dummy slot mapping.
- slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
- continue
- # Compute the slot mapping.
- block_table = seq_group_metadata.block_tables[seq_id]
- # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
- # where start_idx is max(0, seq_len - sliding_window).
- # For example, if the prompt len is 10, sliding window is 8, and
- # block size is 4, the first two tokens are masked and the slot
- # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
- start_idx = 0
- if self.sliding_window is not None:
- start_idx = max(0, seq_len - self.sliding_window)
- for i in range(computed_len, seq_len):
- if i < start_idx:
- slot_mapping.append(_PAD_SLOT_ID)
- continue
- block_number = block_table[i //
- self.block_size] # type: ignore
- block_offset = i % self.block_size # type: ignore
- slot = block_number * self.block_size + block_offset
- slot_mapping.append(slot)
- num_prompt_tokens = len(input_tokens)
- input_tokens = torch.tensor(input_tokens,
- dtype=torch.long,
- device=self.device) # type: ignore
- input_positions = torch.tensor(input_positions,
- dtype=torch.long,
- device=self.device) # type: ignore
- slot_mapping = torch.tensor(slot_mapping,
- dtype=torch.long,
- device=self.device) # type: ignore
- max_seqlen = max(seq_lens)
- tmp = [0]
- tmp.extend(seq_lens)
- seqlen = torch.tensor(tmp)
- seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)
- attn_metadata = self.attn_backend.make_metadata(
- is_prompt=True,
- slot_mapping=slot_mapping,
- seq_lens=seq_lens,
- seqlen_q=seqlen_q,
- max_seqlen=max_seqlen,
- seq_lens_tensor=torch.tensor([]),
- max_decode_seq_len=0,
- num_prefills=len(seq_lens),
- num_prefill_tokens=num_prompt_tokens,
- num_decode_tokens=0,
- block_tables=torch.tensor([], device=self.device, dtype=torch.int),
- )
- multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
- return (input_tokens, input_positions, attn_metadata, seq_lens,
- multi_modal_kwargs)
- def _prepare_decode(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
- assert len(seq_group_metadata_list) > 0
- input_tokens: List[int] = []
- input_positions: List[int] = []
- slot_mapping: List[int] = []
- seq_lens: List[int] = []
- block_tables: List[List[int]] = []
- for seq_group_metadata in seq_group_metadata_list:
- assert not seq_group_metadata.is_prompt
- assert seq_group_metadata.token_chunk_size == 1
- seq_ids = list(seq_group_metadata.seq_data.keys())
- for seq_id in seq_ids:
- seq_data = seq_group_metadata.seq_data[seq_id]
- generation_token = seq_data.get_last_token_id()
- input_tokens.append(generation_token)
- seq_len = seq_data.get_len()
- position = seq_len - 1
- input_positions.append(position)
- seq_len = seq_len if self.sliding_window is None else min(
- seq_len, self.sliding_window)
- seq_lens.append(seq_len)
- block_table = seq_group_metadata.block_tables[seq_id]
- block_number = block_table[position // self.block_size]
- block_offset = position % self.block_size
- slot = block_number * self.block_size + block_offset
- slot_mapping.append(slot)
- if self.sliding_window is not None:
- sliding_window_blocks = (self.sliding_window //
- self.block_size)
- block_table = block_table[-sliding_window_blocks:]
- block_tables.append(block_table)
- max_decode_seq_len = max(seq_lens)
- input_tokens = torch.tensor(input_tokens,
- dtype=torch.long,
- device=self.device)
- input_positions = torch.tensor(input_positions,
- dtype=torch.long,
- device=self.device)
- slot_mapping = torch.tensor(slot_mapping,
- dtype=torch.long,
- device=self.device)
- seq_lens_tensor = torch.tensor(seq_lens,
- dtype=torch.int,
- device=self.device)
- block_tables = make_tensor_with_pad(
- block_tables,
- pad=0,
- dtype=torch.int,
- device=self.device,
- )
- attn_metadata = self.attn_backend.make_metadata(
- is_prompt=False,
- slot_mapping=slot_mapping,
- seq_lens=seq_lens,
- seqlen_q=torch.tensor([]),
- max_seqlen=0,
- seq_lens_tensor=seq_lens_tensor,
- max_decode_seq_len=max_decode_seq_len,
- num_prefill_tokens=0,
- num_decode_tokens=len(input_tokens),
- num_prefills=0,
- block_tables=block_tables,
- )
- return (
- input_tokens,
- input_positions,
- attn_metadata,
- )
- class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
- _model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = (
- ModelInputForXPUWithSamplingMetadata)
- _builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder
- def __init__(
- self,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- cache_config: CacheConfig,
- load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
- kv_cache_dtype: Optional[str] = "auto",
- is_driver_worker: bool = False,
- prompt_adapter_config: Optional[PromptAdapterConfig] = None,
- return_hidden_states: bool = False,
- input_registry: InputRegistry = INPUT_REGISTRY,
- mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
- ):
- self.model_config = model_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
- self.cache_config = cache_config
- self.lora_config = lora_config
- self.load_config = load_config
- self.is_driver_worker = is_driver_worker
- self.prompt_adapter_config = prompt_adapter_config
- self.return_hidden_states = return_hidden_states
- self.device = self.device_config.device
- self.kv_cache_dtype = kv_cache_dtype
- self.sliding_window = model_config.get_sliding_window()
- self.block_size = cache_config.block_size
- self.max_context_len_to_capture = (
- self.model_config.max_context_len_to_capture
- if self.model_config is not None else 0)
- self.attn_backend = get_attn_backend(
- self.model_config.get_head_size(),
- self.model_config.get_sliding_window(),
- self.model_config.dtype,
- self.kv_cache_dtype,
- self.block_size,
- model_config.is_attention_free(),
- )
- # Multi-modal data support
- self.input_registry = input_registry
- self.mm_registry = mm_registry
- self.multi_modal_input_mapper = mm_registry \
- .create_input_mapper(model_config)
- self.mm_registry.init_mm_limits_per_prompt(self.model_config)
- # Lazy initialization.
- self.model: nn.Module # Set after init_Model
- def load_model(self) -> None:
- with CudaMemoryProfiler() as m:
- self.model = get_model(
- model_config=self.model_config,
- device_config=self.device_config,
- load_config=self.load_config,
- lora_config=self.lora_config,
- parallel_config=self.parallel_config,
- scheduler_config=self.scheduler_config,
- cache_config=self.cache_config,
- )
- self.model_memory_usage = m.consumed_memory
- logger.info("Loading model weights took "
- f"{self.model_memory_usage / float(2**30):.4f} GB")
- @property
- def vocab_size(self) -> int:
- return self.model_config.get_vocab_size()
- @torch.inference_mode()
- def profile_run(self) -> None:
- # Enable top-k sampling to reflect the accurate memory usage.
- sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
- max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
- max_num_seqs = self.scheduler_config.max_num_seqs
- # Profile memory usage with max_num_sequences sequences and the total
- # number of tokens equal to max_num_batched_tokens.
- seqs: List[SequenceGroupMetadata] = []
- # Additional GPU memory may be needed for multi-modal encoding, which
- # needs to be accounted for when calculating the GPU blocks for
- # Aphrodite blocker manager.
- # To exercise the worst scenario for GPU memory consumption,
- # the number of seqs (batch_size) is chosen to maximize the number
- # of images processed.
- max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
- self.model_config)
- if max_mm_tokens > 0:
- max_num_seqs_orig = max_num_seqs
- max_num_seqs = min(max_num_seqs,
- max_num_batched_tokens // max_mm_tokens)
- if max_num_seqs < 1:
- expr = (f"min({max_num_seqs_orig}, "
- f"{max_num_batched_tokens} // {max_mm_tokens})")
- logger.warning(
- f"Computed max_num_seqs ({expr}) to be less than 1. "
- "Setting it to the minimum value of 1.")
- max_num_seqs = 1
- batch_size = 0
- for group_id in range(max_num_seqs):
- seq_len = (max_num_batched_tokens // max_num_seqs +
- (group_id < max_num_batched_tokens % max_num_seqs))
- batch_size += seq_len
- seq_data, dummy_multi_modal_data = self.input_registry \
- .dummy_data_for_profiling(self.model_config,
- seq_len,
- self.mm_registry)
- seq = SequenceGroupMetadata(
- request_id=str(group_id),
- is_prompt=True,
- seq_data={group_id: seq_data},
- sampling_params=sampling_params,
- block_tables=None,
- lora_request=None,
- multi_modal_data=dummy_multi_modal_data,
- )
- seqs.append(seq)
- # Run the model with the dummy inputs.
- num_layers = self.model_config.get_num_layers(self.parallel_config)
- kv_caches = [None] * num_layers
- finished_requests_ids = [seq.request_id for seq in seqs]
- model_input = self.prepare_model_input(
- seqs, finished_requests_ids=finished_requests_ids)
- intermediate_tensors = None
- if not get_pp_group().is_first_rank:
- intermediate_tensors = self.model.make_empty_intermediate_tensors(
- batch_size=batch_size,
- dtype=self.model_config.dtype,
- device=self.device)
- self.execute_model(model_input, kv_caches, intermediate_tensors)
- torch.xpu.synchronize()
- return
- def make_model_input_from_broadcasted_tensor_dict(
- self,
- tensor_dict: Dict[str,
- Any]) -> ModelInputForXPUWithSamplingMetadata:
- return (
- ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict(
- tensor_dict,
- attn_backend=self.attn_backend,
- ))
- def _prepare_model_input_tensors(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- finished_requests_ids: Optional[List[str]] = None
- ) -> ModelInputForXPUWithSamplingMetadata:
- """Helper method to prepare the model input based on a given sequence
- group. Prepares metadata needed for the base model forward pass but not
- metadata for possible additional steps, e.g., sampling.
- """
- builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
- for seq_group_metadata in seq_group_metadata_list:
- builder.add_seq_group(seq_group_metadata)
- return builder.build() # type: ignore
- def prepare_model_input(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- virtual_engine: int = 0,
- finished_requests_ids: Optional[List[str]] = None
- ) -> ModelInputForXPUWithSamplingMetadata:
- """Prepare the model input based on a given sequence group, including
- metadata for the sampling step.
- """
- model_input = self._prepare_model_input_tensors(
- seq_group_metadata_list, finished_requests_ids)
- # Sampling metadata is only required for the final pp group
- generators = self.get_generators(finished_requests_ids)
- sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
- model_input.seq_lens,
- model_input.query_lens,
- self.device,
- pin_memory=False,
- generators=generators)
- return dataclasses.replace(model_input,
- sampling_metadata=sampling_metadata,
- virtual_engine=virtual_engine)
- @torch.inference_mode()
- def execute_model(
- self,
- model_input: ModelInputForXPUWithSamplingMetadata,
- kv_caches: List[torch.Tensor],
- intermediate_tensors: Optional[IntermediateTensors] = None,
- num_steps: int = 1,
- ) -> Optional[List[SamplerOutput]]:
- if num_steps > 1:
- raise ValueError(
- "XPUModelRunner does not support multi-step execution.")
- model_executable = self.model
- hidden_or_intermediate_states = model_executable(
- input_ids=model_input.input_tokens,
- positions=model_input.input_positions,
- kv_caches=kv_caches,
- attn_metadata=model_input.attn_metadata,
- intermediate_tensors=intermediate_tensors,
- **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
- device=self.device))
- # Compute the logits in the last pipeline stage.
- if not get_pp_group().is_last_rank:
- return hidden_or_intermediate_states
- # Compute the logits.
- logits = self.model.compute_logits(hidden_or_intermediate_states,
- model_input.sampling_metadata)
- # Only perform sampling in the driver worker.
- if not self.is_driver_worker:
- return []
- # Sample the next token.
- output: SamplerOutput = self.model.sample(
- logits=logits,
- sampling_metadata=model_input.sampling_metadata,
- )
- return [output]
|