test_seeded_generate.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. """Verify that seeded random sampling is deterministic.
  2. Run `pytest tests/samplers/test_seeded_generate.py`.
  3. """
  4. import copy
  5. import random
  6. from itertools import combinations
  7. import pytest
  8. from aphrodite import SamplingParams
  9. from aphrodite.modeling.utils import set_random_seed
  10. MODEL = "facebook/opt-125m"
  11. RANDOM_SEEDS = list(range(5))
  12. @pytest.fixture
  13. def aphrodite_model(aphrodite_runner):
  14. with aphrodite_runner(MODEL, dtype="half") as aphrodite_model:
  15. yield aphrodite_model
  16. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  17. def test_random_sample_with_seed(
  18. aphrodite_model,
  19. example_prompts,
  20. seed: int,
  21. ) -> None:
  22. set_random_seed(seed)
  23. sampling_params = SamplingParams(
  24. # Parameters to ensure sufficient randomness
  25. temperature=2.0,
  26. top_p=min(random.random() + 0.3, 1),
  27. top_k=random.randint(5, 20),
  28. n=random.randint(1, 10),
  29. presence_penalty=random.randint(0, 1),
  30. max_tokens=8,
  31. ignore_eos=True,
  32. )
  33. sampling_params_seed_1 = copy.deepcopy(sampling_params)
  34. sampling_params_seed_1.seed = 100
  35. sampling_params_seed_2 = copy.deepcopy(sampling_params)
  36. sampling_params_seed_2.seed = 200
  37. llm = aphrodite_model.model
  38. for prompt in example_prompts:
  39. for params in (
  40. sampling_params,
  41. sampling_params_seed_1,
  42. sampling_params_seed_2,
  43. sampling_params,
  44. sampling_params_seed_1,
  45. sampling_params_seed_2,
  46. ):
  47. llm._add_request(prompt, params=params)
  48. results = llm._run_engine(use_tqdm=False)
  49. all_outputs = [[out.token_ids for out in output.outputs]
  50. for output in results]
  51. for i in range(0, len(example_prompts), 6):
  52. outputs = all_outputs[i:i + 6]
  53. # verify all non-seeded requests differ
  54. for output_a, output_b in combinations(
  55. (outputs[0], outputs[1], outputs[2], outputs[3]),
  56. 2,
  57. ):
  58. assert output_a != output_b
  59. # verify requests with the same seed match
  60. assert outputs[1] == outputs[4]
  61. assert outputs[2] == outputs[5]