test_seed.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import pytest
  2. from .conftest import run_equality_correctness_test
  3. @pytest.mark.parametrize(
  4. "common_llm_kwargs",
  5. [{
  6. "model": "JackFram/llama-68m",
  7. # Skip cuda graph recording for fast test.
  8. "enforce_eager": True,
  9. # Required for spec decode.
  10. "use_v2_block_manager": True,
  11. # speculative model
  12. "speculative_model": "JackFram/llama-160m",
  13. # num speculative tokens
  14. "num_speculative_tokens": 3,
  15. }])
  16. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  17. @pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
  18. @pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
  19. @pytest.mark.parametrize("batch_size", [1, 8, 32])
  20. @pytest.mark.parametrize("temperature", [0.1, 1.0])
  21. @pytest.mark.parametrize(
  22. "output_len",
  23. [
  24. # Use smaller output len for fast test.
  25. 20,
  26. ])
  27. @pytest.mark.parametrize("seed", [None])
  28. def test_seeded_consistency(baseline_llm_generator, test_llm_generator,
  29. batch_size: int, temperature: float,
  30. output_len: int):
  31. """Verify outputs are consistent across multiple runs with same seed
  32. """
  33. run_equality_correctness_test(baseline_llm_generator,
  34. test_llm_generator,
  35. batch_size,
  36. max_output_len=output_len,
  37. temperature=temperature,
  38. seeded=True,
  39. force_output_len=True)
  40. # Ensure this same test does fail if we _don't_ include per-request seeds
  41. with pytest.raises(AssertionError):
  42. run_equality_correctness_test(baseline_llm_generator,
  43. test_llm_generator,
  44. batch_size,
  45. max_output_len=output_len,
  46. temperature=temperature,
  47. seeded=False,
  48. force_output_len=True)