1
0

test_integration.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """Tests which cover integration of the speculative decoding framework with
  2. other features, e.g. cuda graphs.
  3. """
  4. import pytest
  5. from .conftest import run_greedy_equality_correctness_test
  6. @pytest.mark.parametrize(
  7. "common_llm_kwargs",
  8. [{
  9. # Required for spec decode.
  10. "use_v2_block_manager": True,
  11. # Verify equality when cuda graphs allowed.
  12. "enforce_eager": False,
  13. "model": "JackFram/llama-68m",
  14. }])
  15. @pytest.mark.parametrize(
  16. "per_test_common_llm_kwargs",
  17. [
  18. {
  19. # Identical models.
  20. "speculative_model": "JackFram/llama-68m",
  21. "num_speculative_tokens": 5,
  22. },
  23. ])
  24. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  25. @pytest.mark.parametrize("test_llm_kwargs", [{}])
  26. @pytest.mark.parametrize("batch_size", [8])
  27. @pytest.mark.parametrize("output_len", [32])
  28. @pytest.mark.parametrize("seed", [1])
  29. def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
  30. batch_size, output_len):
  31. """Verify spec decode equality when cuda graphs are enabled.
  32. """
  33. run_greedy_equality_correctness_test(
  34. baseline_llm_generator,
  35. test_llm_generator,
  36. batch_size,
  37. max_output_len=output_len,
  38. force_output_len=True,
  39. )
  40. @pytest.mark.parametrize(
  41. "common_llm_kwargs",
  42. [{
  43. "model": "JackFram/llama-160m",
  44. # Skip cuda graph recording for fast test.
  45. "enforce_eager": True,
  46. # Required for spec decode.
  47. "use_v2_block_manager": True,
  48. }])
  49. @pytest.mark.parametrize("per_test_common_llm_kwargs", [
  50. {
  51. "speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
  52. "num_speculative_tokens": 5,
  53. },
  54. ])
  55. @pytest.mark.parametrize(
  56. "test_llm_kwargs",
  57. [
  58. # Explicitly specify draft model quantization
  59. {
  60. "speculative_model_quantization": "gptq",
  61. },
  62. # Explicitly specify GPTQ-based draft model to use marlin quantization
  63. {
  64. "speculative_model_quantization": "marlin",
  65. },
  66. # Not explicitly specify draft model quantization
  67. {
  68. "speculative_model_quantization": None,
  69. },
  70. ])
  71. @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
  72. @pytest.mark.parametrize("batch_size", [2])
  73. @pytest.mark.parametrize("seed", [1])
  74. def test_speculative_model_quantization_config(baseline_llm_generator,
  75. test_llm_generator,
  76. batch_size: int):
  77. """Verify spec decode works well with draft model quantization configs.
  78. """
  79. run_greedy_equality_correctness_test(baseline_llm_generator,
  80. test_llm_generator,
  81. batch_size,
  82. max_output_len=32,
  83. force_output_len=True)