1
0

test_integration.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. """Tests which cover integration of the speculative decoding framework with
  2. other features, e.g. cuda graphs.
  3. """
  4. import pytest
  5. from .conftest import run_equality_correctness_test
  6. MAIN_MODEL = "JackFram/llama-68m"
  7. @pytest.mark.parametrize(
  8. "common_llm_kwargs",
  9. [{
  10. # Required for spec decode.
  11. "use_v2_block_manager": True,
  12. # Verify equality when cuda graphs allowed.
  13. "enforce_eager": False,
  14. "model_name": "JackFram/llama-68m",
  15. }])
  16. @pytest.mark.parametrize(
  17. "per_test_common_llm_kwargs",
  18. [
  19. {
  20. # Identical models.
  21. "speculative_model": "JackFram/llama-68m",
  22. "num_speculative_tokens": 5,
  23. },
  24. ])
  25. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  26. @pytest.mark.parametrize("test_llm_kwargs", [{}])
  27. @pytest.mark.parametrize("batch_size", [8])
  28. @pytest.mark.parametrize("output_len", [32])
  29. @pytest.mark.parametrize("seed", [1])
  30. def test_spec_decode_cuda_graph(aphrodite_runner, common_llm_kwargs,
  31. per_test_common_llm_kwargs,
  32. baseline_llm_kwargs, test_llm_kwargs,
  33. batch_size: int, output_len: int, seed: int):
  34. """Verify spec decode equality when cuda graphs are enabled.
  35. """
  36. run_equality_correctness_test(aphrodite_runner,
  37. common_llm_kwargs,
  38. per_test_common_llm_kwargs,
  39. baseline_llm_kwargs,
  40. test_llm_kwargs,
  41. batch_size,
  42. max_output_len=output_len,
  43. seed=seed,
  44. temperature=0.0)
  45. @pytest.mark.parametrize(
  46. "common_llm_kwargs",
  47. [{
  48. "model_name": "JackFram/llama-160m",
  49. # Skip cuda graph recording for fast test.
  50. "enforce_eager": True,
  51. # Required for spec decode.
  52. "use_v2_block_manager": True,
  53. }])
  54. @pytest.mark.parametrize("per_test_common_llm_kwargs", [
  55. {
  56. "speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
  57. "num_speculative_tokens": 5,
  58. },
  59. ])
  60. @pytest.mark.parametrize(
  61. "test_llm_kwargs",
  62. [
  63. # Explicitly specify draft model quantization
  64. {
  65. "speculative_model_quantization": "gptq",
  66. },
  67. # Explicitly specify GPTQ-based draft model to use marlin quantization
  68. {
  69. "speculative_model_quantization": "marlin",
  70. },
  71. # Not explicitly specify draft model quantization
  72. {
  73. "speculative_model_quantization": None,
  74. },
  75. ])
  76. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  77. @pytest.mark.parametrize("batch_size", [2])
  78. @pytest.mark.parametrize("seed", [1])
  79. def test_speculative_model_quantization_config(aphrodite_runner,
  80. common_llm_kwargs,
  81. per_test_common_llm_kwargs,
  82. baseline_llm_kwargs,
  83. test_llm_kwargs,
  84. batch_size: int,
  85. seed: int):
  86. """Verify spec decode works well with draft model quantization configs.
  87. """
  88. run_equality_correctness_test(aphrodite_runner,
  89. common_llm_kwargs,
  90. per_test_common_llm_kwargs,
  91. baseline_llm_kwargs,
  92. test_llm_kwargs,
  93. batch_size,
  94. max_output_len=32,
  95. seed=seed,
  96. temperature=0.0)