Browse Source

chore: quant config for speculative draft models (#719)

AlpinDale 6 months ago
parent
commit
28b6397188
2 changed files with 23 additions and 4 deletions
  1. 8 4
      aphrodite/common/config.py
  2. 15 0
      aphrodite/engine/args_tools.py

+ 8 - 4
aphrodite/common/config.py

@@ -1037,6 +1037,7 @@ class SpeculativeConfig:
         target_parallel_config: ParallelConfig,
         target_parallel_config: ParallelConfig,
         target_dtype: str,
         target_dtype: str,
         speculative_model: Optional[str],
         speculative_model: Optional[str],
+        speculative_model_quantization: Optional[str],
         speculative_draft_tensor_parallel_size: Optional[int],
         speculative_draft_tensor_parallel_size: Optional[int],
         num_speculative_tokens: Optional[int],
         num_speculative_tokens: Optional[int],
         speculative_max_model_len: Optional[int],
         speculative_max_model_len: Optional[int],
@@ -1068,6 +1069,9 @@ class SpeculativeConfig:
             num_speculative_tokens (Optional[int]): The number of speculative
             num_speculative_tokens (Optional[int]): The number of speculative
                 tokens, if provided. Will default to the number in the draft
                 tokens, if provided. Will default to the number in the draft
                 model config if present, otherwise is required.
                 model config if present, otherwise is required.
+            speculative_model_quantization (Optional[str]): Quantization method
+                that was used to quantize the speculative model weights. If
+                None, we assume the model weights are not quantized.
             speculative_draft_tensor_parallel_size (Optional[int]): The degree
             speculative_draft_tensor_parallel_size (Optional[int]): The degree
                 of the tensor parallelism for the draft model.
                 of the tensor parallelism for the draft model.
             speculative_max_model_len (Optional[int]): The maximum model len of
             speculative_max_model_len (Optional[int]): The maximum model len of
@@ -1131,11 +1135,11 @@ class SpeculativeConfig:
                 "Speculative decoding requires usage of the V2 "
                 "Speculative decoding requires usage of the V2 "
                 "block manager. Enable it with --use-v2-block-manager.")
                 "block manager. Enable it with --use-v2-block-manager.")
 
 
-        # TODO: The user should be able to specify revision/quantization/max
-        # model len for the draft model. It is not currently supported.
+        # TODO: The user should be able to specify revision/max model len
+        # for the draft model. It is not currently supported.
         draft_revision = None
         draft_revision = None
         draft_code_revision = None
         draft_code_revision = None
-        draft_quantization = None
+        draft_quantization = speculative_model_quantization
 
 
         if speculative_model == "[ngram]":
         if speculative_model == "[ngram]":
             if ngram_prompt_lookup_min is None:
             if ngram_prompt_lookup_min is None:
@@ -1283,7 +1287,7 @@ class SpeculativeConfig:
         elif speculative_draft_tensor_parallel_size != 1:
         elif speculative_draft_tensor_parallel_size != 1:
             # TODO: allow tp values larger than 1
             # TODO: allow tp values larger than 1
             raise ValueError(
             raise ValueError(
-                f"{speculative_draft_tensor_parallel_size=} cannot be"
+                f"{speculative_draft_tensor_parallel_size=} cannot be "
                 f"other value than 1")
                 f"other value than 1")
         draft_parallel_config = ParallelConfig(
         draft_parallel_config = ParallelConfig(
             pipeline_parallel_size=target_parallel_config.
             pipeline_parallel_size=target_parallel_config.

+ 15 - 0
aphrodite/engine/args_tools.py

@@ -115,6 +115,7 @@ class EngineArgs:
     # Speculative Decoding Options
     # Speculative Decoding Options
     num_lookahead_slots: int = 0
     num_lookahead_slots: int = 0
     speculative_model: Optional[str] = None
     speculative_model: Optional[str] = None
+    speculative_model_quantization: Optional[str] = None
     num_speculative_tokens: Optional[int] = None
     num_speculative_tokens: Optional[int] = None
     speculative_max_model_len: Optional[int] = None
     speculative_max_model_len: Optional[int] = None
     ngram_prompt_lookup_max: Optional[int] = None
     ngram_prompt_lookup_max: Optional[int] = None
@@ -639,6 +640,18 @@ class EngineArgs:
             default=EngineArgs.speculative_model,
             default=EngineArgs.speculative_model,
             help="Category: Speculative Decoding Options\n"
             help="Category: Speculative Decoding Options\n"
             "The name of the draft model to be used in speculative decoding.")
             "The name of the draft model to be used in speculative decoding.")
+        # Quantization settings for speculative model.
+        parser.add_argument(
+            '--speculative-model-quantization',
+            type=str,
+            choices=[*QUANTIZATION_METHODS, None],
+            default=EngineArgs.speculative_model_quantization,
+            help='Method used to quantize the weights of speculative model.'
+            'If None, we first check the `quantization_config` '
+            'attribute in the model config file. If that is '
+            'None, we assume the model weights are not '
+            'quantized and use `dtype` to determine the data '
+            'type of the weights.')
         parser.add_argument("--num-speculative-tokens",
         parser.add_argument("--num-speculative-tokens",
                             type=int,
                             type=int,
                             default=EngineArgs.num_speculative_tokens,
                             default=EngineArgs.num_speculative_tokens,
@@ -956,6 +969,8 @@ class EngineArgs:
             target_parallel_config=parallel_config,
             target_parallel_config=parallel_config,
             target_dtype=self.dtype,
             target_dtype=self.dtype,
             speculative_model=self.speculative_model,
             speculative_model=self.speculative_model,
+            speculative_model_quantization = \
+                self.speculative_model_quantization,
             speculative_draft_tensor_parallel_size=self.
             speculative_draft_tensor_parallel_size=self.
             speculative_draft_tensor_parallel_size,
             speculative_draft_tensor_parallel_size,
             num_speculative_tokens=self.num_speculative_tokens,
             num_speculative_tokens=self.num_speculative_tokens,