test_integration_dist_tp2.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. """Tests which cover integration of the speculative decoding framework with
  2. tensor parallelism.
  3. """
  4. import pytest
  5. import torch
  6. from aphrodite.common.utils import is_hip
  7. from .conftest import run_equality_correctness_test_tp
  8. @pytest.mark.skipif(torch.cuda.device_count() < 2,
  9. reason="Need at least 2 GPUs to run the test.")
  10. @pytest.mark.parametrize(
  11. "common_llm_kwargs",
  12. [[
  13. # Skip cuda graph recording for fast test.
  14. "--enforce-eager",
  15. # Required for spec decode.
  16. "--use-v2-block-manager",
  17. "--tensor-parallel-size",
  18. "2"
  19. ]])
  20. @pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
  21. @pytest.mark.parametrize("baseline_llm_kwargs", [[]])
  22. @pytest.mark.parametrize("test_llm_kwargs", [
  23. [
  24. "--speculative-model",
  25. "JackFram/llama-68m",
  26. "--num-speculative-tokens",
  27. "3",
  28. ],
  29. [
  30. "--speculative-model",
  31. "[ngram]",
  32. "--num-speculative-tokens",
  33. "5",
  34. "--ngram-prompt-lookup-max",
  35. "3",
  36. ],
  37. ])
  38. @pytest.mark.parametrize("batch_size", [2])
  39. @pytest.mark.parametrize(
  40. "output_len",
  41. [
  42. # Use smaller output len for fast test.
  43. 32,
  44. ])
  45. @pytest.mark.parametrize("seed", [1])
  46. def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
  47. baseline_llm_kwargs, test_llm_kwargs,
  48. batch_size: int, output_len: int, seed: int):
  49. """Verify greedy equality when tensor parallelism is used.
  50. """
  51. if is_hip():
  52. pytest.skip("hip is not well-supported yet")
  53. run_equality_correctness_test_tp("JackFram/llama-68m",
  54. common_llm_kwargs,
  55. per_test_common_llm_kwargs,
  56. baseline_llm_kwargs,
  57. test_llm_kwargs,
  58. batch_size,
  59. output_len,
  60. seed,
  61. temperature=0.0)
  62. @pytest.mark.skipif(torch.cuda.device_count() < 2,
  63. reason="Need at least 2 GPUs to run the test.")
  64. @pytest.mark.parametrize(
  65. "common_llm_kwargs",
  66. [[
  67. # Skip cuda graph recording for fast test.
  68. "--enforce-eager",
  69. # Required for spec decode.
  70. "--use_v2_block_manager",
  71. "--tensor_parallel_size",
  72. "2",
  73. # precision
  74. "--dtype",
  75. "bfloat16",
  76. ]])
  77. @pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
  78. @pytest.mark.parametrize("baseline_llm_kwargs", [[]])
  79. @pytest.mark.parametrize("model, test_llm_kwargs",
  80. [("JackFram/llama-68m", [
  81. "--speculative-model",
  82. "JackFram/llama-68m",
  83. "--num_speculative-tokens",
  84. "5",
  85. "--speculative-draft-tensor-parallel-size",
  86. "1",
  87. ]),
  88. ("ibm-granite/granite-3b-code-instruct", [
  89. "--speculative-model",
  90. "ibm-granite/granite-3b-code-instruct",
  91. "--num_speculative-tokens",
  92. "5",
  93. "--speculative-draft-tensor-parallel-size",
  94. "1",
  95. ])])
  96. @pytest.mark.parametrize("batch_size", [2])
  97. @pytest.mark.parametrize("seed", [1])
  98. def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
  99. per_test_common_llm_kwargs,
  100. baseline_llm_kwargs,
  101. test_llm_kwargs, batch_size: int,
  102. seed: int):
  103. """Verify spec decode works well with smaller tp for draft models.
  104. """
  105. run_equality_correctness_test_tp(model,
  106. common_llm_kwargs,
  107. per_test_common_llm_kwargs,
  108. baseline_llm_kwargs,
  109. test_llm_kwargs,
  110. batch_size,
  111. max_output_len=32,
  112. seed=seed,
  113. temperature=0.0)