123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383 |
- from typing import List
- import pytest
- import torch
- from aphrodite.common.sequence import (SamplingParams, SequenceData,
- SequenceGroupMetadata)
- from aphrodite.common.utils import get_open_port
- from aphrodite.distributed.parallel_state import (
- ensure_model_parallel_initialized, init_distributed_environment)
- from aphrodite.engine.args_tools import EngineArgs
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.worker.model_runner import ModelRunner, _get_graph_batch_size
- def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
- engine_args = EngineArgs(model, *args, **kwargs)
- engine_config = engine_args.create_engine_config()
- model_runner = ModelRunner(
- model_config=engine_config.model_config,
- parallel_config=engine_config.parallel_config,
- scheduler_config=engine_config.scheduler_config,
- device_config=engine_config.device_config,
- cache_config=engine_config.cache_config,
- load_config=engine_config.load_config,
- lora_config=engine_config.lora_config,
- prompt_adapter_config=engine_config.prompt_adapter_config,
- is_driver_worker=True,
- )
- return model_runner
- @pytest.mark.parametrize("batch_size", list(range(1, 257)))
- def test_prepare_prompt(batch_size):
- model_runner = _create_model_runner(
- "facebook/opt-125m",
- max_num_batched_tokens=100000,
- max_num_seqs=100000,
- enable_chunked_prefill=False,
- )
- seq_lens: List[int] = []
- seq_group_metadata_list: List[SequenceGroupMetadata] = []
- block_tables = {0: [1]}
- for i in range(batch_size):
- # make sure all tokens fit into one block
- seq_len = i % (model_runner.block_size - 1) + 1
- seq_lens.append(seq_len)
- seq_data = SequenceData.from_seqs(range(seq_len))
- seq_group_metadata = SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={0: seq_data},
- sampling_params=SamplingParams(temperature=0),
- block_tables=block_tables,
- )
- assert seq_group_metadata.token_chunk_size == seq_data.get_len()
- seq_group_metadata_list.append(seq_group_metadata)
- expected_selected_token_indices = []
- selected_token_start_idx = 0
- for seq_len in seq_lens:
- expected_selected_token_indices.append(selected_token_start_idx +
- seq_len - 1)
- selected_token_start_idx += seq_len
- model_input = model_runner._prepare_model_input_tensors(
- seq_group_metadata_list)
- input_tokens = model_input.input_tokens
- input_positions = model_input.input_positions
- attn_metadata = model_input.attn_metadata
- return_seq_lens = model_input.seq_lens
- slot_mapping = attn_metadata.slot_mapping
- assert return_seq_lens == seq_lens
- assert len(slot_mapping) == len(input_tokens)
- # Verify input metadata is correct for prompts.
- device = model_runner.device
- assert attn_metadata.num_prefills > 0
- assert attn_metadata.num_decode_tokens == 0
- torch.testing.assert_close(
- attn_metadata.seq_lens_tensor,
- torch.tensor(seq_lens, device=device, dtype=torch.int))
- assert attn_metadata.seq_lens == seq_lens
- assert attn_metadata.max_prefill_seq_len == max(seq_lens)
- assert attn_metadata.max_decode_seq_len == 0
- # Test subquery start locs.
- start_idx = 0
- start_loc = [start_idx]
- for seq_len in seq_lens:
- start_idx += seq_len
- start_loc.append(start_idx)
- torch.testing.assert_close(
- attn_metadata.query_start_loc,
- torch.tensor(start_loc, dtype=torch.int32, device=device))
- # Test seq start locs. Note that for normal prefill it is
- # equivalent to query_start_loc.
- start_idx = 0
- seq_start_loc = [start_idx]
- for seq_len in seq_lens:
- start_idx += seq_len
- seq_start_loc.append(start_idx)
- torch.testing.assert_close(
- attn_metadata.seq_start_loc,
- torch.tensor(start_loc, dtype=torch.int32, device=device))
- torch.testing.assert_close(
- attn_metadata.context_lens_tensor,
- torch.zeros(attn_metadata.context_lens_tensor.shape[0],
- dtype=torch.int,
- device=device))
- expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
- dtype=torch.int32,
- device=model_runner.device)
- torch.testing.assert_close(attn_metadata.block_tables, expected)
- # Cuda graph should not be used for prerill.
- assert attn_metadata.use_cuda_graph is False
- assert len(input_tokens) == sum(seq_lens)
- assert len(input_positions) == sum(seq_lens)
- torch.testing.assert_close(input_tokens, input_positions)
- sampling_metadata = SamplingMetadata.prepare(
- seq_group_metadata_list,
- seq_lens,
- query_lens=seq_lens,
- device=model_runner.device,
- pin_memory=model_runner.pin_memory)
- assert len(input_tokens) == sum(seq_lens)
- assert len(input_positions) == sum(seq_lens)
- actual = sampling_metadata.selected_token_indices
- expected = torch.tensor(expected_selected_token_indices,
- device=actual.device,
- dtype=actual.dtype)
- torch.testing.assert_close(actual, expected)
- torch.allclose(input_tokens, input_positions)
- actual = sampling_metadata.selected_token_indices
- expected = torch.tensor(expected_selected_token_indices,
- device=actual.device,
- dtype=actual.dtype)
- torch.testing.assert_close(actual, expected)
- @pytest.mark.parametrize("batch_size", list(range(1, 257)))
- def test_prepare_decode_cuda_graph(batch_size):
- model_runner = _create_model_runner(
- "facebook/opt-125m",
- seed=0,
- dtype="float16",
- enforce_eager=False,
- max_num_batched_tokens=100000,
- max_num_seqs=100000,
- enable_chunked_prefill=False,
- )
- context_lens: List[int] = []
- seq_group_metadata_list: List[SequenceGroupMetadata] = []
- # Assume each seq group finishes prefill.
- for i in range(batch_size):
- # make sure all tokens fit into one block
- context_len = i % (model_runner.block_size - 1) + 1
- context_lens.append(context_len)
- seq_data = SequenceData.from_seqs(range(context_len))
- seq_data.update_num_computed_tokens(context_len)
- # Append one token ID since prefill is finished.
- seq_data.append_token_id(1, 0)
- seq_group_metadata = SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=False,
- seq_data={0: seq_data},
- sampling_params=SamplingParams(temperature=0),
- block_tables={0: [1]},
- )
- assert seq_group_metadata.token_chunk_size == 1
- seq_group_metadata_list.append(seq_group_metadata)
- model_input = model_runner._prepare_model_input_tensors(
- seq_group_metadata_list)
- input_tokens, input_positions, attn_metadata, slot_mapping = (
- model_input.input_tokens, model_input.input_positions,
- model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
- assert len(slot_mapping) == len(input_tokens)
- expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
- # Verify input metadata is correct for prompts.
- device = model_runner.device
- assert attn_metadata.num_prefills == 0
- assert attn_metadata.num_prefill_tokens == 0
- seq_lens = [context_len + 1 for context_len in context_lens]
- # seq_lens are padded to expected_bs
- for _ in range(expected_bs - len(seq_lens)):
- seq_lens.append(1)
- assert attn_metadata.seq_lens == seq_lens
- assert attn_metadata.num_decode_tokens == len(seq_lens)
- start_idx = 0
- start_loc = [start_idx]
- for _ in context_lens:
- # decode has only 1 token for query.
- start_idx += 1
- start_loc.append(start_idx)
- torch.testing.assert_close(
- attn_metadata.query_start_loc,
- torch.tensor(start_loc, dtype=torch.int32, device=device))
- start_idx = 0
- seq_start_loc = [start_idx]
- for seq_len in seq_lens:
- start_idx += seq_len
- seq_start_loc.append(start_idx)
- torch.testing.assert_close(
- attn_metadata.seq_start_loc,
- torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
- torch.testing.assert_close(
- attn_metadata.context_lens_tensor,
- torch.tensor(context_lens, dtype=torch.int, device=device))
- assert attn_metadata.max_decode_seq_len == max(seq_lens)
- torch.testing.assert_close(
- attn_metadata.seq_lens_tensor[:len(seq_lens)],
- torch.tensor(seq_lens, dtype=torch.int, device=device))
- # block table's first index corresponds to each batch, meaning in
- # decoding it is each token.
- assert attn_metadata.block_tables.shape[0] == len(input_tokens)
- # Block table's second dim correspondsd to each token's block number.
- # It is padded up to
- assert attn_metadata.block_tables.shape[1] == (
- model_runner.get_max_block_per_batch())
- assert attn_metadata.use_cuda_graph is True
- assert len(input_tokens) == expected_bs
- assert len(input_positions) == expected_bs
- torch.allclose(input_tokens, input_positions)
- # Verify Sampling
- expected_selected_token_indices = []
- selected_token_start_idx = 0
- for _ in context_lens:
- expected_selected_token_indices.append(selected_token_start_idx)
- selected_token_start_idx += 1
- sampling_metadata = SamplingMetadata.prepare(
- seq_group_metadata_list,
- seq_lens,
- # query lens is all 1 for decode.
- query_lens=[1 for _ in range(len(context_lens))],
- device=model_runner.device,
- pin_memory=model_runner.pin_memory)
- actual = sampling_metadata.selected_token_indices
- expected = torch.tensor(expected_selected_token_indices,
- device=actual.device,
- dtype=actual.dtype)
- torch.testing.assert_close(actual, expected)
- def test_empty_seq_group():
- """Verify prepare prompt and decode returns empty output."""
- model_runner = _create_model_runner(
- "facebook/opt-125m",
- seed=0,
- dtype="float16",
- enforce_eager=False,
- )
- seq_group_metadata_list: List[SequenceGroupMetadata] = []
- model_input = model_runner._prepare_model_input_tensors(
- seq_group_metadata_list)
- input_tokens, input_positions, attn_metadata = (
- model_input.input_tokens,
- model_input.input_positions,
- model_input.attn_metadata,
- )
- assert input_tokens is None
- assert input_positions is None
- assert attn_metadata is None
- model_input = model_runner._prepare_model_input_tensors(
- seq_group_metadata_list)
- (input_tokens, input_positions, attn_metadata, return_seq_lens) = (
- model_input.input_tokens,
- model_input.input_positions,
- model_input.attn_metadata,
- model_input.seq_lens,
- )
- assert input_tokens is None
- assert input_positions is None
- assert attn_metadata is None
- assert return_seq_lens is None
- @pytest.fixture
- def distributed_init():
- init_distributed_environment(
- world_size=1,
- rank=0,
- distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
- local_rank=0)
- ensure_model_parallel_initialized(1, 1)
- @pytest.mark.parametrize("batch_size", list(range(2, 128)))
- @pytest.mark.parametrize("enforce_eager", [True, False])
- def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
- model_runner = _create_model_runner(
- "facebook/opt-125m",
- seed=0,
- dtype="float16",
- enforce_eager=enforce_eager,
- max_num_batched_tokens=100000,
- max_num_seqs=100000,
- enable_chunked_prefill=True,
- )
- # Add prefill requests.
- seq_lens: List[int] = []
- seq_group_metadata_list: List[SequenceGroupMetadata] = []
- prefill_metadata_list: List[SequenceGroupMetadata] = []
- decode_metadata_list: List[SequenceGroupMetadata] = []
- block_tables = {0: [1]}
- prefill_batch_size = batch_size // 2
- decode_batch_size = batch_size - prefill_batch_size
- for i in range(prefill_batch_size):
- # make sure all tokens fit into one block
- seq_len = i % (model_runner.block_size - 1) + 1
- seq_lens.append(seq_len)
- seq_data = SequenceData.from_seqs(range(seq_len))
- seq_group_metadata = SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={0: seq_data},
- sampling_params=SamplingParams(temperature=0),
- block_tables=block_tables,
- )
- assert seq_group_metadata.token_chunk_size == seq_data.get_len()
- seq_group_metadata_list.append(seq_group_metadata)
- prefill_metadata_list.append(seq_group_metadata)
- # Add decode requests
- for i in range(prefill_batch_size, batch_size):
- # make sure all tokens fit into one block
- context_len = i % (model_runner.block_size - 1) + 1
- seq_data = SequenceData.from_seqs(range(context_len))
- seq_data.append_token_id(1, 0)
- seq_data.update_num_computed_tokens(context_len)
- seq_group_metadata = SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=False,
- seq_data={0: seq_data},
- sampling_params=SamplingParams(temperature=0),
- block_tables={0: [1]},
- )
- assert seq_group_metadata.token_chunk_size == 1
- seq_group_metadata_list.append(seq_group_metadata)
- decode_metadata_list.append(seq_group_metadata)
- model_input = model_runner.prepare_model_input(seq_group_metadata_list)
- (input_tokens, input_positions, attn_metadata) = (
- model_input.input_tokens,
- model_input.input_positions,
- model_input.attn_metadata,
- )
- prefill_meta_actual = attn_metadata.prefill_metadata
- decode_meta_actual = attn_metadata.decode_metadata
- assert len(attn_metadata.slot_mapping) == len(input_tokens)
- assert len(input_positions) == len(input_tokens)
- assert attn_metadata.num_prefills == prefill_batch_size
- assert attn_metadata.num_decode_tokens == decode_batch_size
- assert attn_metadata.num_prefill_tokens == sum(seq_lens)
- # Verify attn metadata is consistent. We don't need to test individual
- # values here because they are tested above.
- attn_metadata = model_runner._prepare_model_input_tensors(
- seq_group_metadata_list).attn_metadata
- for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
- vars(prefill_meta_actual)):
- assert attr_expected[1] == attr_actual[1]
- for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
- vars(decode_meta_actual)):
- assert attr_expected[1] == attr_actual[1]
|