test_seed.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import pytest
  2. from .conftest import run_equality_correctness_test
  3. # main model
  4. MAIN_MODEL = "JackFram/llama-68m"
  5. # speculative model
  6. SPEC_MODEL = "JackFram/llama-160m"
  7. @pytest.mark.parametrize(
  8. "common_llm_kwargs",
  9. [{
  10. "model_name": "JackFram/llama-68m",
  11. # Skip cuda graph recording for fast test.
  12. "enforce_eager": True,
  13. # Required for spec decode.
  14. "use_v2_block_manager": True,
  15. # speculative model
  16. "speculative_model": "JackFram/llama-160m",
  17. # num speculative tokens
  18. "num_speculative_tokens": 3,
  19. }])
  20. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  21. @pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
  22. @pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
  23. @pytest.mark.parametrize("batch_size", [1, 8, 32])
  24. @pytest.mark.parametrize("temperature", [0.1, 1.0])
  25. @pytest.mark.parametrize(
  26. "output_len",
  27. [
  28. # Use smaller output len for fast test.
  29. 20,
  30. ])
  31. def test_seeded_consistency(aphrodite_runner, common_llm_kwargs,
  32. per_test_common_llm_kwargs, baseline_llm_kwargs,
  33. test_llm_kwargs, batch_size: int,
  34. temperature: float, output_len: int):
  35. """Verify outputs are consistent across multiple runs with same seed
  36. """
  37. run_equality_correctness_test(
  38. aphrodite_runner,
  39. common_llm_kwargs,
  40. per_test_common_llm_kwargs,
  41. baseline_llm_kwargs,
  42. test_llm_kwargs,
  43. batch_size,
  44. max_output_len=output_len,
  45. temperature=temperature,
  46. disable_seed=False,
  47. )
  48. # Ensure this same test does fail if we _don't_ include per-request seeds
  49. with pytest.raises(AssertionError):
  50. run_equality_correctness_test(
  51. aphrodite_runner,
  52. common_llm_kwargs,
  53. per_test_common_llm_kwargs,
  54. baseline_llm_kwargs,
  55. test_llm_kwargs,
  56. batch_size,
  57. max_output_len=output_len,
  58. temperature=temperature,
  59. disable_seed=True,
  60. )