123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- """Tests which cover integration of the speculative decoding framework with
- tensor parallelism.
- """
- import pytest
- import torch
- from aphrodite.common.utils import is_hip
- from .conftest import run_equality_correctness_test_tp
- @pytest.mark.skipif(torch.cuda.device_count() < 2,
- reason="Need at least 2 GPUs to run the test.")
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [[
- # Skip cuda graph recording for fast test.
- "--enforce-eager",
- # Required for spec decode.
- "--use-v2-block-manager",
- "--tensor-parallel-size",
- "2"
- ]])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
- @pytest.mark.parametrize("baseline_llm_kwargs", [[]])
- @pytest.mark.parametrize("test_llm_kwargs", [
- [
- "--speculative-model",
- "JackFram/llama-68m",
- "--num-speculative-tokens",
- "3",
- ],
- [
- "--speculative-model",
- "[ngram]",
- "--num-speculative-tokens",
- "5",
- "--ngram-prompt-lookup-max",
- "3",
- ],
- ])
- @pytest.mark.parametrize("batch_size", [2])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use smaller output len for fast test.
- 32,
- ])
- @pytest.mark.parametrize("seed", [1])
- def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
- baseline_llm_kwargs, test_llm_kwargs,
- batch_size: int, output_len: int, seed: int):
- """Verify greedy equality when tensor parallelism is used.
- """
- if is_hip():
- pytest.skip("hip is not well-supported yet")
- run_equality_correctness_test_tp("JackFram/llama-68m",
- common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs,
- test_llm_kwargs,
- batch_size,
- output_len,
- seed,
- temperature=0.0)
- @pytest.mark.skipif(torch.cuda.device_count() < 2,
- reason="Need at least 2 GPUs to run the test.")
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [[
- # Skip cuda graph recording for fast test.
- "--enforce-eager",
- # Required for spec decode.
- "--use_v2_block_manager",
- "--tensor_parallel_size",
- "2",
- # precision
- "--dtype",
- "bfloat16",
- ]])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
- @pytest.mark.parametrize("baseline_llm_kwargs", [[]])
- @pytest.mark.parametrize("model, test_llm_kwargs",
- [("JackFram/llama-68m", [
- "--speculative-model",
- "JackFram/llama-68m",
- "--num_speculative-tokens",
- "5",
- "--speculative-draft-tensor-parallel-size",
- "1",
- ]),
- ("ibm-granite/granite-3b-code-instruct", [
- "--speculative-model",
- "ibm-granite/granite-3b-code-instruct",
- "--num_speculative-tokens",
- "5",
- "--speculative-draft-tensor-parallel-size",
- "1",
- ])])
- @pytest.mark.parametrize("batch_size", [2])
- @pytest.mark.parametrize("seed", [1])
- def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs,
- test_llm_kwargs, batch_size: int,
- seed: int):
- """Verify spec decode works well with smaller tp for draft models.
- """
- run_equality_correctness_test_tp(model,
- common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs,
- test_llm_kwargs,
- batch_size,
- max_output_len=32,
- seed=seed,
- temperature=0.0)
|