12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import random
- import torch
- from aphrodite.common.sequence import (SamplingParams, SequenceData,
- SequenceGroupMetadata)
- from aphrodite.task_handler.model_runner import ModelRunner
- def test_prepare_prompt():
- model_runner = ModelRunner(None, None, None)
- model_runner.set_block_size(16)
- batch_size = random.randint(1, 256)
- prompt_lens = []
- seq_group_metadata_list = []
- for i in range(batch_size):
- # make sure all tokens fit into one block
- prompt_len = i % (model_runner.block_size - 1) + 1
- prompt_lens.append(prompt_len)
- seq_data = list(range(prompt_len))
- seq_group_metadata_list.append(
- SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={0: SequenceData(seq_data)},
- sampling_params=SamplingParams(temperature=0),
- block_tables={0: [1]},
- persistent_data={},
- ))
- expected_selected_token_indices = []
- selected_token_start_idx = 0
- max_seq_len = max(prompt_lens)
- for prompt_len in prompt_lens:
- expected_selected_token_indices.append(selected_token_start_idx +
- prompt_len - 1)
- selected_token_start_idx += max_seq_len
- input_tokens, input_positions, return_prompt_lens = (
- model_runner._prepare_prompt(seq_group_metadata_list))
- assert return_prompt_lens == prompt_lens
- sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
- prompt_lens)
- assert input_tokens.shape == (batch_size, max_seq_len)
- assert input_positions.shape == (batch_size, max_seq_len)
- torch.testing.assert_close(input_tokens, input_positions)
- actual = sampling_metadata.selected_token_indices
- expected = torch.tensor(expected_selected_token_indices,
- device=actual.device,
- dtype=actual.dtype)
- torch.testing.assert_close(actual, expected)
|