123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- """This docstring details important information on the testing methodology.
- Most of the tests rely on "greedy equality", where we expect the output of
- speculative decoding on a sequence to exactly match the output of normal non-
- speculative decoding.
- Since speculative decoding with rejection sampling guarantees that the output
- distribution matches the target model's output distribution (up to hardware
- numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
- equality.
- However, we still need to verify below scenario could be passed:
- * Batch size 1 greedy equality
- * Batch size >1 greedy equality
- * Test greedy equality under preemption
- * Test greedy equality under various number of speculative tokens.
- With those tests, we can say at least, Medusa would not break the
- correctess for the target model outputs.
- """
- import pytest
- from .conftest import run_equality_correctness_test
- # main model
- # lmsys/vicuna-7b-v1.3 was to be used but it's causing
- # OOM in CI pipeline, so using a smaller model.
- MAIN_MODEL = "JackFram/llama-68m"
- # speculative model
- SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
- # max. number of speculative tokens: this corresponds to
- # num_heads in the config.json of the speculator model.
- MAX_SPEC_TOKENS = 5
- # precision
- PRECISION = "float32"
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True,
- # Print spec metrics.
- "disable_log_stats": False,
- # Precision
- "dtype": PRECISION,
- # Main model
- "model_name": MAIN_MODEL,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs", [
- {
- "speculative_model": SPEC_MODEL,
- "num_speculative_tokens": MAX_SPEC_TOKENS,
- },
- ])
- @pytest.mark.parametrize("output_len", [
- 128,
- ])
- @pytest.mark.parametrize("batch_size", [1, 32])
- @pytest.mark.parametrize("seed", [1])
- def test_medusa_e2e_greedy_correctness(aphrodite_runner, common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs, test_llm_kwargs,
- batch_size: int, output_len: int,
- seed: int):
- """Verify greedy equality with different batch size."""
- run_equality_correctness_test(aphrodite_runner,
- common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs,
- test_llm_kwargs,
- batch_size,
- max_output_len=output_len,
- seed=seed,
- temperature=0.0)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "enforce_eager": False,
- # Required for spec decode.
- "use_v2_block_manager": True,
- # Print spec metrics.
- "disable_log_stats": False,
- # Precision
- "dtype": PRECISION,
- # Main model
- "model_name": MAIN_MODEL,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs", [
- {
- "speculative_model": SPEC_MODEL,
- "num_speculative_tokens": MAX_SPEC_TOKENS,
- },
- ])
- @pytest.mark.parametrize("output_len", [
- 128,
- ])
- @pytest.mark.parametrize("batch_size", [1, 32])
- @pytest.mark.parametrize("seed", [1])
- def test_medusa_e2e_greedy_correctness_cuda_graph(
- aphrodite_runner, common_llm_kwargs, per_test_common_llm_kwargs,
- baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
- seed: int):
- """Verify greedy equality with cuda graph enabled and different
- batch sizes."""
- run_equality_correctness_test(aphrodite_runner,
- common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs,
- test_llm_kwargs,
- batch_size,
- max_output_len=output_len,
- seed=seed,
- temperature=0.0)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "block_size": 8,
- # 2 for small prompt, 256//8 for generated.
- "num_gpu_blocks_override": 2 + 256 // 8,
- "max_model_len": (2 + 256 // 8) * 8,
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True,
- # Precision
- "dtype": PRECISION,
- # Main model
- "model_name": MAIN_MODEL,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs", [
- {
- "speculative_model": SPEC_MODEL,
- "num_speculative_tokens": MAX_SPEC_TOKENS,
- },
- ])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use small output len for fast test.
- 128,
- ])
- @pytest.mark.parametrize("batch_size", [4])
- @pytest.mark.parametrize("seed", [1])
- def test_medusa_e2e_greedy_correctness_with_preemption(
- aphrodite_runner, common_llm_kwargs, per_test_common_llm_kwargs,
- baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
- seed: int):
- """Verify greedy equality, even when some sequences are preempted mid-
- generation.
- """
- run_equality_correctness_test(aphrodite_runner,
- common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs,
- test_llm_kwargs,
- batch_size,
- max_output_len=output_len,
- seed=seed,
- temperature=0.0)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True,
- # Precision
- "dtype": PRECISION,
- # Main model
- "model_name": MAIN_MODEL,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize(
- "test_llm_kwargs",
- [
- {
- "speculative_model": SPEC_MODEL,
- "num_speculative_tokens": k,
- }
- # Try a range of num. speculative tokens
- for k in range(1, 1 + MAX_SPEC_TOKENS)
- ])
- @pytest.mark.parametrize("batch_size", [2])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use smaller output len for fast test.
- 32,
- ])
- @pytest.mark.parametrize("seed", [1])
- def test_medusa_different_k(aphrodite_runner, common_llm_kwargs,
- per_test_common_llm_kwargs, baseline_llm_kwargs,
- test_llm_kwargs, batch_size: int, output_len: int,
- seed: int):
- """Verify that medusa speculative decoding produces exact equality
- to without spec decode with different values of num_speculative_tokens.
- """
- run_equality_correctness_test(aphrodite_runner,
- common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs,
- test_llm_kwargs,
- batch_size,
- max_output_len=output_len,
- seed=seed,
- temperature=0.0)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True,
- # Precision
- "dtype": PRECISION,
- # Main model
- "model_name": MAIN_MODEL,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs",
- [{
- "speculative_model": SPEC_MODEL,
- "num_speculative_tokens": MAX_SPEC_TOKENS,
- "speculative_disable_by_batch_size": 4
- }])
- @pytest.mark.parametrize("batch_size", [1, 5])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use smaller output len for fast test.
- 32,
- ])
- @pytest.mark.parametrize("seed", [1])
- def test_medusa_disable_queue(aphrodite_runner, common_llm_kwargs,
- per_test_common_llm_kwargs, baseline_llm_kwargs,
- test_llm_kwargs, batch_size: int,
- output_len: int, seed: int):
- """Verify that medusa speculative decoding produces exact equality
- to without spec decode when speculation is disabled for large
- batch sizes.
- """
- run_equality_correctness_test(aphrodite_runner,
- common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs,
- test_llm_kwargs,
- batch_size,
- max_output_len=output_len,
- seed=seed,
- temperature=0.0)
- if __name__ == "__main__":
- import pytest
- pytest.main([__file__])
|