|
@@ -344,3 +344,69 @@ def run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
|
|
b=baseline_rank_to_logprob[rank],
|
|
|
abs_tol=1e-1,
|
|
|
)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "common_llm_kwargs",
|
|
|
+ [{
|
|
|
+ "model": "JackFram/llama-160m",
|
|
|
+ # Skip cuda graph recording for fast test.
|
|
|
+ "enforce_eager": True,
|
|
|
+ # Required for spec decode.
|
|
|
+ "use_v2_block_manager": True,
|
|
|
+ "max_logprobs": 6,
|
|
|
+ }])
|
|
|
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
|
|
+@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
|
|
+@pytest.mark.parametrize("test_llm_kwargs",
|
|
|
+ [{
|
|
|
+ "speculative_model": "JackFram/llama-68m",
|
|
|
+ "num_speculative_tokens": 3,
|
|
|
+ "disable_logprobs_during_spec_decoding": True,
|
|
|
+ }])
|
|
|
+@pytest.mark.parametrize("seed", [1])
|
|
|
+def test_logprobs_disabled(baseline_llm_generator, test_llm_generator):
|
|
|
+ """Check the behavior when logprobs are disabled.
|
|
|
+ Token choices should match with the base model.
|
|
|
+ """
|
|
|
+ prompts = [
|
|
|
+ "Hello, my name is",
|
|
|
+ "The president of the United States is",
|
|
|
+ "The capital of France is",
|
|
|
+ "The future of AI is",
|
|
|
+ "San Francisco is know for its",
|
|
|
+ "Facebook was created in 2004 by",
|
|
|
+ "Curious George is a",
|
|
|
+ "Python 3.11 brings improvements to its",
|
|
|
+ ]
|
|
|
+ prompts = [prompt for prompt, _ in zip(cycle(prompts), range(4))]
|
|
|
+ sampling_params = SamplingParams(
|
|
|
+ # Use smaller output len for fast test
|
|
|
+ max_tokens=7,
|
|
|
+ ignore_eos=True,
|
|
|
+ temperature=0.0,
|
|
|
+ logprobs=2,
|
|
|
+ )
|
|
|
+ spec_batch_logprobs = get_logprobs_from_llm_generator(
|
|
|
+ test_llm_generator, prompts, sampling_params)
|
|
|
+ baseline_batch_logprobs = get_logprobs_from_llm_generator(
|
|
|
+ baseline_llm_generator, prompts, sampling_params)
|
|
|
+ assert len(baseline_batch_logprobs) == len(prompts)
|
|
|
+ assert len(spec_batch_logprobs) == len(prompts)
|
|
|
+ # For each sequence in the batch.
|
|
|
+ for _, (baseline_logprobs, spec_logprobs) in enumerate(
|
|
|
+ zip(baseline_batch_logprobs, spec_batch_logprobs)):
|
|
|
+ assert len(spec_logprobs) == len(baseline_logprobs)
|
|
|
+ # For each generated position of the sequence.
|
|
|
+ for _, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
|
|
+ zip(spec_logprobs, baseline_logprobs)):
|
|
|
+ assert len(spec_pos_logprobs) == 1
|
|
|
+ spec_top_token_id = list(spec_pos_logprobs)[0]
|
|
|
+ spec_top_logprob = spec_pos_logprobs[spec_top_token_id]
|
|
|
+ assert spec_top_logprob.logprob == 0.0
|
|
|
+ assert spec_top_logprob.rank == -1
|
|
|
+ # check that the chosen token matches the base model
|
|
|
+ baseline_logprob = baseline_pos_logprobs[spec_top_token_id]
|
|
|
+ assert baseline_logprob.rank == 1
|
|
|
+ assert spec_top_logprob.decoded_token \
|
|
|
+ == baseline_logprob.decoded_token
|