瀏覽代碼

feat: MLPSpeculator with tensor parallel

Will have to set --speculative-draft-tensor-paralell to 1
AlpinDale 7 月之前
父節點
當前提交
dd378ea063
共有 3 個文件被更改,包括 13 次插入16 次删除
  1. 0 6
      aphrodite/common/config.py
  2. 4 3
      aphrodite/multimodal/utils.py
  3. 9 7
      aphrodite/spec_decode/spec_decode_worker.py

+ 0 - 6
aphrodite/common/config.py

@@ -1038,12 +1038,6 @@ class SpeculativeConfig:
             )
 
             draft_hf_config = draft_model_config.hf_config
-            if (draft_hf_config.model_type == "mlp_speculator"
-                    and target_parallel_config.world_size != 1):
-                # MLPSpeculator TP support will be added very soon
-                raise ValueError(
-                    "Speculative decoding with mlp_speculator models does not "
-                    "yet support distributed inferencing (TP > 1).")
 
             if (num_speculative_tokens is not None
                     and hasattr(draft_hf_config, "num_lookahead_tokens")):

+ 4 - 3
aphrodite/multimodal/utils.py

@@ -10,9 +10,9 @@ from PIL import Image
 from aphrodite.common.config import ModelConfig
 from aphrodite.multimodal.base import MultiModalDataDict
 
+APHRODITE_IMAGE_FETCH_TIMEOUT = int(
+    os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT", 10))
 
-APHRODITE_IMAGE_FETCH_TIMEOUT = int(os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT",
-                                              10))
 
 class ImageFetchAiohttp:
     aiohttp_client: Optional[aiohttp.ClientSession] = None
@@ -20,7 +20,8 @@ class ImageFetchAiohttp:
     @classmethod
     def get_aiohttp_client(cls) -> aiohttp.ClientSession:
         if cls.aiohttp_client is None:
-            timeout = aiohttp.ClientTimeout(total=APHRODITE_IMAGE_FETCH_TIMEOUT)
+            timeout = aiohttp.ClientTimeout(
+                total=APHRODITE_IMAGE_FETCH_TIMEOUT)
             connector = aiohttp.TCPConnector()
             cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
                                                        connector=connector)

+ 9 - 7
aphrodite/spec_decode/spec_decode_worker.py

@@ -115,19 +115,21 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
             proposer_worker = NGramWorker(**draft_worker_kwargs)
             proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                   ngram_prompt_lookup_max)
-        elif draft_worker_kwargs[
-                "model_config"].hf_config.model_type == "mlp_speculator":
-            proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
-            disable_bonus_tokens = False
         else:
             draft_parallel_config: ParallelConfig = draft_worker_kwargs[
                 'parallel_config']
             draft_tp = draft_parallel_config.tensor_parallel_size
             target_tp = scorer_worker.parallel_config.tensor_parallel_size
 
-            if draft_tp == 1:
-                draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
-            proposer_worker = MultiStepWorker(**draft_worker_kwargs)
+            if draft_worker_kwargs[
+                    "model_config"].hf_config.model_type == "mlp_speculator":
+                disable_bonus_tokens = False
+                proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
+            else:
+                if draft_tp == 1:
+                    draft_worker_kwargs[
+                        "model_runner_cls"] = TP1DraftModelRunner
+                proposer_worker = MultiStepWorker(**draft_worker_kwargs)
             proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
                 proposer_worker, draft_tp, target_tp)