浏览代码

re-add ngram speculative decoding

AlpinDale 7 月之前
父节点
当前提交
723c6acb84

+ 58 - 34
aphrodite/common/config.py

@@ -756,6 +756,8 @@ class SpeculativeConfig:
         speculative_max_model_len: Optional[int],
         enable_chunked_prefill: bool,
         use_v2_block_manager: bool,
+        ngram_prompt_lookup_max: Optional[int],
+        ngram_prompt_lookup_min: Optional[int],
     ) -> Optional["SpeculativeConfig"]:
         """Create a SpeculativeConfig if possible, else return None.
 
@@ -819,40 +821,55 @@ class SpeculativeConfig:
         draft_code_revision = None
         draft_quantization = None
 
-        draft_model_config = ModelConfig(
-            model=speculative_model,
-            tokenizer=target_model_config.tokenizer,
-            tokenizer_mode=target_model_config.tokenizer_mode,
-            trust_remote_code=target_model_config.trust_remote_code,
-            dtype=target_model_config.dtype,
-            seed=target_model_config.seed,
-            revision=draft_revision,
-            code_revision=draft_code_revision,
-            tokenizer_revision=target_model_config.tokenizer_revision,
-            max_model_len=None,
-            quantization=draft_quantization,
-            enforce_eager=target_model_config.enforce_eager,
-            max_context_len_to_capture=target_model_config.
-            max_context_len_to_capture,
-            max_logprobs=target_model_config.max_logprobs,
-        )
+        if speculative_model == "[ngram]":
+            assert (ngram_prompt_lookup_max is not None
+                    and ngram_prompt_lookup_max > 0)
+            if ngram_prompt_lookup_min is None:
+                ngram_prompt_lookup_min = 0
+            else:
+                assert ngram_prompt_lookup_max > ngram_prompt_lookup_min
 
-        draft_model_config.max_model_len = (
-            SpeculativeConfig._maybe_override_draft_max_model_len(
-                speculative_max_model_len,
-                draft_model_config.max_model_len,
-                target_model_config.max_model_len,
-            ))
-
-        draft_parallel_config = (
-            SpeculativeConfig.create_draft_parallel_config(
-                target_parallel_config))
-
-        return SpeculativeConfig(
-            draft_model_config,
-            draft_parallel_config,
-            num_speculative_tokens,
-        )
+            # TODO: current we still need extract vocab_size from target model
+            # config, in future, we may try refactoring  it out, and set
+            # draft related config as None here.
+            draft_model_config = target_model_config
+            draft_parallel_config = target_parallel_config
+        else:
+            ngram_prompt_lookup_max = 0
+            ngram_prompt_lookup_min = 0
+            draft_model_config = ModelConfig(
+                model=speculative_model,
+                tokenizer=target_model_config.tokenizer,
+                tokenizer_mode=target_model_config.tokenizer_mode,
+                trust_remote_code=target_model_config.trust_remote_code,
+                dtype=target_model_config.dtype,
+                seed=target_model_config.seed,
+                revision=draft_revision,
+                code_revision=draft_code_revision,
+                tokenizer_revision=target_model_config.tokenizer_revision,
+                max_model_len=None,
+                quantization=draft_quantization,
+                enforce_eager=target_model_config.enforce_eager,
+                max_context_len_to_capture=target_model_config.
+                max_context_len_to_capture,
+                max_logprobs=target_model_config.max_logprobs,
+            )
+
+            draft_model_config.max_model_len = (
+                SpeculativeConfig._maybe_override_draft_max_model_len(
+                    speculative_max_model_len,
+                    draft_model_config.max_model_len,
+                    target_model_config.max_model_len,
+                ))
+
+            draft_parallel_config = (
+                SpeculativeConfig.create_draft_parallel_config(
+                    target_parallel_config))
+
+        return SpeculativeConfig(draft_model_config, draft_parallel_config,
+                                 num_speculative_tokens,
+                                 ngram_prompt_lookup_max,
+                                 ngram_prompt_lookup_min)
 
     @staticmethod
     def _maybe_override_draft_max_model_len(
@@ -919,6 +936,8 @@ class SpeculativeConfig:
         draft_model_config: ModelConfig,
         draft_parallel_config: ParallelConfig,
         num_speculative_tokens: int,
+        ngram_prompt_lookup_max: int,
+        ngram_prompt_lookup_min: int,
     ):
         """Create a SpeculativeConfig object.
 
@@ -931,6 +950,8 @@ class SpeculativeConfig:
         self.draft_model_config = draft_model_config
         self.draft_parallel_config = draft_parallel_config
         self.num_speculative_tokens = num_speculative_tokens
+        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
+        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
 
         self._verify_args()
 
@@ -954,7 +975,10 @@ class SpeculativeConfig:
         return self.num_speculative_tokens
 
     def __repr__(self) -> str:
-        draft_model = self.draft_model_config.model
+        if self.ngram_prompt_lookup_max > 0:
+            draft_model = "[ngram]"
+        else:
+            draft_model = self.draft_model_config.model
         num_spec_tokens = self.num_speculative_tokens
         return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
 

+ 16 - 0
aphrodite/engine/args_tools.py

@@ -78,6 +78,8 @@ class EngineArgs:
     speculative_model: Optional[str] = None
     num_speculative_tokens: Optional[int] = None
     speculative_max_model_len: Optional[int] = None
+    ngram_prompt_lookup_max: Optional[int] = None
+    ngram_prompt_lookup_min: Optional[int] = None
 
     def __post_init__(self):
         if self.tokenizer is None:
@@ -523,6 +525,18 @@ class EngineArgs:
             help="The maximum sequence length supported by the "
             "draft model. Sequences over this length will skip "
             "speculation.")
+        parser.add_argument(
+            "--ngram-prompt-lookup-max",
+            type=int,
+            default=EngineArgs.ngram_prompt_lookup_max,
+            help="Max size of window for ngram prompt lookup in speculative "
+            "decoding.")
+        parser.add_argument(
+            "--ngram-prompt-lookup-min",
+            type=int,
+            default=EngineArgs.ngram_prompt_lookup_min,
+            help="Min size of window for ngram prompt lookup in speculative "
+            "decoding.")
         parser.add_argument("--model-loader-extra-config",
                             type=str,
                             default=EngineArgs.model_loader_extra_config,
@@ -600,6 +614,8 @@ class EngineArgs:
             speculative_max_model_len=self.speculative_max_model_len,
             enable_chunked_prefill=self.enable_chunked_prefill,
             use_v2_block_manager=self.use_v2_block_manager,
+            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
+            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
         )
 
         scheduler_config = SchedulerConfig(

+ 4 - 3
aphrodite/engine/output_processor/multi_step.py

@@ -1,6 +1,5 @@
 from typing import Callable, List
 
-from loguru import logger
 from transformers import PreTrainedTokenizer
 
 from aphrodite.common.sampling_params import SamplingParams
@@ -13,6 +12,7 @@ from aphrodite.engine.output_processor.interfaces import \
 from aphrodite.engine.output_processor.stop_checker import StopChecker
 from aphrodite.processing.scheduler import Scheduler
 from aphrodite.transformers_utils.detokenizer import Detokenizer
+from aphrodite.common.logger import log_once
 
 
 class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
@@ -47,8 +47,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
                                outputs: List[SequenceGroupOutput]) -> None:
         # TODO: Prompt logprob currently not implemented in multi step
         # workers.
-        logger.warning(
-            "Prompt logprob is not supported by multi step workers. "
+        log_once(
+            level="WARNING",
+            message="Prompt logprob is not supported by multi step workers. "
             "(e.g., speculative decode uses multi step workers).")
         pass
 

+ 5 - 5
aphrodite/executor/gpu_executor.py

@@ -52,7 +52,7 @@ class GPUExecutor(ExecutorBase):
                        rank: int = 0,
                        distributed_init_method: Optional[str] = None):
         wrapper = WorkerWrapperBase(
-            worker_module_name="vllm.worker.worker",
+            worker_module_name="aphrodite.task_handler.worker",
             worker_class_name="Worker",
         )
         wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
@@ -72,7 +72,6 @@ class GPUExecutor(ExecutorBase):
         """
         assert self.speculative_config is not None
 
-        from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
         from aphrodite.spec_decode.spec_decode_worker import SpecDecodeWorker
 
         target_worker = self._create_worker()
@@ -85,10 +84,11 @@ class GPUExecutor(ExecutorBase):
             # TODO allow draft-model specific load config.
             #load_config=self.load_config,
         )
-        draft_worker = MultiStepWorker(**draft_worker_kwargs)
 
-        spec_decode_worker = SpecDecodeWorker.from_workers(
-            proposer_worker=draft_worker, scorer_worker=target_worker)
+        spec_decode_worker = SpecDecodeWorker.create_worker(
+            scorer_worker=target_worker,
+            draft_worker_kwargs=draft_worker_kwargs,
+        )
 
         assert self.parallel_config.world_size == 1, (
             "GPUExecutor only supports single GPU.")

+ 0 - 1
aphrodite/spec_decode/batch_expansion.py

@@ -306,7 +306,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
                 target_seq_id: seq_group_metadata.block_tables[seq_id],
             },
             lora_request=None,
-            persistent_data={},
         )
 
     def _split_scoring_output(

+ 4 - 1
aphrodite/spec_decode/multi_step_worker.py

@@ -51,7 +51,10 @@ class MultiStepWorker(Worker):
         sample_len: int,
     ) -> Tuple[List[SamplerOutput], bool]:
         """Run the model forward pass sample_len times. Returns the list of
-        sampler output, one per model forward pass.
+        sampler output, one per model forward pass, along with indicator of
+        whether torch tensor in sampler output need to be transposed in latter
+        sampler_output_to_torch logic.
+        For multi step worker, this indicator shall be True.
         """
         self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
                                    blocks_to_swap_out, blocks_to_copy)

+ 3 - 0
aphrodite/spec_decode/ngram_worker.py

@@ -78,6 +78,8 @@ class NGramWorker(LoraNotSupportedWorkerBase):
     ) -> Tuple[Optional[List[SamplerOutput]], bool]:
         """NGram match algo to pick proposal candidate. Returns the list of
         sampler output, one per SequenceGroupMetadata.
+        For ngram worker, we already done needed transposed internal, so the
+        indicator pass to sampler_output_to_torch shall be False.
         """
         self._raise_if_unsupported(
             seq_group_metadata_list,
@@ -115,6 +117,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
                     res_len = len(res)
                     # pad 0 towards output as sample_len tokens required
                     res += [0] * (sample_len - res_len)
+
                     break
             else:
                 # if no candidate found, fill with 0

+ 14 - 50
aphrodite/spec_decode/spec_decode_worker.py

@@ -2,9 +2,7 @@ from functools import cached_property
 from typing import Dict, List, Optional, Tuple
 
 import torch
-from loguru import logger
 
-from aphrodite.common.config import SchedulerConfig
 from aphrodite.common.sequence import (Logprob, SamplerOutput,
                                        SequenceGroupMetadata,
                                        SequenceGroupOutput, SequenceOutput)
@@ -46,51 +44,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         correctness tests pass.
     """
 
-    @classmethod
-    def from_workers(cls, proposer_worker: MultiStepWorker,
-                     scorer_worker: WorkerBase) -> "SpecDecodeWorker":
-        return SpecDecodeWorker(
-            proposer_worker,
-            scorer_worker,
-            # TODO: disable strict mode for speedup.
-            rejection_sampler=RejectionSampler(strict_mode=True),
-        )
-
     @classmethod
     def create_worker(
         cls,
         scorer_worker: WorkerBase,
-        speculative_config: SchedulerConfig,
+        draft_worker_kwargs,
     ) -> "SpecDecodeWorker":
 
-        if speculative_config.ngram_prompt_lookup_max > 0:
-            proposer_worker = NGramWorker(
-                model_config=speculative_config.draft_model_config,
-                parallel_config=speculative_config.draft_parallel_config,
-                scheduler_config=scorer_worker.scheduler_config,
-                device_config=scorer_worker.device_config,
-                cache_config=scorer_worker.cache_config,
-                local_rank=0,
-                rank=0,
-                distributed_init_method=scorer_worker.distributed_init_method,
-            )
-            proposer_worker.set_ngram_window_size(
-                speculative_config.ngram_prompt_lookup_min,
-                speculative_config.ngram_prompt_lookup_max)
+        if "ngram_prompt_lookup_max" in draft_worker_kwargs:
+            ngram_prompt_lookup_max = (
+                draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
+            ngram_prompt_lookup_min = (
+                draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
         else:
-            proposer_worker = MultiStepWorker(
-                model_config=speculative_config.draft_model_config,
-                parallel_config=speculative_config.draft_parallel_config,
-                scheduler_config=scorer_worker.scheduler_config,
-                device_config=scorer_worker.device_config,
-                cache_config=scorer_worker.cache_config,
-                local_rank=0,
-                rank=0,
-                distributed_init_method=scorer_worker.distributed_init_method,
-                lora_config=scorer_worker.lora_config,
-                vision_language_config=scorer_worker.vision_language_config,
-                is_driver_worker=True,
-            )
+            ngram_prompt_lookup_max = 0
+
+        if ngram_prompt_lookup_max > 0:
+            proposer_worker = NGramWorker(**draft_worker_kwargs)
+            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
+                                                  ngram_prompt_lookup_max)
+        else:
+            proposer_worker = MultiStepWorker(**draft_worker_kwargs)
         return SpecDecodeWorker(
             proposer_worker,
             scorer_worker,
@@ -223,9 +197,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
             "speculative decoding "
             "requires non-None seq_group_metadata_list")
 
-        logger.debug(
-            f"spec_decode_worker.execute_model {num_lookahead_slots=}")
-
         # If no spec tokens, call the proposer and scorer workers normally.
         # Used for prefill.
         if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
@@ -256,7 +227,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         proposer and scorer model so that the KV cache is consistent between the
         two.
         """
-        logger.debug("run proposer worker no spec")
 
         self.proposer_worker.execute_model(
             seq_group_metadata_list=seq_group_metadata_list,
@@ -265,7 +235,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
             blocks_to_copy=blocks_to_copy,
         )
 
-        logger.debug("run target worker no spec")
         sampler_output = self.scorer_worker.execute_model(
             seq_group_metadata_list=seq_group_metadata_list,
             blocks_to_swap_in=blocks_to_swap_in,
@@ -299,7 +268,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         sequence.
         """
 
-        logger.debug("get spec proposals")
         # Generate proposals using draft worker.
         assert blocks_to_swap_in is not None
         assert blocks_to_swap_out is not None
@@ -308,7 +276,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
             seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
             blocks_to_copy, k)
 
-        logger.debug("score proposals")
         proposal_scores = self.scorer.score_proposals(
             seq_group_metadata_list,
             blocks_to_swap_in,
@@ -318,11 +285,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
             proposals,
         )
 
-        logger.debug("verify proposals")
         accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
                                                  proposal_scores, proposals, k)
 
-        logger.debug("create output list")
         return self._create_output_sampler_list(seq_group_metadata_list,
                                                 accepted_token_ids, k)
 
@@ -420,7 +385,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
                                 output_token=token_id,
                                 # TODO Add verifier logprobs.
                                 logprobs={token_id: Logprob(0.0)},
-                                persistent_data={},
                             )
                         ],
                         prompt_logprobs=None,