|
@@ -2,6 +2,7 @@ from functools import cached_property
|
|
|
from typing import List, Optional, Tuple
|
|
|
|
|
|
import torch
|
|
|
+from loguru import logger
|
|
|
|
|
|
from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
|
|
|
SequenceGroupMetadata)
|
|
@@ -52,13 +53,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|
|
draft_worker_kwargs,
|
|
|
) -> "SpecDecodeWorker":
|
|
|
|
|
|
- if "ngram_prompt_lookup_max" in draft_worker_kwargs:
|
|
|
- ngram_prompt_lookup_max = (
|
|
|
- draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
|
|
- ngram_prompt_lookup_min = (
|
|
|
- draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
|
|
- else:
|
|
|
- ngram_prompt_lookup_max = 0
|
|
|
+ ngram_prompt_lookup_max = (
|
|
|
+ draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
|
|
+ ngram_prompt_lookup_min = (
|
|
|
+ draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
|
|
|
|
|
if ngram_prompt_lookup_max > 0:
|
|
|
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
|
@@ -67,6 +65,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|
|
else:
|
|
|
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
|
|
|
|
|
+ logger.info("Configuring SpecDecodeWorker with "
|
|
|
+ f"proposer={type(proposer_worker)}")
|
|
|
+
|
|
|
return SpecDecodeWorker(
|
|
|
proposer_worker,
|
|
|
scorer_worker,
|