123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281 |
- from itertools import cycle
- import pytest
- from aphrodite import SamplingParams
- from .conftest import run_logprob_correctness_test
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "model_name": "JackFram/llama-68m",
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs",
- [{
- "speculative_model": "JackFram/llama-160m",
- "num_speculative_tokens": 3,
- "disable_logprobs_during_spec_decoding": False,
- }])
- @pytest.mark.parametrize("batch_size", [8])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use smaller output len for fast test.
- 7,
- ])
- @pytest.mark.parametrize("seed", [1])
- @pytest.mark.parametrize("logprobs", [1, 6])
- def test_logprobs_equality(aphrodite_runner, common_llm_kwargs,
- per_test_common_llm_kwargs, baseline_llm_kwargs,
- test_llm_kwargs, batch_size: int, output_len: int,
- seed: int, logprobs: int):
- """Verify output logprobs are equal with and without speculative decoding.
- """
- run_logprob_correctness_test(aphrodite_runner,
- common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs,
- test_llm_kwargs,
- batch_size,
- output_len,
- seed,
- temperature=0.0,
- logprobs=logprobs)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "model_name": "JackFram/llama-68m",
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs",
- [{
- "speculative_model": "JackFram/llama-160m",
- "num_speculative_tokens": 3,
- "disable_logprobs_during_spec_decoding": False,
- }, {
- "speculative_model": "JackFram/llama-160m",
- "num_speculative_tokens": 6,
- "disable_logprobs_during_spec_decoding": False,
- }])
- @pytest.mark.parametrize("batch_size", [8])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use smaller output len for fast test.
- 32,
- ])
- @pytest.mark.parametrize("seed", [1])
- @pytest.mark.parametrize("logprobs", [1, 6])
- def test_logprobs_different_k(aphrodite_runner, common_llm_kwargs,
- per_test_common_llm_kwargs, baseline_llm_kwargs,
- test_llm_kwargs, batch_size: int,
- output_len: int, seed: int, logprobs: int):
- """Veriy logprob greedy equality with different speculation lens.
- """
- run_logprob_correctness_test(aphrodite_runner,
- common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs,
- test_llm_kwargs,
- batch_size,
- output_len,
- seed,
- temperature=0.0,
- logprobs=logprobs)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "model_name": "JackFram/llama-68m",
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize(
- "test_llm_kwargs",
- [{
- "speculative_model": "JackFram/llama-160m",
- "num_speculative_tokens": 3,
- "disable_logprobs_during_spec_decoding": False,
- # Artificially limit the draft model max model len; this forces
- # Aphrodite to skip speculation once the sequences grow beyond 32-k
- # tokens.
- "speculative_max_model_len": 32,
- }])
- @pytest.mark.parametrize("batch_size", [8])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use smaller output len for fast test.
- 32,
- ])
- @pytest.mark.parametrize("seed", [1])
- @pytest.mark.parametrize("logprobs", [1])
- def test_logprobs_when_skip_speculation(aphrodite_runner, common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs, test_llm_kwargs,
- batch_size: int, output_len: int,
- seed: int, logprobs: int):
- """Verify logprobs greedy equality when some sequences skip speculation.
- """
- run_logprob_correctness_test(aphrodite_runner,
- common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs,
- test_llm_kwargs,
- batch_size,
- output_len,
- seed,
- temperature=0.0,
- logprobs=logprobs)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "model_name": "JackFram/llama-68m",
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs",
- [{
- "speculative_model": "JackFram/llama-160m",
- "num_speculative_tokens": 3,
- "disable_logprobs_during_spec_decoding": False,
- }])
- @pytest.mark.parametrize("batch_size", [1])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use smaller output len for fast test.
- 32,
- ])
- @pytest.mark.parametrize("seed", [1])
- @pytest.mark.parametrize("logprobs", [6])
- def test_logprobs_temp_1(aphrodite_runner, common_llm_kwargs,
- per_test_common_llm_kwargs, baseline_llm_kwargs,
- test_llm_kwargs, batch_size: int, output_len: int,
- seed: int, logprobs: int):
- """Verify at least one logprob result has num_logprobs+1, which tests the
- case where the sampled token is not in top-k logprobs.
- Ideally, this test should validate equality with non-spec by getting
- logprobs. This is left as future improvement.
- """
- temperature = 1.0
- prompts = [
- "Hello, my name is",
- "The president of the United States is",
- "The capital of France is",
- "The future of AI is",
- "San Francisco is know for its",
- "Facebook was created in 2004 by",
- "Curious George is a",
- "Python 3.11 brings improvements to its",
- ]
- prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
- sampling_params = SamplingParams(
- max_tokens=output_len,
- ignore_eos=True,
- temperature=temperature,
- logprobs=logprobs,
- )
- sd_args = {
- **common_llm_kwargs,
- **per_test_common_llm_kwargs,
- **test_llm_kwargs,
- }
- with aphrodite_runner(**sd_args) as aphrodite_model:
- sd_outputs = aphrodite_model.generate_w_logprobs(prompts,
- sampling_params)
- num_returned_logprobs = [
- len(seq_logprobs) for seq_logprobs in sd_outputs[-1]
- ]
- # Assert one of the returned logprobs has > num_logprobs (indicating the
- # sampled token is not in top-k).
- assert any(
- [num_returned > logprobs for num_returned in num_returned_logprobs])
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "model_name": "JackFram/llama-160m",
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs",
- [{
- "speculative_model": "JackFram/llama-68m",
- "num_speculative_tokens": 3,
- "disable_logprobs_during_spec_decoding": True,
- }])
- @pytest.mark.parametrize("seed", [1])
- @pytest.mark.parametrize("batch_size", [4])
- @pytest.mark.parametrize(
- "output_len",
- [
- # Use smaller output len for fast test.
- 32,
- ])
- @pytest.mark.parametrize("logprobs", [0])
- def test_logprobs_disabled(aphrodite_runner, common_llm_kwargs,
- per_test_common_llm_kwargs, baseline_llm_kwargs,
- test_llm_kwargs, batch_size: int, output_len: int,
- seed: int, logprobs: int):
- """Check the behavior when logprobs are disabled.
- Token choices should match with the base model.
- """
- run_logprob_correctness_test(aphrodite_runner,
- common_llm_kwargs,
- per_test_common_llm_kwargs,
- baseline_llm_kwargs,
- test_llm_kwargs,
- batch_size,
- output_len,
- seed,
- temperature=0.0,
- logprobs=logprobs)
|