|
@@ -756,6 +756,8 @@ class SpeculativeConfig:
|
|
|
speculative_max_model_len: Optional[int],
|
|
|
enable_chunked_prefill: bool,
|
|
|
use_v2_block_manager: bool,
|
|
|
+ ngram_prompt_lookup_max: Optional[int],
|
|
|
+ ngram_prompt_lookup_min: Optional[int],
|
|
|
) -> Optional["SpeculativeConfig"]:
|
|
|
"""Create a SpeculativeConfig if possible, else return None.
|
|
|
|
|
@@ -819,40 +821,55 @@ class SpeculativeConfig:
|
|
|
draft_code_revision = None
|
|
|
draft_quantization = None
|
|
|
|
|
|
- draft_model_config = ModelConfig(
|
|
|
- model=speculative_model,
|
|
|
- tokenizer=target_model_config.tokenizer,
|
|
|
- tokenizer_mode=target_model_config.tokenizer_mode,
|
|
|
- trust_remote_code=target_model_config.trust_remote_code,
|
|
|
- dtype=target_model_config.dtype,
|
|
|
- seed=target_model_config.seed,
|
|
|
- revision=draft_revision,
|
|
|
- code_revision=draft_code_revision,
|
|
|
- tokenizer_revision=target_model_config.tokenizer_revision,
|
|
|
- max_model_len=None,
|
|
|
- quantization=draft_quantization,
|
|
|
- enforce_eager=target_model_config.enforce_eager,
|
|
|
- max_context_len_to_capture=target_model_config.
|
|
|
- max_context_len_to_capture,
|
|
|
- max_logprobs=target_model_config.max_logprobs,
|
|
|
- )
|
|
|
+ if speculative_model == "[ngram]":
|
|
|
+ assert (ngram_prompt_lookup_max is not None
|
|
|
+ and ngram_prompt_lookup_max > 0)
|
|
|
+ if ngram_prompt_lookup_min is None:
|
|
|
+ ngram_prompt_lookup_min = 0
|
|
|
+ else:
|
|
|
+ assert ngram_prompt_lookup_max > ngram_prompt_lookup_min
|
|
|
|
|
|
- draft_model_config.max_model_len = (
|
|
|
- SpeculativeConfig._maybe_override_draft_max_model_len(
|
|
|
- speculative_max_model_len,
|
|
|
- draft_model_config.max_model_len,
|
|
|
- target_model_config.max_model_len,
|
|
|
- ))
|
|
|
-
|
|
|
- draft_parallel_config = (
|
|
|
- SpeculativeConfig.create_draft_parallel_config(
|
|
|
- target_parallel_config))
|
|
|
-
|
|
|
- return SpeculativeConfig(
|
|
|
- draft_model_config,
|
|
|
- draft_parallel_config,
|
|
|
- num_speculative_tokens,
|
|
|
- )
|
|
|
+ # TODO: current we still need extract vocab_size from target model
|
|
|
+ # config, in future, we may try refactoring it out, and set
|
|
|
+ # draft related config as None here.
|
|
|
+ draft_model_config = target_model_config
|
|
|
+ draft_parallel_config = target_parallel_config
|
|
|
+ else:
|
|
|
+ ngram_prompt_lookup_max = 0
|
|
|
+ ngram_prompt_lookup_min = 0
|
|
|
+ draft_model_config = ModelConfig(
|
|
|
+ model=speculative_model,
|
|
|
+ tokenizer=target_model_config.tokenizer,
|
|
|
+ tokenizer_mode=target_model_config.tokenizer_mode,
|
|
|
+ trust_remote_code=target_model_config.trust_remote_code,
|
|
|
+ dtype=target_model_config.dtype,
|
|
|
+ seed=target_model_config.seed,
|
|
|
+ revision=draft_revision,
|
|
|
+ code_revision=draft_code_revision,
|
|
|
+ tokenizer_revision=target_model_config.tokenizer_revision,
|
|
|
+ max_model_len=None,
|
|
|
+ quantization=draft_quantization,
|
|
|
+ enforce_eager=target_model_config.enforce_eager,
|
|
|
+ max_context_len_to_capture=target_model_config.
|
|
|
+ max_context_len_to_capture,
|
|
|
+ max_logprobs=target_model_config.max_logprobs,
|
|
|
+ )
|
|
|
+
|
|
|
+ draft_model_config.max_model_len = (
|
|
|
+ SpeculativeConfig._maybe_override_draft_max_model_len(
|
|
|
+ speculative_max_model_len,
|
|
|
+ draft_model_config.max_model_len,
|
|
|
+ target_model_config.max_model_len,
|
|
|
+ ))
|
|
|
+
|
|
|
+ draft_parallel_config = (
|
|
|
+ SpeculativeConfig.create_draft_parallel_config(
|
|
|
+ target_parallel_config))
|
|
|
+
|
|
|
+ return SpeculativeConfig(draft_model_config, draft_parallel_config,
|
|
|
+ num_speculative_tokens,
|
|
|
+ ngram_prompt_lookup_max,
|
|
|
+ ngram_prompt_lookup_min)
|
|
|
|
|
|
@staticmethod
|
|
|
def _maybe_override_draft_max_model_len(
|
|
@@ -919,6 +936,8 @@ class SpeculativeConfig:
|
|
|
draft_model_config: ModelConfig,
|
|
|
draft_parallel_config: ParallelConfig,
|
|
|
num_speculative_tokens: int,
|
|
|
+ ngram_prompt_lookup_max: int,
|
|
|
+ ngram_prompt_lookup_min: int,
|
|
|
):
|
|
|
"""Create a SpeculativeConfig object.
|
|
|
|
|
@@ -931,6 +950,8 @@ class SpeculativeConfig:
|
|
|
self.draft_model_config = draft_model_config
|
|
|
self.draft_parallel_config = draft_parallel_config
|
|
|
self.num_speculative_tokens = num_speculative_tokens
|
|
|
+ self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
|
|
|
+ self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
|
|
|
|
|
self._verify_args()
|
|
|
|
|
@@ -954,7 +975,10 @@ class SpeculativeConfig:
|
|
|
return self.num_speculative_tokens
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
- draft_model = self.draft_model_config.model
|
|
|
+ if self.ngram_prompt_lookup_max > 0:
|
|
|
+ draft_model = "[ngram]"
|
|
|
+ else:
|
|
|
+ draft_model = self.draft_model_config.model
|
|
|
num_spec_tokens = self.num_speculative_tokens
|
|
|
return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
|
|
|
|