Browse Source

chore: quant config for speculative draft models (#719)

AlpinDale 6 tháng trước cách đây
mục cha
commit
28b6397188
2 tập tin đã thay đổi với 23 bổ sung4 xóa
  1. 8 4
      aphrodite/common/config.py
  2. 15 0
      aphrodite/engine/args_tools.py

+ 8 - 4
aphrodite/common/config.py

@@ -1037,6 +1037,7 @@ class SpeculativeConfig:
         target_parallel_config: ParallelConfig,
         target_dtype: str,
         speculative_model: Optional[str],
+        speculative_model_quantization: Optional[str],
         speculative_draft_tensor_parallel_size: Optional[int],
         num_speculative_tokens: Optional[int],
         speculative_max_model_len: Optional[int],
@@ -1068,6 +1069,9 @@ class SpeculativeConfig:
             num_speculative_tokens (Optional[int]): The number of speculative
                 tokens, if provided. Will default to the number in the draft
                 model config if present, otherwise is required.
+            speculative_model_quantization (Optional[str]): Quantization method
+                that was used to quantize the speculative model weights. If
+                None, we assume the model weights are not quantized.
             speculative_draft_tensor_parallel_size (Optional[int]): The degree
                 of the tensor parallelism for the draft model.
             speculative_max_model_len (Optional[int]): The maximum model len of
@@ -1131,11 +1135,11 @@ class SpeculativeConfig:
                 "Speculative decoding requires usage of the V2 "
                 "block manager. Enable it with --use-v2-block-manager.")
 
-        # TODO: The user should be able to specify revision/quantization/max
-        # model len for the draft model. It is not currently supported.
+        # TODO: The user should be able to specify revision/max model len
+        # for the draft model. It is not currently supported.
         draft_revision = None
         draft_code_revision = None
-        draft_quantization = None
+        draft_quantization = speculative_model_quantization
 
         if speculative_model == "[ngram]":
             if ngram_prompt_lookup_min is None:
@@ -1283,7 +1287,7 @@ class SpeculativeConfig:
         elif speculative_draft_tensor_parallel_size != 1:
             # TODO: allow tp values larger than 1
             raise ValueError(
-                f"{speculative_draft_tensor_parallel_size=} cannot be"
+                f"{speculative_draft_tensor_parallel_size=} cannot be "
                 f"other value than 1")
         draft_parallel_config = ParallelConfig(
             pipeline_parallel_size=target_parallel_config.

+ 15 - 0
aphrodite/engine/args_tools.py

@@ -115,6 +115,7 @@ class EngineArgs:
     # Speculative Decoding Options
     num_lookahead_slots: int = 0
     speculative_model: Optional[str] = None
+    speculative_model_quantization: Optional[str] = None
     num_speculative_tokens: Optional[int] = None
     speculative_max_model_len: Optional[int] = None
     ngram_prompt_lookup_max: Optional[int] = None
@@ -639,6 +640,18 @@ class EngineArgs:
             default=EngineArgs.speculative_model,
             help="Category: Speculative Decoding Options\n"
             "The name of the draft model to be used in speculative decoding.")
+        # Quantization settings for speculative model.
+        parser.add_argument(
+            '--speculative-model-quantization',
+            type=str,
+            choices=[*QUANTIZATION_METHODS, None],
+            default=EngineArgs.speculative_model_quantization,
+            help='Method used to quantize the weights of speculative model.'
+            'If None, we first check the `quantization_config` '
+            'attribute in the model config file. If that is '
+            'None, we assume the model weights are not '
+            'quantized and use `dtype` to determine the data '
+            'type of the weights.')
         parser.add_argument("--num-speculative-tokens",
                             type=int,
                             default=EngineArgs.num_speculative_tokens,
@@ -956,6 +969,8 @@ class EngineArgs:
             target_parallel_config=parallel_config,
             target_dtype=self.dtype,
             speculative_model=self.speculative_model,
+            speculative_model_quantization = \
+                self.speculative_model_quantization,
             speculative_draft_tensor_parallel_size=self.
             speculative_draft_tensor_parallel_size,
             num_speculative_tokens=self.num_speculative_tokens,