test_seeded_generate.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. """Verify that seeded random sampling is deterministic.
  2. Run `pytest tests/samplers/test_seeded_generate.py --forked`.
  3. """
  4. import copy
  5. import random
  6. from itertools import combinations
  7. import pytest
  8. from aphrodite.modeling.utils import set_random_seed
  9. from aphrodite import SamplingParams
  10. MODEL = "EleutherAI/pythia-70m-deduped"
  11. RANDOM_SEEDS = list(range(5))
  12. @pytest.fixture
  13. def aphrodite_model(aphrodite_runner):
  14. aphrodite_model = aphrodite_runner(MODEL, dtype="half")
  15. yield aphrodite_model
  16. del aphrodite_model
  17. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  18. def test_random_sample_with_seed(
  19. aphrodite_model,
  20. example_prompts,
  21. seed: int,
  22. ) -> None:
  23. set_random_seed(seed)
  24. sampling_params = SamplingParams(
  25. # Parameters to ensure sufficient randomness
  26. temperature=2.0,
  27. top_p=min(random.random() + 0.3, 1),
  28. top_k=random.randint(5, 20),
  29. n=random.randint(1, 10),
  30. presence_penalty=random.randint(0, 1),
  31. max_tokens=8,
  32. ignore_eos=True,
  33. )
  34. sampling_params_seed_1 = copy.deepcopy(sampling_params)
  35. sampling_params_seed_1.seed = 100
  36. sampling_params_seed_2 = copy.deepcopy(sampling_params)
  37. sampling_params_seed_2.seed = 200
  38. llm = aphrodite_model.model
  39. for prompt in example_prompts:
  40. for params in (
  41. sampling_params,
  42. sampling_params_seed_1,
  43. sampling_params_seed_2,
  44. sampling_params,
  45. sampling_params_seed_1,
  46. sampling_params_seed_2,
  47. ):
  48. llm._add_request(
  49. prompt=prompt,
  50. prompt_token_ids=None,
  51. sampling_params=params,
  52. )
  53. results = llm._run_engine(use_tqdm=False)
  54. all_outputs = [[out.token_ids for out in output.outputs]
  55. for output in results]
  56. for i in range(0, len(example_prompts), 6):
  57. outputs = all_outputs[i:i + 6]
  58. # verify all non-seeded requests differ
  59. for output_a, output_b in combinations(
  60. (outputs[0], outputs[1], outputs[2], outputs[3]),
  61. 2,
  62. ):
  63. assert output_a != output_b
  64. # verify requests with the same seed match
  65. assert outputs[1] == outputs[4]
  66. assert outputs[2] == outputs[5]