1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- """Verify that seeded random sampling is deterministic.
- Run `pytest tests/samplers/test_seeded_generate.py`.
- """
- import copy
- import random
- from itertools import combinations
- import pytest
- from aphrodite import SamplingParams
- from aphrodite.modeling.utils import set_random_seed
- MODEL = "facebook/opt-125m"
- RANDOM_SEEDS = list(range(5))
- @pytest.fixture
- def aphrodite_model(aphrodite_runner):
- with aphrodite_runner(MODEL, dtype="half") as aphrodite_model:
- yield aphrodite_model
- @pytest.mark.parametrize("seed", RANDOM_SEEDS)
- def test_random_sample_with_seed(
- aphrodite_model,
- example_prompts,
- seed: int,
- ) -> None:
- set_random_seed(seed)
- sampling_params = SamplingParams(
- # Parameters to ensure sufficient randomness
- temperature=2.0,
- top_p=min(random.random() + 0.3, 1),
- top_k=random.randint(5, 20),
- n=random.randint(1, 10),
- presence_penalty=random.randint(0, 1),
- max_tokens=8,
- ignore_eos=True,
- )
- sampling_params_seed_1 = copy.deepcopy(sampling_params)
- sampling_params_seed_1.seed = 100
- sampling_params_seed_2 = copy.deepcopy(sampling_params)
- sampling_params_seed_2.seed = 200
- llm = aphrodite_model.model
- for prompt in example_prompts:
- for params in (
- sampling_params,
- sampling_params_seed_1,
- sampling_params_seed_2,
- sampling_params,
- sampling_params_seed_1,
- sampling_params_seed_2,
- ):
- llm._add_request(prompt, params=params)
- results = llm._run_engine(use_tqdm=False)
- all_outputs = [[out.token_ids for out in output.outputs]
- for output in results]
- for i in range(0, len(example_prompts), 6):
- outputs = all_outputs[i:i + 6]
- # verify all non-seeded requests differ
- for output_a, output_b in combinations(
- (outputs[0], outputs[1], outputs[2], outputs[3]),
- 2,
- ):
- assert output_a != output_b
- # verify requests with the same seed match
- assert outputs[1] == outputs[4]
- assert outputs[2] == outputs[5]
|