Procházet zdrojové kódy

spec decode: fix logprobs when using speculative decoding (#904)

AlpinDale před 2 měsíci
rodič
revize
ab533e0e60

+ 9 - 7
aphrodite/spec_decode/util.py

@@ -64,24 +64,26 @@ def create_sequence_group_output(
         token_id_logprob_rank (int): The logprob rank of the sampled token.
         token_id_logprob (float): The logprob value of the sampled token.
         seq_id (int): The sequence id.
-        topk_token_ids (List[int]): The list of top-k token ids.
-        topk_logprobs (List[float]): The list of top-k logprobs.
+        topk_token_ids (List[Optional[int]]): The list of top-k token ids.
+        topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
     """
     # Aphrodite logprobs always include the sampled token. In addition, the
     # user may request topk-logprobs (where top-k varies per user up to
     # max_logprobs).
-    logprobs: Dict[Optional[int], Logprob] = {
+    logprobs: Dict[int, Logprob] = {
         token_id: Logprob(
             logprob=token_id_logprob,
             rank=token_id_logprob_rank,
         ),
     }
     logprobs.update({
-        topk_token_ids[topk_logprob_index]: Logprob(
-            logprob=topk_logprobs[topk_logprob_index],
-            rank=topk_logprob_index + 1,
+        topk_token_id: Logprob(
+            logprob=topk_logprob if topk_logprob is not None else 0.0,
+            rank=topk_index + 1,
         )
-        for topk_logprob_index, _ in enumerate(topk_token_ids)
+        for topk_index, (topk_token_id, topk_logprob) \
+            in enumerate(zip(topk_token_ids, topk_logprobs)) \
+        if topk_token_id is not None
     })
 
     return CompletionSequenceGroupOutput(

+ 66 - 0
tests/spec_decode/e2e/test_logprobs.py

@@ -344,3 +344,69 @@ def run_greedy_logprobs_correctness_test(baseline_llm_generator,
                     b=baseline_rank_to_logprob[rank],
                     abs_tol=1e-1,
                 )
+
+
+@pytest.mark.parametrize(
+    "common_llm_kwargs",
+    [{
+        "model": "JackFram/llama-160m",
+        # Skip cuda graph recording for fast test.
+        "enforce_eager": True,
+        # Required for spec decode.
+        "use_v2_block_manager": True,
+        "max_logprobs": 6,
+    }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
+@pytest.mark.parametrize("test_llm_kwargs",
+                         [{
+                             "speculative_model": "JackFram/llama-68m",
+                             "num_speculative_tokens": 3,
+                             "disable_logprobs_during_spec_decoding": True,
+                         }])
+@pytest.mark.parametrize("seed", [1])
+def test_logprobs_disabled(baseline_llm_generator, test_llm_generator):
+    """Check the behavior when logprobs are disabled.
+    Token choices should match with the base model.
+    """
+    prompts = [
+        "Hello, my name is",
+        "The president of the United States is",
+        "The capital of France is",
+        "The future of AI is",
+        "San Francisco is know for its",
+        "Facebook was created in 2004 by",
+        "Curious George is a",
+        "Python 3.11 brings improvements to its",
+    ]
+    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(4))]
+    sampling_params = SamplingParams(
+        # Use smaller output len for fast test
+        max_tokens=7,
+        ignore_eos=True,
+        temperature=0.0,
+        logprobs=2,
+    )
+    spec_batch_logprobs = get_logprobs_from_llm_generator(
+        test_llm_generator, prompts, sampling_params)
+    baseline_batch_logprobs = get_logprobs_from_llm_generator(
+        baseline_llm_generator, prompts, sampling_params)
+    assert len(baseline_batch_logprobs) == len(prompts)
+    assert len(spec_batch_logprobs) == len(prompts)
+    # For each sequence in the batch.
+    for _, (baseline_logprobs, spec_logprobs) in enumerate(
+            zip(baseline_batch_logprobs, spec_batch_logprobs)):
+        assert len(spec_logprobs) == len(baseline_logprobs)
+        # For each generated position of the sequence.
+        for _, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
+                zip(spec_logprobs, baseline_logprobs)):
+            assert len(spec_pos_logprobs) == 1
+            spec_top_token_id = list(spec_pos_logprobs)[0]
+            spec_top_logprob = spec_pos_logprobs[spec_top_token_id]
+            assert spec_top_logprob.logprob == 0.0
+            assert spec_top_logprob.rank == -1
+            # check that the chosen token matches the base model
+            baseline_logprob = baseline_pos_logprobs[spec_top_token_id]
+            assert baseline_logprob.rank == 1
+            assert spec_top_logprob.decoded_token \
+                == baseline_logprob.decoded_token