123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241 |
- import random
- from typing import Tuple
- from unittest.mock import patch
- import pytest
- import torch
- from aphrodite.modeling.layers.sampler import Sampler
- from aphrodite.modeling.utils import set_random_seed
- from aphrodite.common.sequence import (SamplingParams, SequenceData,
- SequenceGroupMetadata)
- from aphrodite.task_handler.model_runner import ModelRunner
- class MockLogitsSampler(Sampler):
- def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
- super().__init__(vocab_size=vocab_size)
- self.fake_logits = fake_logits
- def forward(self, *args, **kwargs):
- with patch("aphrodite.modeling.layers.sampler._prune_hidden_states",
- lambda x, y: x), patch(
- "aphrodite.modeling.layers.sampler._get_logits",
- lambda *args, **kwargs: self.fake_logits):
- return super().forward(*args, **kwargs)
- def _prepare_test(
- batch_size: int
- ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
- vocab_size = 32000
- input_tensor = torch.rand((batch_size, 1024),
- device="cuda",
- dtype=torch.float16)
- fake_logits = torch.full((batch_size, vocab_size),
- 1e-2,
- device=input_tensor.device,
- dtype=input_tensor.dtype)
- sampler = MockLogitsSampler(32000, fake_logits)
- model_runner = ModelRunner(None, None, None)
- return input_tensor, fake_logits, sampler, model_runner
- RANDOM_SEEDS = list(range(128))
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- def test_sampler_all_greedy(seed: int):
- set_random_seed(seed)
- batch_size = random.randint(1, 256)
- input_tensor, fake_logits, sampler, model_runner = _prepare_test(
- batch_size)
- seq_group_metadata_list = []
- prompt_lens = []
- for i in range(batch_size):
- seq_group_metadata_list.append(
- SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={0: SequenceData([1, 2, 3])},
- sampling_params=SamplingParams(temperature=0, ),
- block_tables={0: [1]},
- persistent_data={},
- ))
- prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
- sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
- prompt_lens)
- sampler_output = sampler(embedding=None,
- hidden_states=input_tensor,
- sampling_metadata=sampling_metadata)
- expected = torch.argmax(fake_logits, dim=-1)
- for i, sequence_output in enumerate(sampler_output):
- for nth_output in sequence_output.samples:
- assert nth_output.output_token == expected[i].item()
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- def test_sampler_all_random(seed: int):
- set_random_seed(seed)
- batch_size = random.randint(1, 256)
- input_tensor, fake_logits, sampler, model_runner = _prepare_test(
- batch_size)
- for i in range(batch_size):
- fake_logits[i, i] = 1e2
- seq_group_metadata_list = []
- prompt_lens = []
- for i in range(batch_size):
- seq_group_metadata_list.append(
- SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={0: SequenceData([1, 2, 3])},
- sampling_params=SamplingParams(
- temperature=1.0,
- n=random.randint(1, 10),
- ),
- block_tables={0: [1]},
- persistent_data={},
- ))
- prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
- sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
- prompt_lens)
- sampler_output = sampler(embedding=None,
- hidden_states=input_tensor,
- sampling_metadata=sampling_metadata)
- for i, sequence_output in enumerate(sampler_output):
- for nth_output in sequence_output.samples:
- assert nth_output.output_token == i
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- def test_sampler_all_beam(seed: int):
- set_random_seed(seed)
- batch_size = random.randint(1, 256)
- input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
- seq_group_metadata_list = []
- prompt_lens = []
- for i in range(batch_size):
- seq_group_metadata_list.append(
- SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={0: SequenceData([1, 2, 3])},
- sampling_params=SamplingParams(
- temperature=0,
- best_of=2,
- use_beam_search=True,
- ),
- block_tables={0: [1]},
- persistent_data={},
- ))
- prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
- sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
- prompt_lens)
- sampler(embedding=None,
- hidden_states=input_tensor,
- sampling_metadata=sampling_metadata)
- # no assertion here as I am not sure how to determine whether
- # the outputs are expected - in other words, this just tests
- # whether there are no exceptions in the sampler
- # when handling an all-beam search case.
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- def test_sampler_mixed(seed: int):
- set_random_seed(seed)
- batch_size = random.randint(1, 256)
- input_tensor, fake_logits, sampler, model_runner = _prepare_test(
- batch_size)
- seq_group_metadata_list = []
- expected_tokens = []
- prompt_lens = []
- for i in range(batch_size):
- n = 1
- sampling_type = random.randint(0, 2)
- if sampling_type == 0:
- sampling_params = SamplingParams(temperature=0)
- elif sampling_type == 1:
- n = random.randint(1, 10)
- sampling_params = SamplingParams(
- temperature=random.random() + 0.1,
- top_p=min(random.random() + 0.1, 1),
- top_k=random.randint(0, 10) or -1,
- n=n,
- presence_penalty=random.randint(0, 1),
- )
- else:
- sampling_params = SamplingParams(temperature=0,
- use_beam_search=True,
- best_of=2)
- for idx in range(n):
- fake_logits[i, i + idx] = 1e2
- expected_tokens.append(i + idx)
- seq_group_metadata_list.append(
- SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={0: SequenceData([1, 2, 3])},
- sampling_params=sampling_params,
- block_tables={0: [1]},
- persistent_data={},
- ))
- prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
- sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
- prompt_lens)
- sampler_output = sampler(embedding=None,
- hidden_states=input_tensor,
- sampling_metadata=sampling_metadata)
- for i, sequence_output in enumerate(sampler_output):
- if seq_group_metadata_list[i].sampling_params.use_beam_search:
- continue
- for nth_output in sequence_output.samples:
- assert nth_output.output_token in expected_tokens
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- def test_sampler_logits_processors(seed: int):
- set_random_seed(seed)
- batch_size = random.randint(1, 256)
- input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
- # This sample logits processor gives infinite score to the i-th token,
- # where i is the length of the input sequence.
- # We therefore expect the output token sequence to be [0, 1, 2, ...]
- def pick_ith(token_ids, logits):
- logits[len(token_ids)] = float("inf")
- return logits
- seq_group_metadata_list = []
- prompt_lens = []
- for i in range(batch_size):
- seq_group_metadata_list.append(
- SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={0: SequenceData([1, 2, 3])},
- sampling_params=SamplingParams(temperature=0,
- logits_processors=[pick_ith]),
- block_tables={0: [1]},
- persistent_data={},
- ))
- prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
- sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
- prompt_lens)
- sampler_output = sampler(embedding=None,
- hidden_states=input_tensor,
- sampling_metadata=sampling_metadata)
- for _, sequence_output in enumerate(sampler_output):
- for idx, nth_output in enumerate(sequence_output.samples):
- assert nth_output.output_token == idx
|