123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876 |
- import itertools
- import random
- from array import array
- from dataclasses import dataclass
- from typing import Dict, List, Optional, Tuple
- from unittest.mock import Mock, patch
- import pytest
- import torch
- from transformers import GenerationConfig, GenerationMixin
- from aphrodite.common.sequence import (SamplingParams, SequenceData,
- SequenceGroupMetadata)
- from aphrodite.common.utils import Counter, is_pin_memory_available
- from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
- from aphrodite.modeling.layers.sampler import Sampler
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.modeling.utils import set_random_seed
- class MockLogitsSampler(Sampler):
- def __init__(self, fake_logits: torch.Tensor):
- super().__init__()
- self.fake_logits = fake_logits
- def forward(self, *args, **kwargs):
- return super().forward(*args, **kwargs)
- def _prepare_test(
- batch_size: int
- ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
- input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
- fake_logits = torch.full((batch_size, VOCAB_SIZE),
- 1e-2,
- dtype=input_tensor.dtype)
- sampler = MockLogitsSampler(fake_logits)
- return input_tensor, fake_logits, sampler
- VOCAB_SIZE = 32000
- RANDOM_SEEDS = list(range(128))
- CUDA_DEVICES = [
- f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
- ]
- def _do_sample(
- batch_size: int,
- input_tensor: torch.Tensor,
- sampler: MockLogitsSampler,
- sampling_params: SamplingParams,
- device: str,
- ):
- seq_group_metadata_list: List[SequenceGroupMetadata] = []
- seq_lens: List[int] = []
- for i in range(batch_size):
- seq_group_metadata_list.append(
- SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={
- 0: SequenceData(array(
- APHRODITE_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
- },
- sampling_params=sampling_params,
- block_tables={0: [1]},
- ))
- seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
- sampling_metadata = SamplingMetadata.prepare(
- seq_group_metadata_list,
- seq_lens,
- query_lens=seq_lens,
- device=device,
- pin_memory=is_pin_memory_available())
- return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_sampler_all_greedy(seed: int, device: str):
- set_random_seed(seed)
- torch.set_default_device(device)
- batch_size = random.randint(1, 256)
- input_tensor, fake_logits, sampler = _prepare_test(batch_size)
- sampling_params = SamplingParams(temperature=0)
- sampler_output = _do_sample(batch_size, fake_logits, sampler,
- sampling_params, device)
- expected = torch.argmax(fake_logits, dim=-1)
- for i, sequence_output in enumerate(sampler_output):
- for nth_output in sequence_output.samples:
- assert nth_output.output_token == expected[i].item()
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_sampler_all_random(seed: int, device: str):
- set_random_seed(seed)
- torch.set_default_device(device)
- batch_size = random.randint(1, 256)
- _, fake_logits, sampler = _prepare_test(batch_size)
- for i in range(batch_size):
- fake_logits[i, i] = 1e2
- sampling_params = SamplingParams(
- temperature=1.0,
- n=random.randint(1, 10),
- )
- sampler_output = _do_sample(batch_size, fake_logits, sampler,
- sampling_params, device)
- for i, sequence_output in enumerate(sampler_output):
- for nth_output in sequence_output.samples:
- assert nth_output.output_token == i
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_sampler_all_random_seed(seed: int, device: str):
- set_random_seed(seed)
- torch.set_default_device(device)
- batch_size = random.randint(1, 256)
- _, fake_logits, sampler = _prepare_test(batch_size)
- for i in range(batch_size):
- fake_logits[i, i] = 1e2
- sampling_params = SamplingParams(
- temperature=1.0,
- n=random.randint(1, 10),
- seed=random.randint(0, 10000),
- )
- sampler_output = _do_sample(batch_size, fake_logits, sampler,
- sampling_params, device)
- for i, sequence_output in enumerate(sampler_output):
- for nth_output in sequence_output.samples:
- assert nth_output.output_token == i
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_sampler_all_random_seed_deterministic(seed: int, device: str):
- set_random_seed(seed)
- torch.set_default_device(device)
- batch_size = random.randint(1, 256)
- _, fake_logits, sampler = _prepare_test(batch_size)
- sampling_params = SamplingParams(
- temperature=1.0,
- n=random.randint(1, 10),
- seed=random.randint(0, 10000),
- )
- first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
- sampling_params, device)
- second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
- sampling_params, device)
- assert first_sampler_output == second_sampler_output
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_sampler_all_beam(seed: int, device: str):
- set_random_seed(seed)
- torch.set_default_device(device)
- batch_size = random.randint(1, 256)
- _, fake_logits, sampler = _prepare_test(batch_size)
- sampling_params = SamplingParams(
- temperature=0,
- best_of=2,
- use_beam_search=True,
- )
- _do_sample(batch_size, fake_logits, sampler, sampling_params, device)
- # no assertion here as I am not sure how to determine whether
- # the outputs are expected - in other words, this just tests
- # whether there are no exceptions in the sampler
- # when handling an all-beam search case.
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_sampler_min_tokens_penalty(seed: int, device: str):
- seq_id_counter = Counter(start=random.randint(0, 100))
- set_random_seed(seed)
- torch.set_default_device(device)
- def create_sampling_params(min_tokens,
- eos_token_id=0,
- *,
- stop_token_ids: Optional[List[int]] = None,
- prompt_logprobs: Optional[int] = None):
- sampling_params = SamplingParams(
- min_tokens=min_tokens,
- max_tokens=9999, # keep higher than max of min_tokens
- stop_token_ids=stop_token_ids,
- # requesting prompt_logprobs changes the structure of `logits`
- prompt_logprobs=prompt_logprobs,
- )
- sampling_params.all_stop_token_ids.add(eos_token_id)
- return sampling_params
- def create_sequence_data(num_input=3, num_generated=0):
- seq_data = SequenceData(
- array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
- random.choices(range(0, VOCAB_SIZE), k=num_input)))
- if num_generated > 0:
- seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
- k=num_generated)
- return seq_data
- def generate_test_case():
- # generate multiple seq groups but limit total batch size
- batch_size = random.randint(1, 128)
- expected_penalization = []
- sequence_metadata_list: List[SequenceGroupMetadata] = []
- # 20% chance to generate seq group metadata list with all prompts
- is_prompt = random.random() < 0.2
- while batch_size > 0:
- num_seqs = 1 if is_prompt else random.randint(1, batch_size)
- eos_token_id = random.randint(0, VOCAB_SIZE - 1)
- min_tokens = random.randint(0, 50)
- num_stop_tokens = random.randint(0, 8)
- if num_stop_tokens > 0:
- stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
- k=num_stop_tokens)
- else:
- stop_token_ids = None
- sampling_params = create_sampling_params(
- min_tokens=min_tokens,
- eos_token_id=eos_token_id,
- stop_token_ids=stop_token_ids)
- seq_data: Dict[int, SequenceData] = {}
- seq_group_penalization: List[bool] = []
- for _ in range(num_seqs):
- num_input = random.randint(1, 100)
- num_generated = 0 if is_prompt else random.randint(1, 100)
- seq_data[next(seq_id_counter)] = create_sequence_data(
- num_input=num_input, num_generated=num_generated)
- seq_group_penalization.append(num_generated < min_tokens)
- expected_penalization.extend(seq_group_penalization)
- sequence_metadata_list.append(
- SequenceGroupMetadata(
- request_id=f"test_{batch_size}",
- is_prompt=is_prompt,
- seq_data=seq_data,
- sampling_params=sampling_params,
- block_tables={},
- ))
- batch_size -= num_seqs
- return {
- "expected_penalization": expected_penalization,
- "seq_group_metadata_list": sequence_metadata_list,
- }
- # define some explicit test cases for edge case behavior
- prompt_without_penalization = {
- "expected_penalization": [False],
- "seq_group_metadata_list": [
- SequenceGroupMetadata(
- request_id="test_1",
- is_prompt=True,
- seq_data={
- next(seq_id_counter): create_sequence_data(),
- },
- sampling_params=create_sampling_params(0),
- block_tables={},
- ),
- ]
- }
- prompt_with_penalization = {
- "expected_penalization": [True],
- "seq_group_metadata_list": [
- SequenceGroupMetadata(
- request_id="test_1",
- is_prompt=True,
- seq_data={
- next(seq_id_counter): create_sequence_data(),
- },
- sampling_params=create_sampling_params(1),
- block_tables={},
- ),
- ]
- }
- prompt_with_penalization_and_prompt_logprobs = {
- "expected_penalization": [False, False, True],
- "seq_group_metadata_list": [
- SequenceGroupMetadata(
- request_id="test_1",
- is_prompt=True,
- seq_data={
- next(seq_id_counter): create_sequence_data(num_input=3),
- },
- sampling_params=create_sampling_params(1, prompt_logprobs=3),
- block_tables={},
- ),
- ]
- }
- stop_penalizing_after_min_tokens = {
- "expected_penalization": [False],
- "seq_group_metadata_list": [
- SequenceGroupMetadata(
- request_id="test_1",
- is_prompt=False,
- seq_data={
- next(seq_id_counter):
- create_sequence_data(num_generated=1),
- },
- sampling_params=create_sampling_params(1),
- block_tables={},
- )
- ]
- }
- stop_token_ids = [42, 99, 42, 0] # intentional duplication
- prompt_combination = {
- "expected_penalization": [False, True, False],
- "seq_group_metadata_list": [
- SequenceGroupMetadata(
- request_id="test_2",
- is_prompt=True,
- seq_data={
- next(seq_id_counter): create_sequence_data(num_input=2),
- },
- sampling_params=create_sampling_params(1, prompt_logprobs=3),
- block_tables={},
- ),
- SequenceGroupMetadata(
- request_id="test_3",
- is_prompt=True,
- seq_data={
- next(seq_id_counter): create_sequence_data(),
- },
- sampling_params=create_sampling_params(
- 0, stop_token_ids=stop_token_ids),
- block_tables={},
- )
- ]
- }
- stop_token_ids = [1, 999, 37, 37] # intentional duplication
- decode_combination = {
- "expected_penalization": [True, False, False, True, False],
- "seq_group_metadata_list": [
- SequenceGroupMetadata(
- request_id="test_1",
- is_prompt=False,
- seq_data={
- next(seq_id_counter):
- create_sequence_data(num_generated=1),
- next(seq_id_counter):
- create_sequence_data(num_generated=100),
- },
- sampling_params=create_sampling_params(
- 2, stop_token_ids=stop_token_ids),
- block_tables={},
- ),
- SequenceGroupMetadata(
- request_id="test_2",
- is_prompt=False,
- seq_data={
- next(seq_id_counter):
- create_sequence_data(num_generated=20),
- next(seq_id_counter):
- create_sequence_data(num_generated=1),
- next(seq_id_counter):
- create_sequence_data(num_generated=10),
- },
- sampling_params=create_sampling_params(
- 10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
- block_tables={},
- ),
- ]
- }
- if seed == 0:
- test_cases = [
- prompt_without_penalization,
- prompt_with_penalization,
- prompt_with_penalization_and_prompt_logprobs,
- stop_penalizing_after_min_tokens,
- prompt_combination,
- decode_combination,
- ]
- else:
- test_cases = [generate_test_case()]
- def run_test_case(*, expected_penalization: List[bool],
- seq_group_metadata_list: List[SequenceGroupMetadata]):
- assert expected_penalization, \
- "Invalid test case, need expected_penalization"
- assert seq_group_metadata_list, \
- "Invalid test case, need seq_group_metadata_list"
- batch_size = 0
- seq_lens: List[int] = []
- sampling_params_per_row: List[SamplingParams] = []
- for sgm in seq_group_metadata_list:
- sampling_params = sgm.sampling_params
- num_rows = len(sgm.seq_data)
- if sgm.is_prompt:
- # a prompt seq_group has only one sequence
- seq_data = next(iter(sgm.seq_data.values()))
- prompt_len = seq_data.get_prompt_len()
- seq_lens.append(prompt_len)
- if sgm.sampling_params.prompt_logprobs:
- # with prompt_logprobs each token in the prompt has a row in
- # logits
- num_rows = prompt_len
- batch_size += num_rows
- sampling_params_per_row.extend(
- itertools.repeat(sampling_params, num_rows))
- assert len(
- expected_penalization
- ) == batch_size, \
- ("Invalid test case, expected_penalization does not match computed"
- "batch size")
- _, fake_logits, sampler = _prepare_test(batch_size)
- sampling_metadata = SamplingMetadata.prepare(
- seq_group_metadata_list,
- seq_lens=seq_lens if seq_lens else None,
- query_lens=seq_lens if seq_lens else None,
- device=device,
- pin_memory=is_pin_memory_available())
- # the logits tensor is modified in-place by the sampler
- _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
- for logits_idx, (should_penalize, sampling_params) in enumerate(
- zip(expected_penalization, sampling_params_per_row)):
- tokens_to_check = sampling_params.all_stop_token_ids
- if should_penalize:
- for token_id in tokens_to_check:
- assert fake_logits[logits_idx, token_id] == -float(
- 'inf'
- ), f"Expected token {token_id} for logits row {logits_idx}"
- " to be penalized"
- # no other tokens should be set to -inf
- assert torch.count_nonzero(
- fake_logits[logits_idx, :] == -float('inf')) == len(
- tokens_to_check
- ), f"Expected only {len(tokens_to_check)} to be penalized"
- else:
- # no tokens should be set to -inf
- assert torch.count_nonzero(
- fake_logits[logits_idx, :] ==
- -float('inf')) == 0, "No tokens should have been penalized"
- for test_case in test_cases:
- run_test_case(**test_case)
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_sampler_mixed(seed: int, device: str):
- set_random_seed(seed)
- torch.set_default_device(device)
- batch_size = random.randint(1, 256)
- input_tensor, fake_logits, sampler = _prepare_test(batch_size)
- seq_group_metadata_list: List[SequenceGroupMetadata] = []
- expected_tokens: List[Optional[List[int]]] = []
- seq_lens: List[int] = []
- for i in range(batch_size):
- expected: Optional[List[int]] = None
- sampling_type = random.randint(0, 3)
- if sampling_type == 0:
- sampling_params = SamplingParams(temperature=0)
- expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
- elif sampling_type in (1, 2):
- n = random.randint(1, 10)
- sampling_params = SamplingParams(
- temperature=random.random() + 0.1,
- top_p=min(random.random() + 0.1, 1),
- top_k=random.randint(0, 10) or -1,
- n=n,
- presence_penalty=random.randint(0, 1),
- )
- if sampling_type == 2:
- sampling_params.seed = random.randint(0, 10000)
- else:
- for idx in range(n):
- fake_logits[i, i + idx] = 1e2
- expected = list(range(i, i + n))
- else:
- sampling_params = SamplingParams(temperature=0,
- use_beam_search=True,
- best_of=2)
- expected_tokens.append(expected)
- seq_group_metadata_list.append(
- SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={
- 0: SequenceData(array(
- APHRODITE_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
- },
- sampling_params=sampling_params,
- block_tables={0: [1]},
- ))
- seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
- generators: Dict[str, torch.Generator] = {}
- def test_sampling():
- sampling_metadata = SamplingMetadata.prepare(
- seq_group_metadata_list,
- seq_lens,
- query_lens=seq_lens,
- device=device,
- pin_memory=is_pin_memory_available(),
- generators=generators)
- sampler_output = sampler(logits=fake_logits,
- sampling_metadata=sampling_metadata)
- for i, (sequence_output, metadata) in enumerate(
- zip(sampler_output, seq_group_metadata_list)):
- if metadata.sampling_params.use_beam_search:
- continue
- if (metadata.sampling_params.seed is not None
- and expected_tokens[i] is None):
- # Record seeded random result to compare with results of
- # second invocation
- expected_tokens[i] = [
- nth_output.output_token
- for nth_output in sequence_output.samples
- ]
- continue
- expected_tokens_item = expected_tokens[i]
- assert expected_tokens_item is not None
- for n, nth_output in enumerate(sequence_output.samples):
- if (metadata.sampling_params.temperature == 0
- or metadata.sampling_params.seed is not None):
- # Ensure exact matches for greedy or random with seed
- assert nth_output.output_token == expected_tokens_item[n]
- else:
- # For non-seeded random check that one of the high-logit
- # tokens were chosen
- assert nth_output.output_token in expected_tokens_item
- # Test batch
- test_sampling()
- # Shuffle the batch and resample
- target_index = list(range(batch_size))
- for list_to_shuffle in (target_index, seq_group_metadata_list,
- expected_tokens, seq_lens):
- random.Random(seed).shuffle(list_to_shuffle)
- target_index = torch.tensor(target_index)
- input_tensor.data = input_tensor.index_select(0, target_index)
- fake_logits.data = fake_logits.index_select(0, target_index)
- # This time, results of seeded random samples will be compared with
- # the corresponding sample in the pre-shuffled batch
- test_sampling()
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_sampler_top_k_top_p(seed: int, device: str):
- set_random_seed(seed)
- batch_size = random.randint(1, 256)
- top_k = random.randint(100, 500)
- top_p = random.random() * 0.1
- vocab_size = 32000
- input_tensor = torch.rand((batch_size, 1024),
- device=device,
- dtype=torch.float16)
- fake_logits = torch.normal(0,
- 5,
- size=(batch_size, vocab_size),
- device=input_tensor.device,
- dtype=input_tensor.dtype)
- sampler = MockLogitsSampler(fake_logits)
- generation_model = GenerationMixin()
- generation_config = GenerationConfig(top_k=top_k,
- top_p=top_p,
- do_sample=True)
- @dataclass
- class MockConfig:
- is_encoder_decoder: bool = False
- generation_model.config = MockConfig() # needed by the following method
- generation_model._prepare_special_tokens(generation_config, device=device)
- processors = generation_model._get_logits_processor(generation_config,
- None,
- None,
- None, [],
- device=device)
- assert len(processors) == 2 # top_p and top_k
- seq_group_metadata_list: List[SequenceGroupMetadata] = []
- seq_lens: List[int] = []
- for i in range(batch_size):
- seq_group_metadata_list.append(
- SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={
- 0: SequenceData(array(
- APHRODITE_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
- },
- sampling_params=SamplingParams(
- temperature=1,
- top_k=top_k,
- top_p=top_p,
- ),
- block_tables={0: [1]},
- ))
- seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
- sampling_metadata = SamplingMetadata.prepare(
- seq_group_metadata_list,
- seq_lens,
- query_lens=seq_lens,
- device=device,
- pin_memory=is_pin_memory_available())
- sample_probs = None
- def mock_sample(probs, *args, **kwargs):
- nonlocal sample_probs
- sample_probs = probs
- return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
- for prob in probs], None)
- with patch("aphrodite.modeling.layers.sampler._sample", mock_sample):
- sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
- assert sample_probs is not None
- hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone())
- hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
- torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
- assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_sampler_repetition_penalty_mixed(device: str):
- vocab_size = 8
- def test_sampling_params(sampling_params: List[SamplingParams]):
- seq_group_metadata_list: List[SequenceGroupMetadata] = []
- seq_lens: List[int] = []
- for i in range(2):
- seq_group_metadata_list.append(
- SequenceGroupMetadata(
- request_id=f"test_{i}",
- is_prompt=True,
- seq_data={
- 0:
- SequenceData(array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
- [1, 2, 3]))
- },
- sampling_params=sampling_params[i],
- block_tables={0: [1]},
- ))
- seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
- sampling_metadata = SamplingMetadata.prepare(
- seq_group_metadata_list,
- seq_lens,
- query_lens=seq_lens,
- device=device,
- pin_memory=is_pin_memory_available())
- fake_logits = torch.full((2, vocab_size),
- 1e-2,
- device=device,
- dtype=torch.float16)
- fake_logits[:, 5] = 1.1e-2
- fake_logits[:, 1] = 1.2e-2
- sampler = MockLogitsSampler(fake_logits)
- sampler_output = sampler(logits=fake_logits,
- sampling_metadata=sampling_metadata)
- generated_tokens = []
- for output in sampler_output:
- generated_tokens.append(output.samples[0].output_token)
- return generated_tokens
- # one configuration is greedy with repetition_penalty
- sampling_params_rep = SamplingParams(
- temperature=0.0,
- repetition_penalty=2.0,
- )
- # other configuration is sampling w/o repetition_penalty
- sampling_params_sample = SamplingParams(
- temperature=1.0,
- top_k=1,
- seed=42,
- )
- tokens1 = test_sampling_params(
- [sampling_params_rep, sampling_params_sample])
- tokens2 = test_sampling_params(
- [sampling_params_sample, sampling_params_rep])
- assert tokens1[0] == tokens2[1]
- assert tokens1[1] == tokens2[0]
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_sampler_no_repeat_ngram(seed: int, device: str):
- """Test that no-repeat-ngram sampling behaves as expected."""
- set_random_seed(seed)
- torch.set_default_device(device)
- batch_size = random.randint(1, 256)
- _, fake_logits, sampler = _prepare_test(batch_size)
- test_sequences = {
- # Format: sequence: [tokens_that_should_be_blocked]
- (1, 2, 3): [3], # With ngram_size=2, should block 3 after [2]
- (4, 5, 4, 5): [4], # With ngram_size=2, should block 4 after [5]
- (6, 7, 8, 6, 7): [8], # With ngram_size=3, should block 8 after [6, 7]
- (1, 2, 3, 4, 1, 2): [3], # With ngram_size=4, should block 3 after [1, 2] # noqa: E501
- }
- for input_seq, blocked_tokens in test_sequences.items():
- for ngram_size in [2, 3, 4]:
- sampling_params = SamplingParams(
- temperature=1.0,
- no_repeat_ngram_size=ngram_size,
- seed=random.randint(0, 10000),
- )
- sampler_output = _do_sample(
- 1,
- fake_logits[0:1].clone(), # Just use first row
- sampler,
- sampling_params,
- device
- )
- if len(input_seq) >= ngram_size:
- # check if blocked tokens have -inf logits
- for token in blocked_tokens:
- assert sampler_output[0].samples[0].output_token != token, \
- f"Token {token} should have been blocked by {ngram_size}-gram repetition prevention" # noqa: E501
- # disabled
- sampling_params = SamplingParams(
- temperature=1.0,
- no_repeat_ngram_size=0,
- seed=random.randint(0, 10000),
- )
- sampler_output = _do_sample(
- 1,
- fake_logits[0:1].clone(),
- sampler,
- sampling_params,
- device
- )
- output_token = sampler_output[0].samples[0].output_token
- assert output_token is not None, "Should produce output token with ngram_size=0" # noqa: E501
- # determinism
- sampling_params = SamplingParams(
- temperature=1.0,
- no_repeat_ngram_size=3,
- seed=random.randint(0, 10000),
- )
- first_output = _do_sample(batch_size, fake_logits.clone(), sampler,
- sampling_params, device)
- second_output = _do_sample(batch_size, fake_logits.clone(), sampler,
- sampling_params, device)
- assert first_output == second_output, \
- "No-repeat-ngram sampling is not deterministic with same seed"
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_sampler_nsigma(seed: int, device: str):
- """Test that top-nsigma sampling behaves as expected."""
- set_random_seed(seed)
- torch.set_default_device(device)
- batch_size = random.randint(1, 256)
- _, fake_logits, sampler = _prepare_test(batch_size)
- # Create a clear separation in logits for testing
- high_logit_indices = {} # Store high logit indices for each batch
- for i in range(batch_size):
- # Set a few logits significantly higher than others
- num_high_logits = random.randint(1, 5)
- high_indices = random.sample(range(fake_logits.size(1)),
- num_high_logits)
- high_logit_indices[i] = set(high_indices) # Store for verification
- for idx in high_indices:
- fake_logits[i, idx] = 10.0 # Clearly above the mean
- # Test with different nsigma values
- for nsigma in [1.5, 2.0, 3.0]:
- sampling_params = SamplingParams(
- temperature=1.0,
- nsigma=nsigma,
- seed=random.randint(0, 10000),
- )
-
- sampler_output = _do_sample(batch_size, fake_logits.clone(), sampler,
- sampling_params, device)
- # Verify that sampling only selects from high logits
- for batch_idx, sequence_output in enumerate(sampler_output):
- for nth_output in sequence_output.samples:
- token_id = nth_output.output_token
- # The token should come from the high logits region
- assert token_id in high_logit_indices[batch_idx], \
- f"Sampled token {token_id} for batch {batch_idx} was not in the high logit set" # noqa
- # Test determinism
- second_output = _do_sample(batch_size, fake_logits.clone(), sampler,
- sampling_params, device)
- assert sampler_output == second_output, \
- "Top-nsigma sampling is not deterministic with same seed"
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_sampler_include_gpu_probs_tensor(device: str):
- set_random_seed(42)
- torch.set_default_device(device)
- batch_size = random.randint(1, 256)
- _, fake_logits, sampler = _prepare_test(batch_size)
- sampler.include_gpu_probs_tensor = True
- sampler.should_modify_greedy_probs_inplace = False
- sampling_params = SamplingParams(temperature=0)
- mock_inplace = Mock()
- with patch(
- "aphrodite.modeling.layers.sampler._modify_greedy_probs_inplace",
- mock_inplace):
- sampler_output = _do_sample(batch_size, fake_logits, sampler,
- sampling_params, device)
- mock_inplace.assert_not_called()
- assert sampler_output.sampled_token_probs is not None
- assert sampler_output.logprobs is not None
- assert sampler_output.sampled_token_ids is not None
|