123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648 |
- import itertools
- from array import array
- from typing import List
- import pytest
- import torch
- from aphrodite.common.sequence import (SamplingParams, SequenceData,
- SequenceGroupMetadata)
- from aphrodite.common.utils import is_cpu, make_tensor_with_pad
- from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
- from aphrodite.engine.args_tools import EngineArgs
- from aphrodite.worker.enc_dec_model_runner import EncoderDecoderModelRunner
- from aphrodite.worker.model_runner import _get_graph_batch_size
- BATCH_SIZES = [1, 4, 16, 64, 256]
- def _create_model_runner(
- model: str, *args, **kwargs
- ) -> EncoderDecoderModelRunner:
- engine_args = EngineArgs(model, *args, **kwargs)
- engine_config = engine_args.create_engine_config()
- model_runner = EncoderDecoderModelRunner(
- model_config=engine_config.model_config,
- parallel_config=engine_config.parallel_config,
- scheduler_config=engine_config.scheduler_config,
- device_config=engine_config.device_config,
- cache_config=engine_config.cache_config,
- load_config=engine_config.load_config,
- lora_config=engine_config.lora_config,
- prompt_adapter_config=engine_config.prompt_adapter_config,
- is_driver_worker=True,
- )
- return model_runner
- @pytest.mark.skipif(
- condition=is_cpu(),
- reason="CPU backend is currently "
- "unsupported for encoder/ "
- "decoder models",
- )
- def test_empty_seq_group():
- """Verify prepare prompt and decode returns empty output
- for empty seq group list"""
- model_runner = _create_model_runner(
- "facebook/bart-base",
- seed=0,
- dtype="float16",
- max_num_batched_tokens=100000,
- max_num_seqs=100000,
- enable_chunked_prefill=False,
- enforce_eager=True,
- )
- seq_group_metadata_list: List[SequenceGroupMetadata] = []
- model_input = model_runner._prepare_model_input_tensors(
- seq_group_metadata_list
- )
- (
- input_tokens,
- input_positions,
- encoder_input_tokens,
- encoder_input_positions,
- attn_metadata,
- return_seq_lens,
- ) = (
- model_input.input_tokens,
- model_input.input_positions,
- model_input.encoder_input_tokens,
- model_input.encoder_input_positions,
- model_input.attn_metadata,
- model_input.seq_lens,
- )
- assert input_tokens is None
- assert input_positions is None
- assert encoder_input_tokens is None
- assert encoder_input_positions is None
- assert attn_metadata is None
- assert return_seq_lens is None
- @pytest.mark.skipif(
- condition=is_cpu(),
- reason="CPU backend is currently "
- "unsupported for encoder/ "
- "decoder models",
- )
- @pytest.mark.parametrize("batch_size", BATCH_SIZES)
- def test_prepare_prompt(batch_size):
- """
- Test the ability of the encoder/decoder model runner subclass to
- produce prefill-phase model inputs & attention metadata.
- Test behavior:
- * Instantiate BART base model & enc/dec model runner
- * Construct sequence-group metadata for dummy prompts
- * Test that encoder attention, decoder self-attention,
- and encoder/decoder cross-attention inputs are correct
- Arguments:
- * batch_size
- * backend_name: The attention backend under test
- * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
- """
- model_runner = _create_model_runner(
- "facebook/bart-base",
- seed=0,
- dtype="float16",
- max_num_batched_tokens=100000,
- max_num_seqs=100000,
- enable_chunked_prefill=False,
- enforce_eager=True,
- )
- seq_lens: List[int] = []
- encoder_seq_lens: List[int] = []
- seq_group_metadata_list: List[SequenceGroupMetadata] = []
- block_tables = {0: [1]}
- cross_block_table = [2]
- for i in range(batch_size):
- # make sure all tokens fit into one block
- seq_len = i % (model_runner.block_size - 1) + 1
- seq_lens.append(seq_len)
- seq_data = SequenceData(
- array(APHRODITE_TOKEN_ID_ARRAY_TYPE, range(seq_len))
- )
- encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
- encoder_seq_lens.append(encoder_seq_len)
- encoder_seq_data = SequenceData(
- array(APHRODITE_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len))
- )
- seq_group_metadata = SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={0: seq_data},
- sampling_params=SamplingParams(temperature=0),
- block_tables=block_tables,
- encoder_seq_data=encoder_seq_data,
- cross_block_table=cross_block_table,
- )
- assert seq_group_metadata.token_chunk_size == seq_data.get_len()
- seq_group_metadata_list.append(seq_group_metadata)
- # Build
- # * Decoder model inputs
- # * Decoder self-attention KV caching data structures
- # * Encoder model inputs
- # * Encoder/decoder cross-attention KV caching data structures
- model_input = model_runner.prepare_model_input(seq_group_metadata_list)
- input_tokens = model_input.input_tokens
- input_positions = model_input.input_positions
- attn_metadata = model_input.attn_metadata
- return_seq_lens = model_input.seq_lens
- slot_mapping = attn_metadata.slot_mapping
- encoder_input_tokens = model_input.encoder_input_tokens
- encoder_input_positions = model_input.encoder_input_positions
- cross_slot_mapping = attn_metadata.cross_slot_mapping
- assert return_seq_lens == seq_lens
- assert len(slot_mapping) == len(input_tokens)
- assert len(cross_slot_mapping) == len(encoder_input_tokens)
- # Verify input metadata is correct for prompts.
- # - Decoder attention metadata
- device = model_runner.device
- assert attn_metadata.num_prefills > 0
- assert attn_metadata.num_decode_tokens == 0
- assert torch.equal(
- attn_metadata.seq_lens_tensor,
- torch.tensor(seq_lens, device=device, dtype=torch.int),
- )
- assert attn_metadata.seq_lens == seq_lens
- assert attn_metadata.max_prefill_seq_len == max(seq_lens)
- assert attn_metadata.max_decode_seq_len == 0
- # - Encoder attention metadata
- assert attn_metadata.encoder_seq_lens == encoder_seq_lens
- assert torch.equal(
- attn_metadata.encoder_seq_lens_tensor,
- torch.tensor(encoder_seq_lens, device=device, dtype=torch.int),
- )
- assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
- assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
- # Test decoder subquery start locs.
- start_idx = 0
- start_loc = [start_idx]
- for seq_len in seq_lens:
- start_idx += seq_len
- start_loc.append(start_idx)
- assert torch.equal(
- attn_metadata.query_start_loc,
- torch.tensor(start_loc, dtype=torch.int32, device=device),
- )
- # Test decoder seq start locs & context lengths
- assert torch.equal(
- attn_metadata.seq_start_loc,
- torch.tensor(start_loc, dtype=torch.int32, device=device),
- )
- assert torch.equal(
- attn_metadata.context_lens_tensor,
- torch.zeros(
- attn_metadata.context_lens_tensor.shape[0],
- dtype=torch.int,
- device=device,
- ),
- )
- # Verify block tables are correct for prompts
- # - Decoder self-attention
- expected = torch.tensor(
- [[] for _ in range(len(seq_group_metadata_list))],
- dtype=torch.int32,
- device=model_runner.device,
- )
- assert torch.equal(
- attn_metadata.block_tables,
- expected,
- )
- # - Encoder/decoder cross-attention
- assert torch.equal(
- attn_metadata.cross_block_tables,
- expected,
- )
- # Cuda graph should not be used for prefill.
- assert attn_metadata.use_cuda_graph is False
- # Verify the lengths of input tokens & positions
- # - Decoder
- assert len(input_tokens) == sum(seq_lens)
- assert len(input_positions) == sum(seq_lens)
- # -- An indirect check that model_input.input_tokens
- # and model_input.input_positions are correct -
- # by design of the test, the input tokens are
- # equal to the input position values, so if
- # the model_input data structure has the correct
- # values then these two should be equal
- assert torch.equal(
- input_tokens,
- input_positions,
- )
- # - Encoder
- assert len(encoder_input_tokens) == sum(encoder_seq_lens)
- # -- An indirect check that model_input.encoder_input_tokens
- # and model_input.encoder_input_positions are correct -
- # by design of the test, the input tokens are
- # equal to the input position values, so if
- # the model_input data structure has the correct
- # values then these two should be equal
- assert torch.equal(
- encoder_input_tokens,
- encoder_input_positions,
- )
- # Test that Aphrodite sampling infrastructure chooses the correct
- # sequence positions at which to sample (i.e. the end of
- # each sequence) in the prefill phase
- expected_selected_token_indices = []
- selected_token_start_idx = 0
- for seq_len in seq_lens:
- # Compute the index offset of the final token in each
- # prompt (recall that the prompts are concatenated)
- expected_selected_token_indices.append(
- selected_token_start_idx + seq_len - 1
- )
- selected_token_start_idx += seq_len
- sampling_metadata = model_input.sampling_metadata
- actual = sampling_metadata.selected_token_indices
- expected = torch.tensor(
- expected_selected_token_indices,
- device=actual.device,
- dtype=actual.dtype,
- )
- assert torch.equal(actual, expected)
- @pytest.mark.skipif(
- condition=is_cpu(),
- reason="CPU backend is currently "
- "unsupported for encoder/ "
- "decoder models",
- )
- @pytest.mark.parametrize("batch_size", BATCH_SIZES)
- def test_prepare_decode(batch_size):
- """
- Test the ability of the encoder/decoder model runner subclass to
- produce decode-phase model inputs & attention metadata.
- Test behavior:
- * Instantiate BART base model & enc/dec model runner
- * Construct sequence-group metadata for dummy prompts
- * Test that encoder attention, decoder self-attention,
- and encoder/decoder cross-attention inputs are correct
- Arguments:
- * batch_size
- * backend_name: The attention backend under test
- * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
- """
- model_runner = _create_model_runner(
- "facebook/bart-base",
- seed=0,
- dtype="float16",
- max_num_batched_tokens=100000,
- max_num_seqs=100000,
- enable_chunked_prefill=False,
- enforce_eager=True,
- )
- seq_lens: List[int] = []
- encoder_seq_lens: List[int] = []
- seq_group_metadata_list: List[SequenceGroupMetadata] = []
- block_tables = {0: [1]}
- cross_block_table = [2]
- for i in range(batch_size):
- # make sure all tokens fit into one block
- seq_len = i % (model_runner.block_size - 1) + 1
- seq_lens.append(seq_len)
- seq_data = SequenceData(
- array(APHRODITE_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))
- )
- encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
- encoder_seq_lens.append(encoder_seq_len)
- encoder_seq_data = SequenceData(
- array(APHRODITE_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))
- )
- seq_group_metadata = SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=False,
- seq_data={0: seq_data},
- sampling_params=SamplingParams(temperature=0),
- block_tables=block_tables,
- encoder_seq_data=encoder_seq_data,
- cross_block_table=cross_block_table,
- )
- assert seq_group_metadata.token_chunk_size == 1
- seq_group_metadata_list.append(seq_group_metadata)
- # Build
- # * Decoder model inputs
- # * Decoder self-attention KV caching data structures
- # * Encoder model inputs
- # * Encoder/decoder cross-attention KV caching data structures
- model_input = model_runner.prepare_model_input(seq_group_metadata_list)
- input_tokens = model_input.input_tokens
- input_positions = model_input.input_positions
- attn_metadata = model_input.attn_metadata
- return_seq_lens = model_input.seq_lens
- slot_mapping = attn_metadata.slot_mapping
- encoder_input_tokens = model_input.encoder_input_tokens
- encoder_input_positions = model_input.encoder_input_positions
- cross_slot_mapping = attn_metadata.cross_slot_mapping
- assert return_seq_lens == seq_lens
- assert len(slot_mapping) == len(input_tokens)
- assert len(cross_slot_mapping) == len(encoder_input_tokens)
- # Verify input metadata is correct for decode phase.
- # - Decoder attention metadata
- device = model_runner.device
- assert attn_metadata.num_prefills == 0
- assert attn_metadata.num_decode_tokens > 0
- assert torch.equal(
- attn_metadata.seq_lens_tensor,
- torch.tensor(seq_lens, device=device, dtype=torch.int),
- )
- assert attn_metadata.seq_lens == seq_lens
- assert attn_metadata.max_prefill_seq_len == 0
- assert attn_metadata.max_decode_seq_len == max(seq_lens)
- # - Encoder attention metadata
- assert attn_metadata.encoder_seq_lens == encoder_seq_lens
- assert torch.equal(
- attn_metadata.encoder_seq_lens_tensor,
- torch.tensor(encoder_seq_lens, device=device, dtype=torch.int),
- )
- assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
- assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
- # Test decoder subquery start locs.
- start_idx = 0
- start_loc = [start_idx]
- for seq_len in seq_lens:
- start_idx += 1
- start_loc.append(start_idx)
- assert torch.equal(
- attn_metadata.query_start_loc,
- torch.tensor(start_loc, dtype=torch.int32, device=device),
- )
- # Test decoder seq start locs. Note that for normal prefill it is
- # equivalent to query_start_loc.
- start_idx = 0
- seq_start_loc = [start_idx]
- for seq_len in seq_lens:
- start_idx += seq_len
- seq_start_loc.append(start_idx)
- # Test seq_start_loc and context lengths
- assert torch.equal(
- attn_metadata.seq_start_loc,
- torch.tensor(seq_start_loc, dtype=torch.int32, device=device),
- )
- assert torch.equal(
- attn_metadata.context_lens_tensor,
- torch.tensor(
- [seq_len - 1 for seq_len in seq_lens],
- dtype=torch.int,
- device=device,
- ),
- )
- # Verify block tables are correct for prompts
- # - Decoder self-attention
- expected = torch.tensor(
- [block_tables[0] for _ in range(len(seq_group_metadata_list))],
- dtype=torch.int32,
- device=model_runner.device,
- )
- assert torch.equal(
- attn_metadata.block_tables,
- expected,
- )
- # - Encoder/decoder cross-attention
- expected = torch.tensor(
- [cross_block_table for _ in range(len(seq_group_metadata_list))],
- dtype=torch.int32,
- device=model_runner.device,
- )
- assert torch.equal(
- attn_metadata.cross_block_tables,
- expected,
- )
- # Model runner's CUDAGraph setting should be propagated to attention
- # metadata.
- assert attn_metadata.use_cuda_graph is False
- # Verify the lengths of input tokens & positions
- # - Decoder
- assert len(input_tokens) == len(seq_lens)
- assert len(input_positions) == len(seq_lens)
- # -- An indirect check that model_input.input_tokens
- # and model_input.input_positions are correct -
- # by design of the test, the input tokens are
- # equal to the input position values, so if
- # the model_input data structure has the correct
- # values then these two should be equal
- assert torch.equal(
- input_tokens,
- input_positions,
- )
- # - Encoder
- assert len(encoder_input_tokens) == 0
- assert len(encoder_input_tokens) == 0
- # -- An indirect check that model_input.encoder_input_tokens
- # and model_input.encoder_input_positions are correct -
- # by design of the test, the input tokens are
- # equal to the input position values, so if
- # the model_input data structure has the correct
- # values then these two should be equal
- assert torch.equal(
- encoder_input_tokens,
- encoder_input_positions,
- )
- # Test that Aphrodite sampling infrastructure chooses the correct
- # sequence positions at which to sample (i.e. the end of
- # each sequence) in the decode phase
- expected_selected_token_indices = []
- selected_token_start_idx = 0
- for seq_len in seq_lens:
- # Compute the index offset of the final token in each
- # sequence's decoded outputs; since a single token is
- # decoded per iteration per sequence, then the length
- # of the decoded tokens for a given sequence is 1 and
- # the final index offset into a given sequence's
- # generated tokens is 0 (i.e. the expected sampling index
- # for a given sequence is just `selected_token_start_idx`)
- expected_selected_token_indices.append(selected_token_start_idx)
- selected_token_start_idx += 1
- sampling_metadata = model_input.sampling_metadata
- actual = sampling_metadata.selected_token_indices
- expected = torch.tensor(
- expected_selected_token_indices,
- device=actual.device,
- dtype=actual.dtype,
- )
- assert torch.equal(actual, expected)
- @pytest.mark.parametrize("batch_size", list(range(1, 257)))
- def test_prepare_decode_cuda_graph(batch_size):
- """
- Tests that for encoder-decoder models with CUDA Graph capture and replay
- enabled, the tensors used during the decode phase are correctly padded
- for varying input batch sizes.
- """
- model_runner = _create_model_runner(
- "facebook/bart-base",
- seed=0,
- dtype="float16",
- max_num_batched_tokens=100000,
- max_num_seqs=100000,
- enable_chunked_prefill=False,
- enforce_eager=False,
- )
- seq_lens: List[int] = []
- encoder_seq_lens: List[int] = []
- seq_group_metadata_list: List[SequenceGroupMetadata] = []
- block_tables = {0: [1]}
- cross_block_table = [2]
- for i in range(batch_size):
- # make sure all tokens fit into one block
- seq_len = i % (model_runner.block_size - 1) + 1
- seq_lens.append(seq_len)
- seq_data = SequenceData(
- array(APHRODITE_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))
- )
- encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
- encoder_seq_lens.append(encoder_seq_len)
- encoder_seq_data = SequenceData(
- array(APHRODITE_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))
- )
- seq_group_metadata = SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=False,
- seq_data={0: seq_data},
- sampling_params=SamplingParams(temperature=0),
- block_tables=block_tables,
- encoder_seq_data=encoder_seq_data,
- cross_block_table=cross_block_table,
- )
- assert seq_group_metadata.token_chunk_size == 1
- seq_group_metadata_list.append(seq_group_metadata)
- model_input = model_runner.prepare_model_input(seq_group_metadata_list)
- input_tokens = model_input.input_tokens
- input_positions = model_input.input_positions
- attn_metadata = model_input.attn_metadata
- return_seq_lens = model_input.seq_lens
- slot_mapping = attn_metadata.slot_mapping
- encoder_input_tokens = model_input.encoder_input_tokens
- encoder_input_positions = model_input.encoder_input_positions
- cross_slot_mapping = attn_metadata.cross_slot_mapping
- # With CUDA Graph capture and replay enabled, the decoder and encoder
- # input sequences will be padded. Create the expected padded tensors
- # accordingly.
- graph_batch_size = _get_graph_batch_size(batch_size)
- cuda_graph_pad_size = graph_batch_size - batch_size
- padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
- padded_encoder_seq_lens = encoder_seq_lens + list(
- itertools.repeat(1, cuda_graph_pad_size)
- )
- assert return_seq_lens == padded_seq_lens
- assert len(slot_mapping) == len(input_tokens)
- assert len(cross_slot_mapping) == len(encoder_input_tokens)
- # Verify attention metadata
- device = model_runner.device
- assert attn_metadata.num_prefills == 0
- assert attn_metadata.num_decode_tokens > 0
- assert torch.equal(
- attn_metadata.seq_lens_tensor,
- torch.tensor(padded_seq_lens, device=device, dtype=torch.int),
- )
- assert attn_metadata.seq_lens == padded_seq_lens
- assert attn_metadata.max_prefill_seq_len == 0
- assert attn_metadata.max_decode_seq_len == max(seq_lens)
- # - Encoder attention metadata
- assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens
- assert torch.equal(
- attn_metadata.encoder_seq_lens_tensor,
- torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int),
- )
- assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens)
- assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens)
- # Verify block tables are correct for prompts
- # - Decoder self-attention. Pad the block tables as expected.
- expected = [block_tables[0] for _ in range(batch_size)]
- expected.extend([[] for _ in range(cuda_graph_pad_size)])
- expected = make_tensor_with_pad(
- expected,
- max_len=64,
- pad=0,
- dtype=torch.int32,
- device=model_runner.device,
- )
- assert torch.equal(
- attn_metadata.block_tables,
- expected,
- )
- # - Encoder/decoder cross-attention. Pad the cross-attention block tables
- # as expected.
- expected = [cross_block_table for _ in range(len(seq_group_metadata_list))]
- expected.extend([[] for _ in range(cuda_graph_pad_size)])
- expected = make_tensor_with_pad(
- expected,
- max_len=64,
- pad=0,
- dtype=torch.int32,
- device=model_runner.device,
- )
- assert torch.equal(
- attn_metadata.cross_block_tables,
- expected,
- )
- # Model runner's CUDAGraph setting should be propagated to attention
- # metadata.
- assert attn_metadata.use_cuda_graph is True
- # Verify the lengths of input tokens & positions
- # - Decoder
- assert len(input_tokens) == len(padded_seq_lens)
- assert len(input_positions) == len(padded_seq_lens)
- # -- An indirect check that model_input.input_tokens
- # and model_input.input_positions are correct -
- # by design of the test, the input tokens are
- # equal to the input position values, so if
- # the model_input data structure has the correct
- # values then these two should be equal
- assert torch.equal(
- input_tokens,
- input_positions,
- )
- # - Encoder
- assert len(encoder_input_tokens) == 0
- assert len(encoder_input_tokens) == 0
- # -- An indirect check that model_input.encoder_input_tokens
- # and model_input.encoder_input_positions are correct -
- # by design of the test, the input tokens are
- # equal to the input position values, so if
- # the model_input data structure has the correct
- # values then these two should be equal
- assert torch.equal(
- encoder_input_tokens,
- encoder_input_positions,
- )
|