123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- """This docstring details important information on the testing methodology.
- Most of the tests rely on "greedy equality", where we expect the output of
- speculative decoding on a sequence to exactly match the output of normal non-
- speculative decoding.
- Since speculative decoding with rejection sampling guarantees that the output
- distribution matches the target model's output distribution (up to hardware
- numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
- equality.
- However, we still need to verify below scenario could be passed:
- * Batch size 1 greedy equality
- * Batch size >1 greedy equality
- * Test greedy equality under preemption
- * Test greedy equality under various number of speculative tokens.
- With those tests, we can say at least, Medusa would not break the
- correctess for the target model outputs.
- """
- import pytest
- from .conftest import run_greedy_equality_correctness_test
- # main model
- # lmsys/vicuna-7b-v1.3 was to be used but it's causing
- # OOM in CI pipeline, so using a smaller model.
- MAIN_MODEL = "JackFram/llama-68m"
- # speculative model
- SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
- # max. number of speculative tokens: this corresponds to
- # num_heads in the config.json of the speculator model.
- MAX_SPEC_TOKENS = 5
- # precision
- PRECISION = "float32"
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True,
- # Print spec metrics.
- "disable_log_stats": False,
- # Precision
- "dtype": PRECISION,
- # Main model
- "model": MAIN_MODEL,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs", [
- {
- "speculative_model": SPEC_MODEL,
- "num_speculative_tokens": MAX_SPEC_TOKENS,
- },
- ])
- @pytest.mark.parametrize("output_len", [
- 128,
- ])
- @pytest.mark.parametrize("batch_size", [1, 32])
- @pytest.mark.parametrize("seed", [1])
- def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
- batch_size: int, output_len: int):
- """Verify greedy equality with different batch size."""
- run_greedy_equality_correctness_test(baseline_llm_generator,
- test_llm_generator,
- batch_size,
- max_output_len=output_len,
- force_output_len=True)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "block_size": 8,
- # 2 for small prompt, 256//8 for generated.
- "num_gpu_blocks_override": 2 + 256 // 8,
- "max_model_len": (2 + 256 // 8) * 8,
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True,
- # Precision
- "dtype": PRECISION,
- # Main model
- "model": MAIN_MODEL,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs", [
- {
- "speculative_model": SPEC_MODEL,
- "num_speculative_tokens": MAX_SPEC_TOKENS,
- },
- ])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use small output len for fast test.
- 128,
- ])
- @pytest.mark.parametrize("batch_size", [4])
- @pytest.mark.parametrize("seed", [1])
- def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
- test_llm_generator,
- batch_size: int,
- output_len: int):
- """Verify greedy equality, even when some sequences are preempted mid-
- generation.
- """
- run_greedy_equality_correctness_test(baseline_llm_generator,
- test_llm_generator,
- batch_size,
- max_output_len=output_len,
- force_output_len=True)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True,
- # Precision
- "dtype": PRECISION,
- # Main model
- "model": MAIN_MODEL,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize(
- "test_llm_kwargs",
- [
- {
- "speculative_model": SPEC_MODEL,
- "num_speculative_tokens": k,
- }
- # Try a range of num. speculative tokens
- for k in range(1, 1 + MAX_SPEC_TOKENS)
- ])
- @pytest.mark.parametrize("batch_size", [2])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use smaller output len for fast test.
- 32,
- ])
- @pytest.mark.parametrize("seed", [1])
- def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
- batch_size: int, output_len: int):
- """Verify that mlp speculative decoding produces exact equality
- to without spec decode with different values of num_speculative_tokens.
- """
- run_greedy_equality_correctness_test(baseline_llm_generator,
- test_llm_generator,
- batch_size,
- max_output_len=output_len,
- force_output_len=True)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True,
- # Precision
- "dtype": PRECISION,
- # Main model
- "model": MAIN_MODEL,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs",
- [{
- "speculative_model": SPEC_MODEL,
- "num_speculative_tokens": MAX_SPEC_TOKENS,
- "speculative_disable_by_batch_size": 4
- }])
- @pytest.mark.parametrize("batch_size", [1, 5])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use smaller output len for fast test.
- 32,
- ])
- @pytest.mark.parametrize("seed", [1])
- def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
- batch_size: int, output_len: int):
- """Verify that mlp speculative decoding produces exact equality
- to without spec decode when speculation is disabled for large
- batch sizes.
- """
- run_greedy_equality_correctness_test(baseline_llm_generator,
- test_llm_generator,
- batch_size,
- max_output_len=output_len,
- force_output_len=True)
- if __name__ == "__main__":
- import pytest
- pytest.main([__file__])
|