|
@@ -1037,6 +1037,7 @@ class SpeculativeConfig:
|
|
|
target_parallel_config: ParallelConfig,
|
|
|
target_dtype: str,
|
|
|
speculative_model: Optional[str],
|
|
|
+ speculative_model_quantization: Optional[str],
|
|
|
speculative_draft_tensor_parallel_size: Optional[int],
|
|
|
num_speculative_tokens: Optional[int],
|
|
|
speculative_max_model_len: Optional[int],
|
|
@@ -1068,6 +1069,9 @@ class SpeculativeConfig:
|
|
|
num_speculative_tokens (Optional[int]): The number of speculative
|
|
|
tokens, if provided. Will default to the number in the draft
|
|
|
model config if present, otherwise is required.
|
|
|
+ speculative_model_quantization (Optional[str]): Quantization method
|
|
|
+ that was used to quantize the speculative model weights. If
|
|
|
+ None, we assume the model weights are not quantized.
|
|
|
speculative_draft_tensor_parallel_size (Optional[int]): The degree
|
|
|
of the tensor parallelism for the draft model.
|
|
|
speculative_max_model_len (Optional[int]): The maximum model len of
|
|
@@ -1131,11 +1135,11 @@ class SpeculativeConfig:
|
|
|
"Speculative decoding requires usage of the V2 "
|
|
|
"block manager. Enable it with --use-v2-block-manager.")
|
|
|
|
|
|
- # TODO: The user should be able to specify revision/quantization/max
|
|
|
- # model len for the draft model. It is not currently supported.
|
|
|
+ # TODO: The user should be able to specify revision/max model len
|
|
|
+ # for the draft model. It is not currently supported.
|
|
|
draft_revision = None
|
|
|
draft_code_revision = None
|
|
|
- draft_quantization = None
|
|
|
+ draft_quantization = speculative_model_quantization
|
|
|
|
|
|
if speculative_model == "[ngram]":
|
|
|
if ngram_prompt_lookup_min is None:
|
|
@@ -1283,7 +1287,7 @@ class SpeculativeConfig:
|
|
|
elif speculative_draft_tensor_parallel_size != 1:
|
|
|
# TODO: allow tp values larger than 1
|
|
|
raise ValueError(
|
|
|
- f"{speculative_draft_tensor_parallel_size=} cannot be"
|
|
|
+ f"{speculative_draft_tensor_parallel_size=} cannot be "
|
|
|
f"other value than 1")
|
|
|
draft_parallel_config = ParallelConfig(
|
|
|
pipeline_parallel_size=target_parallel_config.
|