Browse Source

fix ngrams

AlpinDale 8 months ago
parent
commit
438f5bdce9
2 changed files with 12 additions and 7 deletions
  1. 4 0
      aphrodite/executor/gpu_executor.py
  2. 8 7
      aphrodite/spec_decode/spec_decode_worker.py

+ 4 - 0
aphrodite/executor/gpu_executor.py

@@ -81,6 +81,10 @@ class GPUExecutor(ExecutorBase):
         draft_worker_kwargs.update(
             model_config=self.speculative_config.draft_model_config,
             parallel_config=self.speculative_config.draft_parallel_config,
+            ngram_prompt_lookup_max=self.speculative_config.
+            ngram_prompt_lookup_max,
+            ngram_prompt_lookup_min=self.speculative_config.
+            ngram_prompt_lookup_min,
             # TODO allow draft-model specific load config.
             #load_config=self.load_config,
         )

+ 8 - 7
aphrodite/spec_decode/spec_decode_worker.py

@@ -2,6 +2,7 @@ from functools import cached_property
 from typing import List, Optional, Tuple
 
 import torch
+from loguru import logger
 
 from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
                                        SequenceGroupMetadata)
@@ -52,13 +53,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         draft_worker_kwargs,
     ) -> "SpecDecodeWorker":
 
-        if "ngram_prompt_lookup_max" in draft_worker_kwargs:
-            ngram_prompt_lookup_max = (
-                draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
-            ngram_prompt_lookup_min = (
-                draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
-        else:
-            ngram_prompt_lookup_max = 0
+        ngram_prompt_lookup_max = (
+            draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
+        ngram_prompt_lookup_min = (
+            draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
 
         if ngram_prompt_lookup_max > 0:
             proposer_worker = NGramWorker(**draft_worker_kwargs)
@@ -67,6 +65,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         else:
             proposer_worker = MultiStepWorker(**draft_worker_kwargs)
 
+        logger.info("Configuring SpecDecodeWorker with "
+                    f"proposer={type(proposer_worker)}")
+
         return SpecDecodeWorker(
             proposer_worker,
             scorer_worker,