|
- """This docstring details important information on the testing methodology.
- Most of the tests rely on "greedy equality", where we expect the output of
- speculative decoding on a sequence to exactly match the output of normal non-
- speculative decoding.
- Since speculative decoding with rejection sampling guarantees that the output
- distribution matches the target model's output distribution (up to hardware
- numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
- equality.
- For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
- and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
- Since there is no model is needed for generate the proposal, we could make
- the testcase much simpler than drafter multi-step one.
- However, we still need to verify below scenario could be passed:
- * Batch size 1 greedy equality
- * Batch size >1 greedy equality
- * Test greedy equality under preemption
- * Test greedy equality under various ngram sizes / speculative sizes
- With those tests, we can say at least, ngram spec would not break the correctess
- for the target model outputs.
- """
- import pytest
- from .conftest import run_greedy_equality_correctness_test
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True,
- # Print spec metrics.
- "disable_log_stats": False,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [
- {
- "model": "JackFram/llama-68m",
- },
- ])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs", [
- {
- "speculative_model": "[ngram]",
- "num_speculative_tokens": 5,
- "ngram_prompt_lookup_max": 3,
- },
- ])
- @pytest.mark.parametrize("output_len", [
- 256,
- ])
- @pytest.mark.parametrize("batch_size", [1, 32])
- @pytest.mark.parametrize("seed", [1])
- def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
- test_llm_generator, batch_size: int,
- output_len: int):
- """Verify greedy equality on a tiny model with different batch size."""
- run_greedy_equality_correctness_test(baseline_llm_generator,
- test_llm_generator,
- batch_size,
- max_output_len=output_len,
- force_output_len=True)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "block_size": 8,
- # 2 for small prompt, 256//8 for generated.
- "num_gpu_blocks_override": 2 + 256 // 8,
- "max_model_len": (2 + 256 // 8) * 8,
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [
- {
- "model": "JackFram/llama-160m",
- },
- ])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs", [
- {
- "speculative_model": "[ngram]",
- "num_speculative_tokens": 5,
- "ngram_prompt_lookup_max": 3,
- },
- ])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use small output len for fast test.
- 256,
- ])
- @pytest.mark.parametrize("batch_size", [4])
- @pytest.mark.parametrize("seed", [1])
- def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
- test_llm_generator,
- batch_size: int,
- output_len: int):
- """Verify greedy equality, even when some sequences are preempted mid-
- generation.
- """
- run_greedy_equality_correctness_test(baseline_llm_generator,
- test_llm_generator,
- batch_size,
- max_output_len=output_len,
- force_output_len=True)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "model": "JackFram/llama-68m",
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize(
- "test_llm_kwargs",
- [
- {
- "speculative_model": "[ngram]",
- "num_speculative_tokens": k,
- "ngram_prompt_lookup_max": 3,
- }
- # Try a range of common k, as well as large speculation.
- for k in [1, 3, 5]
- ] + [
- {
- "speculative_model": "[ngram]",
- "num_speculative_tokens": k,
- "ngram_prompt_lookup_max": 1,
- }
- # Try a range of common k, as well as large speculation.
- for k in [1, 3, 5]
- ])
- @pytest.mark.parametrize("batch_size", [2])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use smaller output len for fast test.
- 32,
- ])
- @pytest.mark.parametrize("seed", [1])
- def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
- batch_size: int, output_len: int):
- """Verify that ngram speculative decoding produces exact equality
- to without spec decode with many different values of k and
- different ngram_prompt_lookup_max.
- """
- run_greedy_equality_correctness_test(baseline_llm_generator,
- test_llm_generator,
- batch_size,
- max_output_len=output_len,
- force_output_len=True)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "model": "JackFram/llama-68m",
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs",
- [{
- "speculative_model": "[ngram]",
- "num_speculative_tokens": 5,
- "ngram_prompt_lookup_max": 3,
- "speculative_disable_by_batch_size": 4
- }])
- @pytest.mark.parametrize("batch_size", [1, 5])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use smaller output len for fast test.
- 32,
- ])
- @pytest.mark.parametrize("seed", [1])
- def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator,
- batch_size: int, output_len: int):
- """Verify that ngram speculative decoding produces exact equality
- to without spec decode with many different values of k and
- different ngram_prompt_lookup_max.
- """
- run_greedy_equality_correctness_test(baseline_llm_generator,
- test_llm_generator,
- batch_size,
- max_output_len=output_len,
- force_output_len=True)
|