import dataclasses from typing import List, Tuple, Type import torch from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder from aphrodite.attention.backends.abstract import AttentionBackend from aphrodite.modeling.pooling_metadata import PoolingMetadata from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.worker.embedding_model_runner import ( ModelInputForGPUWithPoolingMetadata) from aphrodite.worker.model_runner import ModelInputForGPUWithSamplingMetadata from aphrodite.worker.multi_step_model_runner import StatefulModelInput class MockAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: raise NotImplementedError @staticmethod def get_impl_cls(): raise NotImplementedError @staticmethod def get_metadata_cls() -> Type["AttentionMetadata"]: return AttentionMetadata @staticmethod def get_builder_cls() -> Type["AttentionMetadataBuilder"]: raise AttentionMetadataBuilder @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: raise NotImplementedError @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: pass @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: pass def test_model_runner_input(): sampling_metadata = SamplingMetadata( ["seq_group"], "selected_token_indices", "categorized_sample_indices", "num_prompts", ) attn_metadata = AttentionMetadata( num_prefills=1, num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), ) model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), input_positions=torch.ones(10), sampling_metadata=sampling_metadata, attn_metadata=attn_metadata) assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata) # Test round trip serialization. tensor_dict = model_input.as_broadcastable_tensor_dict() attn_backend = MockAttentionBackend() received_model_input = ( ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( tensor_dict, attn_backend=attn_backend)) # Check that received copy has correct values. assert isinstance(received_model_input, ModelInputForGPUWithSamplingMetadata) assert received_model_input.input_tokens is not None assert ( received_model_input.input_tokens == model_input.input_tokens).all() assert received_model_input.input_positions is not None assert (received_model_input.input_positions == model_input.input_positions ).all() assert received_model_input.multi_modal_kwargs is None assert (received_model_input.multi_modal_kwargs == model_input.multi_modal_kwargs) assert received_model_input.lora_requests is None assert received_model_input.lora_requests == model_input.lora_requests assert received_model_input.lora_mapping is None assert received_model_input.lora_mapping == model_input.lora_mapping for field in dataclasses.fields(AttentionMetadata): assert getattr(received_model_input.attn_metadata, field.name, None) == getattr(attn_metadata, field.name, None) # For sampling metadata, only selected_token_indices is copied. assert (received_model_input.sampling_metadata.selected_token_indices == sampling_metadata.selected_token_indices) assert received_model_input.sampling_metadata.seq_groups is None def test_embedding_model_runner_input(): pooling_metadata = PoolingMetadata( seq_groups=[[0]], seq_data={}, prompt_lens=[1], ) attn_metadata = AttentionMetadata( num_prefills=1, num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), ) model_input = ModelInputForGPUWithPoolingMetadata( input_tokens=torch.ones(10), input_positions=torch.ones(10), pooling_metadata=pooling_metadata, attn_metadata=attn_metadata) assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata) # Test round trip serialization. tensor_dict = model_input.as_broadcastable_tensor_dict() attn_backend = MockAttentionBackend() received_model_input = ( ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( tensor_dict, attn_backend=attn_backend)) # Check that received copy has correct values. assert isinstance(received_model_input, ModelInputForGPUWithPoolingMetadata) assert received_model_input.input_tokens is not None assert ( received_model_input.input_tokens == model_input.input_tokens).all() assert received_model_input.input_positions is not None assert (received_model_input.input_positions == model_input.input_positions ).all() assert received_model_input.multi_modal_kwargs is None assert (received_model_input.multi_modal_kwargs == model_input.multi_modal_kwargs) assert received_model_input.lora_requests is None assert received_model_input.lora_requests == model_input.lora_requests assert received_model_input.lora_mapping is None assert received_model_input.lora_mapping == model_input.lora_mapping for field in dataclasses.fields(AttentionMetadata): assert getattr(received_model_input.attn_metadata, field.name, None) == getattr(attn_metadata, field.name, None) # Pooling metadata is not broadcast. assert received_model_input.pooling_metadata is None def test_multi_step_model_runner_input(): sampling_metadata = SamplingMetadata( ["seq_group"], "selected_token_indices", "categorized_sample_indices", "num_prompts", ) attn_metadata = AttentionMetadata( num_prefills=1, num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), ) frozen_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), input_positions=torch.ones(10), sampling_metadata=sampling_metadata, attn_metadata=attn_metadata) model_input = StatefulModelInput( frozen_model_input=frozen_model_input, is_last_step=True, is_first_multi_step=False, current_step=4, last_sampled_token_ids=torch.ones((10, 1)), is_multi_step=True, num_queries=8, num_seqs=5, cached_outputs=[], ) assert isinstance(model_input, StatefulModelInput) # Test round trip serialization. tensor_dict = model_input.as_broadcastable_tensor_dict() attn_backend = MockAttentionBackend() received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict( tensor_dict, attn_backend=attn_backend)) receieved_frozen_input = received_model_input.frozen_model_input # Check that received copy has correct values. assert isinstance(received_model_input, StatefulModelInput) assert receieved_frozen_input.input_tokens is not None assert (receieved_frozen_input.input_tokens == frozen_model_input.input_tokens).all() assert receieved_frozen_input.input_positions is not None assert (receieved_frozen_input.input_positions == frozen_model_input.input_positions).all() assert receieved_frozen_input.multi_modal_kwargs is None assert (frozen_model_input.multi_modal_kwargs == frozen_model_input.multi_modal_kwargs) assert receieved_frozen_input.lora_requests is None assert (receieved_frozen_input.lora_requests == frozen_model_input.lora_requests) assert receieved_frozen_input.lora_mapping is None assert ( receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping) for field in dataclasses.fields(AttentionMetadata): assert getattr(receieved_frozen_input.attn_metadata, field.name, None) == getattr(attn_metadata, field.name, None) # For sampling metadata, only selected_token_indices is copied. assert (receieved_frozen_input.sampling_metadata.selected_token_indices == sampling_metadata.selected_token_indices) assert receieved_frozen_input.sampling_metadata.seq_groups is None # check non frozen fields assert received_model_input.is_last_step == model_input.is_last_step assert (received_model_input.is_first_multi_step == model_input.is_first_multi_step) assert received_model_input.current_step == model_input.current_step assert (received_model_input.last_sampled_token_ids == model_input.last_sampled_token_ids).all() assert received_model_input.is_multi_step == model_input.is_multi_step