import dataclasses import itertools from typing import Any, Dict, List, Optional, Tuple, Type, cast import torch import torch.distributed from loguru import logger from aphrodite.attention.backends.abstract import (AttentionBackend, AttentionMetadata) from aphrodite.attention.backends.utils import PAD_SLOT_ID from aphrodite.attention.selector import (_Backend, get_env_variable_attn_backend, get_global_forced_attn_backend, global_force_attn_backend) from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from aphrodite.common.sampling_params import SamplingParams from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad) from aphrodite.inputs import INPUT_REGISTRY, InputRegistry from aphrodite.modeling.layers.sampler import SamplerOutput from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from aphrodite.worker.model_runner import ( GPUModelRunnerBase, ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata, _get_graph_batch_size) from aphrodite.worker.model_runner_base import ( _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict) from aphrodite.worker.utils import assert_enc_dec_mr_supported_scenario @dataclasses.dataclass(frozen=True) class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): """ Used by the EncoderDecoderModelRunner. """ encoder_input_tokens: Optional[torch.Tensor] = None encoder_input_positions: Optional[torch.Tensor] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, "encoder_input_tokens": self.encoder_input_tokens, "encoder_input_positions": self.encoder_input_positions, "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, self.sampling_metadata) return tensor_dict @classmethod def from_broadcasted_tensor_dict( cls, tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, ) -> "EncoderDecoderModelInput": return cast( EncoderDecoderModelInput, super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): _model_input_cls: Type[EncoderDecoderModelInput] = ( EncoderDecoderModelInput) _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) def __init__( self, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, prompt_adapter_config: Optional[PromptAdapterConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, **kwargs, ): ''' EncoderDecoderModelRunner constructor. `lora_config` and `prompt_adapter_config` are unused (since these features are not yet supported for encoder/decoder models) but these arguments are present here for compatibility with the base-class constructor. ''' self._maybe_force_supported_attention_backend() super().__init__( model_config, parallel_config, scheduler_config, device_config, cache_config, load_config, lora_config=None, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, **kwargs, ) # Crash for unsupported encoder/scenarios assert_enc_dec_mr_supported_scenario(self) def _maybe_force_supported_attention_backend(self): ''' Force Aphrodite to use the XFormers attention backend, which is currently the only supported option. ''' def raise_backend_err(): # The user has specified an attention backend override # which is invalid for encoder/decoder models raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND) maybe_env_var_forced_backend = get_env_variable_attn_backend() maybe_global_forced_backend = get_global_forced_attn_backend() is_forced_by_global = maybe_global_forced_backend is not None is_forced_by_env_var = maybe_env_var_forced_backend is not None if not (is_forced_by_global or is_forced_by_env_var): # The user has not already specified an attention backend # override logger.info("EncoderDecoderModelRunner requires " "XFormers backend; overriding backend " "auto-selection and forcing XFormers.") global_force_attn_backend(_Backend.XFORMERS) elif is_forced_by_global: # Backend override enforced by global variable takes # precedence over Aphrodite backend environment variable. if maybe_global_forced_backend != _Backend.XFORMERS: raise_backend_err() elif is_forced_by_env_var: # Backend override enforced by Aphrodite backend # environment variable if maybe_env_var_forced_backend != _Backend.XFORMERS: raise_backend_err() def _list_to_int32_tensor( self, _list: List[int], ) -> torch.Tensor: return torch.tensor(_list, dtype=torch.int32, device=self.device) def _list_to_long_tensor( self, _list: List[int], ) -> torch.Tensor: return torch.tensor(_list, dtype=torch.long, device=self.device) def _empty_int32_tensor(self) -> torch.Tensor: return self._list_to_int32_tensor([]) def _empty_long_tensor(self) -> torch.Tensor: return self._list_to_long_tensor([]) @torch.inference_mode() def execute_model( self, model_input: EncoderDecoderModelInput, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[PoolerOutput]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in " "EncoderDecoderModelRunner") if (model_input.attn_metadata is not None and model_input.attn_metadata.prefill_metadata is None and model_input.attn_metadata.decode_metadata.use_cuda_graph): assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[ model_input.virtual_engine][graph_batch_size] else: model_executable = self.model seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, encoder_input_ids=model_input.encoder_input_tokens, encoder_positions=model_input.encoder_input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **seqlen_agnostic_kwargs) logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) if not self.is_driver_worker: return [] if model_input.async_callback is not None: model_input.async_callback() # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) return [output] def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput: return EncoderDecoderModelInput.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, ) def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> EncoderDecoderModelInput: """Prepare the model input based on a given sequence group, including metadata for the sampling step. Since chunked prefill is not supported for encoder/decoder models, `input_tokens` is assumed to be either entirely prefill tokens or entirely decode tokens. """ model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) ( attn_metadata, encoder_input_tokens_tensor, encoder_input_positions_tensor, ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, model_input)) # Inject attn_metadata encoder/cross-attention fields & # encoder input tokens/positions into model_input. # Frozen dataclass fields cannot be modified, so use # dataclasses.replace to construct a new model input # instance. model_input = dataclasses.replace( model_input, attn_metadata=attn_metadata, encoder_input_tokens=encoder_input_tokens_tensor, encoder_input_positions=encoder_input_positions_tensor, ) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, model_input.seq_lens, model_input.query_lens, self.device, self.pin_memory) is_prompt = (seq_group_metadata_list[0].is_prompt if seq_group_metadata_list else None) return dataclasses.replace(model_input, sampling_metadata=sampling_metadata, is_prompt=is_prompt, virtual_engine=virtual_engine) @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: raise NotImplementedError( "Multi-modal encoder-decoder models are not supported yet") batch_size = 0 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len seq_data, _ = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) # Having more tokens is over-conservative but otherwise fine assert len(seq_data.prompt_token_ids) >= seq_len, ( f"Expected at least {seq_len} dummy tokens for profiling, " f"but got: {len(seq_data.prompt_token_ids)}") seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, encoder_seq_data=seq_data, cross_block_table=None, ) seqs.append(seq) # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) intermediate_tensors = None self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() return def _prepare_encoder_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], model_input: EncoderDecoderModelInput, ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], Optional[torch.Tensor]]: """Helper method to prepare the encoder- and cross-attn-related model inputs based on a given sequence group. These additional inputs are used to augment an already-computed `EncoderDecoderModelInput` data structure which already has decoder-related model inputs populated. Sets the following attn_metadata fields: * `num_encoder_tokens` * `encoder_seq_lens` * `encoder_seq_lens_tensor` * `max_encoder_seq_len` * `cross_slot_mapping` * `cross_block_tables` Constructs a new model inputs data structure, based on (1) the existing fields in the `model_inputs` argument, and (2) the following additional fields which are computed (or in the case of `attn_metadata`, updated) by this function: * attn_metadata * encoder_input_tokens * encoder_input_positions Arguments: * seq_group_metadata_list: list of sequence groups for which to compute inputs * model_inputs: model inputs data structure with decoder-oriented fields already computed. Return: * Updated model inputs data structure """ if len(seq_group_metadata_list) == 0: return (model_input.attn_metadata, None, None) # Since we are not supporting chunked prefill either the entire # batch is prefill or it is decode is_prompt = seq_group_metadata_list[0].is_prompt # Build encoder inputs encoder_seq_lens: List[int] = [] if is_prompt: # Prefill phase. cross_block_tables = self._empty_int32_tensor().view( len(seq_group_metadata_list), -1) # Extract input tokens/positions, cross-attention slot-mapping, # & seq len from each sequence group metadata ( encoder_input_tokens, encoder_input_positions, cross_slot_mapping, ) = ( [], [], [], ) for seq_group_metadata in seq_group_metadata_list: # Build seq lens seq_len = seq_group_metadata.encoder_seq_data.get_len() token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() encoder_seq_lens.append(seq_len) # Build slot mapping is_profile_run = (seq_group_metadata.block_tables is None) if is_profile_run: # During memory profiling, the block tables are not # initialized yet. In this case, we just use a dummy # slot mapping. # In embeddings, the block tables are {seq_id: None}. cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len) else: for i in range(0, seq_len): block_number = seq_group_metadata.cross_block_table[ i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset cross_slot_mapping.append(slot) # Build encoder input tokens encoder_input_tokens.extend(token_ids) encoder_input_positions.extend(list(range(0, seq_len))) # Convert tokens/positions & cross-attention # slot-mapping to encoder input tensors encoder_input_tokens_tensor = self._list_to_long_tensor( encoder_input_tokens) encoder_input_positions_tensor = self._list_to_long_tensor( encoder_input_positions) cross_slot_mapping_tensor = self._list_to_long_tensor( cross_slot_mapping) else: # Decode phase. encoder_input_tokens_tensor = self._empty_long_tensor() encoder_input_positions_tensor = self._empty_long_tensor() cross_slot_mapping_tensor = self._empty_long_tensor() # Extract cross-attention block tables & # seq len from each sequence group metadata. # Cross-attention block tables are empty # during Aphrodite memory profiling. cross_block_tables = [] for seq_group_metadata in seq_group_metadata_list: encoder_seq_lens.append( seq_group_metadata.encoder_seq_data.get_len()) cross_block_table = seq_group_metadata.cross_block_table cross_block_tables.append([] if ( cross_block_table is None) else cross_block_table) if (model_input.attn_metadata is not None and model_input.attn_metadata.use_cuda_graph): # We will be using CUDA graph replay for this decode. max_len_of_block_table = self.get_max_block_per_batch() batch_size = len(encoder_seq_lens) graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size cuda_graph_pad_size = graph_batch_size - batch_size # extend the cross_block_tables and encoder_seq_lens to match # the graph_batch_size. cross_block_tables.extend([[] for _ in range(cuda_graph_pad_size) ]) encoder_seq_lens.extend( itertools.repeat(1, cuda_graph_pad_size)) else: max_len_of_block_table = max( len(block_table) for block_table in cross_block_tables) cross_block_tables = make_tensor_with_pad( cross_block_tables, max_len=max_len_of_block_table, pad=0, dtype=torch.int32, device=self.device, ) # Compute encoder sequence lengths & encoder # sequence starting offset tensors max_encoder_seq_len = max(encoder_seq_lens, default=0) encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) torch.cumsum(encoder_seq_lens_tensor, dim=0, dtype=encoder_seq_start_loc.dtype, out=encoder_seq_start_loc[1:]) # Update attention metadata with encoder-oriented attributes attn_metadata = model_input.attn_metadata assert attn_metadata is not None ( attn_metadata.num_encoder_tokens, attn_metadata.encoder_seq_lens, attn_metadata.encoder_seq_lens_tensor, attn_metadata.max_encoder_seq_len, attn_metadata.cross_slot_mapping, attn_metadata.cross_block_tables, ) = ( sum(encoder_seq_lens), encoder_seq_lens, encoder_seq_lens_tensor, max_encoder_seq_len, cross_slot_mapping_tensor, cross_block_tables, ) return (attn_metadata, encoder_input_tokens_tensor, encoder_input_positions_tensor)