import itertools from typing import List import pytest import torch from aphrodite.common.sequence import (SamplingParams, SequenceData, SequenceGroupMetadata) from aphrodite.common.utils import is_cpu, make_tensor_with_pad from aphrodite.engine.args_tools import EngineArgs from aphrodite.worker.enc_dec_model_runner import EncoderDecoderModelRunner from aphrodite.worker.model_runner import _get_graph_batch_size BATCH_SIZES = [1, 4, 16, 64, 256] def _create_model_runner(model: str, *args, **kwargs) -> EncoderDecoderModelRunner: engine_args = EngineArgs(model, *args, **kwargs) engine_config = engine_args.create_engine_config() model_runner = EncoderDecoderModelRunner( model_config=engine_config.model_config, parallel_config=engine_config.parallel_config, scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, cache_config=engine_config.cache_config, load_config=engine_config.load_config, lora_config=engine_config.lora_config, prompt_adapter_config=engine_config.prompt_adapter_config, is_driver_worker=True, ) return model_runner @pytest.mark.skipif(condition=is_cpu(), reason="CPU backend is currently " "unsupported for encoder/ " "decoder models") def test_empty_seq_group(): """Verify prepare prompt and decode returns empty output for empty seq group list""" model_runner = _create_model_runner( "facebook/bart-base", seed=0, dtype="float16", max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, enforce_eager=True, ) seq_group_metadata_list: List[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) ( input_tokens, input_positions, encoder_input_tokens, encoder_input_positions, attn_metadata, return_seq_lens, ) = ( model_input.input_tokens, model_input.input_positions, model_input.encoder_input_tokens, model_input.encoder_input_positions, model_input.attn_metadata, model_input.seq_lens, ) assert input_tokens is None assert input_positions is None assert encoder_input_tokens is None assert encoder_input_positions is None assert attn_metadata is None assert return_seq_lens is None @pytest.mark.skipif(condition=is_cpu(), reason="CPU backend is currently " "unsupported for encoder/ " "decoder models") @pytest.mark.parametrize("batch_size", BATCH_SIZES) def test_prepare_prompt(batch_size): ''' Test the ability of the encoder/decoder model runner subclass to produce prefill-phase model inputs & attention metadata. Test behavior: * Instantiate BART base model & enc/dec model runner * Construct sequence-group metadata for dummy prompts * Test that encoder attention, decoder self-attention, and encoder/decoder cross-attention inputs are correct Arguments: * batch_size * backend_name: The attention backend under test * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) ''' model_runner = _create_model_runner( "facebook/bart-base", seed=0, dtype="float16", max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, enforce_eager=True, ) seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] block_tables = {0: [1]} cross_block_table = [2] for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) seq_data = SequenceData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_lens.append(encoder_seq_len) encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, seq_data={0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, ) assert seq_group_metadata.token_chunk_size == seq_data.get_len() seq_group_metadata_list.append(seq_group_metadata) # Build # * Decoder model inputs # * Decoder self-attention KV caching data structures # * Encoder model inputs # * Encoder/decoder cross-attention KV caching data structures model_input = model_runner.prepare_model_input(seq_group_metadata_list) input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata return_seq_lens = model_input.seq_lens slot_mapping = attn_metadata.slot_mapping encoder_input_tokens = model_input.encoder_input_tokens encoder_input_positions = model_input.encoder_input_positions cross_slot_mapping = attn_metadata.cross_slot_mapping assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) assert len(cross_slot_mapping) == len(encoder_input_tokens) # Verify input metadata is correct for prompts. # - Decoder attention metadata device = model_runner.device assert attn_metadata.num_prefills > 0 assert attn_metadata.num_decode_tokens == 0 assert torch.equal(attn_metadata.seq_lens_tensor, torch.tensor(seq_lens, device=device, dtype=torch.int)) assert attn_metadata.seq_lens == seq_lens assert attn_metadata.max_prefill_seq_len == max(seq_lens) assert attn_metadata.max_decode_seq_len == 0 # - Encoder attention metadata assert attn_metadata.encoder_seq_lens == encoder_seq_lens assert torch.equal( attn_metadata.encoder_seq_lens_tensor, torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) # Test decoder subquery start locs. start_idx = 0 start_loc = [start_idx] for seq_len in seq_lens: start_idx += seq_len start_loc.append(start_idx) assert torch.equal( attn_metadata.query_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device), ) # Test decoder seq start locs & context lengths assert torch.equal( attn_metadata.seq_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device), ) assert torch.equal( attn_metadata.context_lens_tensor, torch.zeros(attn_metadata.context_lens_tensor.shape[0], dtype=torch.int, device=device), ) # Verify block tables are correct for prompts # - Decoder self-attention expected = torch.tensor( [[] for _ in range(len(seq_group_metadata_list))], dtype=torch.int32, device=model_runner.device, ) assert torch.equal( attn_metadata.block_tables, expected, ) # - Encoder/decoder cross-attention assert torch.equal( attn_metadata.cross_block_tables, expected, ) # Cuda graph should not be used for prefill. assert attn_metadata.use_cuda_graph is False # Verify the lengths of input tokens & positions # - Decoder assert len(input_tokens) == sum(seq_lens) assert len(input_positions) == sum(seq_lens) # -- An indirect check that model_input.input_tokens # and model_input.input_positions are correct - # by design of the test, the input tokens are # equal to the input position values, so if # the model_input data structure has the correct # values then these two should be equal assert torch.equal( input_tokens, input_positions, ) # - Encoder assert len(encoder_input_tokens) == sum(encoder_seq_lens) # -- An indirect check that model_input.encoder_input_tokens # and model_input.encoder_input_positions are correct - # by design of the test, the input tokens are # equal to the input position values, so if # the model_input data structure has the correct # values then these two should be equal assert torch.equal( encoder_input_tokens, encoder_input_positions, ) # Test that Aphrodite sampling infrastructure chooses the correct # sequence positions at which to sample (i.e. the end of # each sequence) in the prefill phase expected_selected_token_indices = [] selected_token_start_idx = 0 for seq_len in seq_lens: # Compute the index offset of the final token in each # prompt (recall that the prompts are concatenated) expected_selected_token_indices.append(selected_token_start_idx + seq_len - 1) selected_token_start_idx += seq_len sampling_metadata = model_input.sampling_metadata actual = sampling_metadata.selected_token_indices expected = torch.tensor( expected_selected_token_indices, device=actual.device, dtype=actual.dtype, ) assert torch.equal(actual, expected) @pytest.mark.skipif(condition=is_cpu(), reason="CPU backend is currently " "unsupported for encoder/ " "decoder models") @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): ''' Test the ability of the encoder/decoder model runner subclass to produce decode-phase model inputs & attention metadata. Test behavior: * Instantiate BART base model & enc/dec model runner * Construct sequence-group metadata for dummy prompts * Test that encoder attention, decoder self-attention, and encoder/decoder cross-attention inputs are correct Arguments: * batch_size * multiple_seqs_per_seq_group * backend_name: The attention backend under test * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) ''' model_runner = _create_model_runner( "facebook/bart-base", seed=0, dtype="float16", max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, enforce_eager=True, ) seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] block_tables = { 0: [1], 1: [3] } if multiple_seqs_per_seq_group else { 0: [1] } cross_block_table = [2] for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_data = SequenceData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, seq_data={ 0: seq_data, 1: seq_data } if multiple_seqs_per_seq_group else {0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, ) assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) seq_lens.extend( [seq_len for _ in range(len(seq_group_metadata.seq_data))]) encoder_seq_lens.extend( [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) # Build # * Decoder model inputs # * Decoder self-attention KV caching data structures # * Encoder model inputs # * Encoder/decoder cross-attention KV caching data structures model_input = model_runner.prepare_model_input(seq_group_metadata_list) input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata return_seq_lens = model_input.seq_lens slot_mapping = attn_metadata.slot_mapping encoder_input_tokens = model_input.encoder_input_tokens encoder_input_positions = model_input.encoder_input_positions cross_slot_mapping = attn_metadata.cross_slot_mapping assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) assert len(cross_slot_mapping) == len(encoder_input_tokens) # Verify input metadata is correct for decode phase. # - Decoder attention metadata device = model_runner.device assert attn_metadata.num_prefills == 0 assert attn_metadata.num_decode_tokens > 0 assert torch.equal(attn_metadata.seq_lens_tensor, torch.tensor(seq_lens, device=device, dtype=torch.int)) assert attn_metadata.seq_lens == seq_lens assert attn_metadata.max_prefill_seq_len == 0 assert attn_metadata.max_decode_seq_len == max(seq_lens) # - Encoder attention metadata assert attn_metadata.encoder_seq_lens == encoder_seq_lens assert torch.equal( attn_metadata.encoder_seq_lens_tensor, torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) # Test decoder subquery start locs. start_idx = 0 start_loc = [start_idx] for seq_len in seq_lens: start_idx += 1 start_loc.append(start_idx) assert torch.equal( attn_metadata.query_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device), ) # Test decoder seq start locs. Note that for normal prefill it is # equivalent to query_start_loc. start_idx = 0 seq_start_loc = [start_idx] for seq_len in seq_lens: start_idx += seq_len seq_start_loc.append(start_idx) # Test seq_start_loc and context lengths assert torch.equal( attn_metadata.seq_start_loc, torch.tensor(seq_start_loc, dtype=torch.int32, device=device), ) assert torch.equal( attn_metadata.context_lens_tensor, torch.tensor([seq_len - 1 for seq_len in seq_lens], dtype=torch.int, device=device)) # Verify block tables are correct for prompts # - Decoder self-attention flattened_block_tables = [ block_table for block_table in block_tables.values() ] expected = torch.tensor(flattened_block_tables * len(seq_group_metadata_list), dtype=torch.int32, device=model_runner.device) assert torch.equal( attn_metadata.block_tables, expected, ) # - Encoder/decoder cross-attention expected = torch.tensor([ cross_block_table for seq_group_metadata in seq_group_metadata_list for _ in range(len(seq_group_metadata.seq_data)) ], dtype=torch.int32, device=model_runner.device) assert torch.equal( attn_metadata.cross_block_tables, expected, ) # Model runner's CUDAGraph setting should be propagated to attention # metadata. assert attn_metadata.use_cuda_graph is False # Verify the lengths of input tokens & positions # - Decoder assert len(input_tokens) == len(seq_lens) assert len(input_positions) == len(seq_lens) # -- An indirect check that model_input.input_tokens # and model_input.input_positions are correct - # by design of the test, the input tokens are # equal to the input position values, so if # the model_input data structure has the correct # values then these two should be equal assert torch.equal( input_tokens, input_positions, ) # - Encoder assert len(encoder_input_tokens) == 0 assert len(encoder_input_tokens) == 0 # -- An indirect check that model_input.encoder_input_tokens # and model_input.encoder_input_positions are correct - # by design of the test, the input tokens are # equal to the input position values, so if # the model_input data structure has the correct # values then these two should be equal assert torch.equal( encoder_input_tokens, encoder_input_positions, ) # Test that Aphrodite sampling infrastructure chooses the correct # sequence positions at which to sample (i.e. the end of # each sequence) in the decode phase expected_selected_token_indices = [] for selected_token_start_idx, seq_len in enumerate(seq_lens): # Compute the index offset of the final token in each # sequence's decoded outputs; since a single token is # decoded per iteration per sequence, then the length # of the decoded tokens for a given sequence is 1 and # the final index offset into a given sequence's # generated tokens is 0 (i.e. the expected sampling index # for a given sequence is just `selected_token_start_idx`) expected_selected_token_indices.append(selected_token_start_idx) sampling_metadata = model_input.sampling_metadata actual = sampling_metadata.selected_token_indices expected = torch.tensor( expected_selected_token_indices, device=actual.device, dtype=actual.dtype, ) assert torch.equal(actual, expected) @pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): """ Tests that for encoder-decoder models with CUDA Graph capture and replay enabled, the tensors used during the decode phase are correctly padded for varying input batch sizes. """ model_runner = _create_model_runner( "facebook/bart-base", seed=0, dtype="float16", max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, enforce_eager=False, ) block_tables = { 0: [1], 1: [3] } if multiple_seqs_per_seq_group else { 0: [1] } seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] cross_block_table = [2] expanded_batch_size = 0 for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_data = SequenceData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, seq_data={ 0: seq_data, 1: seq_data } if multiple_seqs_per_seq_group else {0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, ) assert seq_group_metadata.token_chunk_size == 1 seq_lens.extend( [seq_len for _ in range(len(seq_group_metadata.seq_data))]) encoder_seq_lens.extend( [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) expanded_batch_size = expanded_batch_size + len( seq_group_metadata.seq_data) seq_group_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata return_seq_lens = model_input.seq_lens slot_mapping = attn_metadata.slot_mapping encoder_input_tokens = model_input.encoder_input_tokens encoder_input_positions = model_input.encoder_input_positions cross_slot_mapping = attn_metadata.cross_slot_mapping # With CUDA Graph capture and replay enabled, the decoder and encoder # input sequences will be padded. Create the expected padded tensors # accordingly. graph_batch_size = _get_graph_batch_size(expanded_batch_size) cuda_graph_pad_size = graph_batch_size - expanded_batch_size padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_encoder_seq_lens = encoder_seq_lens + list( itertools.repeat(1, cuda_graph_pad_size)) assert return_seq_lens == padded_seq_lens assert len(slot_mapping) == len(input_tokens) assert len(cross_slot_mapping) == len(encoder_input_tokens) # Verify attention metadata device = model_runner.device assert attn_metadata.num_prefills == 0 assert attn_metadata.num_decode_tokens > 0 assert torch.equal( attn_metadata.seq_lens_tensor, torch.tensor(padded_seq_lens, device=device, dtype=torch.int)) assert attn_metadata.seq_lens == padded_seq_lens assert attn_metadata.max_prefill_seq_len == 0 assert attn_metadata.max_decode_seq_len == max(seq_lens) # - Encoder attention metadata assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens assert torch.equal( attn_metadata.encoder_seq_lens_tensor, torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int)) assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens) assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens) # Verify block tables are correct for prompts # - Decoder self-attention. Pad the block tables as expected. flattened_block_tables = [ block_table for _ in range(len(seq_group_metadata_list)) for block_table in block_tables.values() ] flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)]) expected = make_tensor_with_pad( flattened_block_tables, max_len=64, pad=0, dtype=torch.int32, device=model_runner.device, ) assert torch.equal( attn_metadata.block_tables, expected, ) # - Encoder/decoder cross-attention. Pad the cross-attention block tables # as expected. expected = [ cross_block_table for seq_group_metadata in seq_group_metadata_list for _ in range(len(seq_group_metadata.seq_data)) ] expected.extend([[] for _ in range(cuda_graph_pad_size)]) expected = make_tensor_with_pad( expected, max_len=64, pad=0, dtype=torch.int32, device=model_runner.device, ) assert torch.equal( attn_metadata.cross_block_tables, expected, ) # Model runner's CUDAGraph setting should be propagated to attention # metadata. assert attn_metadata.use_cuda_graph is True # Verify the lengths of input tokens & positions # - Decoder assert len(input_tokens) == len(padded_seq_lens) assert len(input_positions) == len(padded_seq_lens) # -- An indirect check that model_input.input_tokens # and model_input.input_positions are correct - # by design of the test, the input tokens are # equal to the input position values, so if # the model_input data structure has the correct # values then these two should be equal assert torch.equal( input_tokens, input_positions, ) # - Encoder assert len(encoder_input_tokens) == 0 assert len(encoder_input_tokens) == 0 # -- An indirect check that model_input.encoder_input_tokens # and model_input.encoder_input_positions are correct - # by design of the test, the input tokens are # equal to the input position values, so if # the model_input data structure has the correct # values then these two should be equal assert torch.equal( encoder_input_tokens, encoder_input_positions, )