123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- import pytest
- from aphrodite import SamplingParams
- from .conftest import get_output_from_llm_generator
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "model": "JackFram/llama-68m",
- "speculative_model": "JackFram/llama-68m",
- "num_speculative_tokens": 5,
- # Required for spec decode.
- "use_v2_block_manager": True
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [
- {
- "enable_chunked_prefill": True,
- },
- ])
- @pytest.mark.parametrize("test_llm_kwargs", [{}])
- @pytest.mark.parametrize("seed", [1])
- def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
- """Verify that speculative decoding with chunked prefill fails.
- """
- output_len = 128
- temperature = 0.0
- prompts = [
- "Hello, my name is",
- ]
- sampling_params = SamplingParams(
- max_tokens=output_len,
- ignore_eos=True,
- temperature=temperature,
- )
- with pytest.raises(ValueError,
- match="Speculative decoding and chunked prefill"):
- get_output_from_llm_generator(test_llm_generator, prompts,
- sampling_params)
- @pytest.mark.parametrize(
- "common_llm_kwargs",
- [{
- "model": "meta-llama/Llama-2-7b-chat-hf",
- "speculative_model": "JackFram/llama-68m",
- "num_speculative_tokens": 5,
- # Required for spec decode.
- "use_v2_block_manager": True
- }])
- @pytest.mark.parametrize(
- "per_test_common_llm_kwargs",
- [
- {
- # Speculative max model len > overridden max model len should raise.
- "max_model_len": 128,
- "speculative_max_model_len": 129,
- },
- {
- # Speculative max model len > draft max model len should raise.
- # https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
- "speculative_max_model_len": 2048 + 1,
- },
- {
- # Speculative max model len > target max model len should raise.
- # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
- "speculative_max_model_len": 4096 + 1,
- },
- ])
- @pytest.mark.parametrize("test_llm_kwargs", [{}])
- @pytest.mark.parametrize("seed", [1])
- def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
- """Verify that speculative decoding validates speculative_max_model_len.
- """
- output_len = 128
- temperature = 0.0
- prompts = [
- "Hello, my name is",
- ]
- sampling_params = SamplingParams(
- max_tokens=output_len,
- ignore_eos=True,
- temperature=temperature,
- )
- with pytest.raises(ValueError, match="cannot be larger than"):
- get_output_from_llm_generator(test_llm_generator, prompts,
- sampling_params)
- @pytest.mark.parametrize("common_llm_kwargs", [{
- "model": "JackFram/llama-68m",
- "speculative_model": "JackFram/llama-68m",
- "num_speculative_tokens": 5,
- }])
- @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
- @pytest.mark.parametrize("test_llm_kwargs", [{}])
- @pytest.mark.parametrize("seed", [1])
- def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
- """Verify that speculative decoding with block manager v1 fails.
- """
- output_len = 128
- temperature = 0.0
- prompts = [
- "Hello, my name is",
- ]
- sampling_params = SamplingParams(
- max_tokens=output_len,
- ignore_eos=True,
- temperature=temperature,
- )
- with pytest.raises(ValueError,
- match="Speculative decoding requires usage of the V2"):
- get_output_from_llm_generator(test_llm_generator, prompts,
- sampling_params)
|