|
@@ -77,6 +77,10 @@ class ModelConfig:
|
|
|
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
|
|
When a sequence has context length larger than this, we fall back
|
|
|
to eager mode
|
|
|
+ disable_sliding_window: Whether to disable sliding window. If True,
|
|
|
+ we will disable the sliding window functionality of the model.
|
|
|
+ If the model does not support sliding window, this argument is
|
|
|
+ ignored.
|
|
|
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
|
|
detokenizer.
|
|
|
"""
|
|
@@ -104,6 +108,7 @@ class ModelConfig:
|
|
|
max_context_len_to_capture: Optional[int] = None,
|
|
|
max_seq_len_to_capture: Optional[int] = None,
|
|
|
max_logprobs: int = 5,
|
|
|
+ disable_sliding_window: bool = False,
|
|
|
skip_tokenizer_init: bool = False,
|
|
|
) -> None:
|
|
|
self.model = model
|
|
@@ -129,14 +134,18 @@ class ModelConfig:
|
|
|
self.max_seq_len_to_capture = (max_seq_len_to_capture
|
|
|
or max_context_len_to_capture)
|
|
|
self.max_logprobs = max_logprobs
|
|
|
+ self.disable_sliding_window = disable_sliding_window
|
|
|
self.skip_tokenizer_init = skip_tokenizer_init
|
|
|
|
|
|
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
|
|
code_revision, rope_scaling)
|
|
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
|
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
|
|
- self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
|
|
|
- max_model_len)
|
|
|
+ self.max_model_len = _get_and_verify_max_len(
|
|
|
+ hf_config=self.hf_text_config,
|
|
|
+ max_model_len=max_model_len,
|
|
|
+ disable_sliding_window=self.disable_sliding_window,
|
|
|
+ sliding_window_len=self.get_hf_config_sliding_window())
|
|
|
|
|
|
if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
|
|
|
and getattr(self.hf_config, "rope_scaling", None) is None):
|
|
@@ -308,7 +317,7 @@ class ModelConfig:
|
|
|
"must be divisible by pipeline parallel size "
|
|
|
f"({pipeline_parallel_size}).")
|
|
|
|
|
|
- def get_sliding_window(self) -> Optional[int]:
|
|
|
+ def get_hf_config_sliding_window(self) -> Optional[int]:
|
|
|
"""Get the sliding window size, or None if disabled.
|
|
|
"""
|
|
|
|
|
@@ -320,6 +329,15 @@ class ModelConfig:
|
|
|
return None
|
|
|
return getattr(self.hf_text_config, "sliding_window", None)
|
|
|
|
|
|
+ def get_sliding_window(self) -> Optional[int]:
|
|
|
+ """Get the sliding window size, or None if disabled.
|
|
|
+ """
|
|
|
+ # If user disables sliding window, return None.
|
|
|
+ if self.disable_sliding_window:
|
|
|
+ return None
|
|
|
+ # Otherwise get the value from the hf config.
|
|
|
+ return self.get_hf_config_sliding_window()
|
|
|
+
|
|
|
def get_vocab_size(self) -> int:
|
|
|
return self.hf_text_config.vocab_size
|
|
|
|
|
@@ -424,6 +442,7 @@ class CacheConfig:
|
|
|
self.enable_prefix_caching = enable_prefix_caching
|
|
|
self._verify_args()
|
|
|
self._verify_cache_dtype()
|
|
|
+ self._verify_prefix_caching()
|
|
|
|
|
|
# Will be set after profiling.
|
|
|
self.num_gpu_blocks = None
|
|
@@ -452,6 +471,19 @@ class CacheConfig:
|
|
|
else:
|
|
|
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
|
|
|
|
|
+ def _verify_prefix_caching(self) -> None:
|
|
|
+ if not self.enable_prefix_caching:
|
|
|
+ return
|
|
|
+
|
|
|
+ if self.sliding_window is not None:
|
|
|
+ raise NotImplementedError(
|
|
|
+ "Prefix caching is not supported with sliding window. "
|
|
|
+ "Run with --disable-sliding-window to use prefix caching.")
|
|
|
+ if self.cache_dtype == "fp8":
|
|
|
+ raise NotImplementedError(
|
|
|
+ "Prefix caching is not supported for fp8 cache_dtype. "
|
|
|
+ "Run with --kv-cache-dtype auto to use prefix caching.")
|
|
|
+
|
|
|
def verify_with_parallel_config(
|
|
|
self,
|
|
|
parallel_config: "ParallelConfig",
|
|
@@ -1203,6 +1235,8 @@ def _get_and_verify_dtype(
|
|
|
def _get_and_verify_max_len(
|
|
|
hf_config: PretrainedConfig,
|
|
|
max_model_len: Optional[int],
|
|
|
+ disable_sliding_window: bool,
|
|
|
+ sliding_window_len: Optional[int],
|
|
|
) -> int:
|
|
|
"""Get and verify the model's maximum length."""
|
|
|
derived_max_model_len = float("inf")
|
|
@@ -1224,6 +1258,7 @@ def _get_and_verify_max_len(
|
|
|
"max_seq_length",
|
|
|
"seq_len",
|
|
|
]
|
|
|
+ # Choose the smallest "max_length" from the possible keys.
|
|
|
max_len_key = None
|
|
|
for key in possible_keys:
|
|
|
max_len = getattr(hf_config, key, None)
|
|
@@ -1231,6 +1266,16 @@ def _get_and_verify_max_len(
|
|
|
max_len_key = key if max_len < derived_max_model_len \
|
|
|
else max_len_key
|
|
|
derived_max_model_len = min(derived_max_model_len, max_len)
|
|
|
+
|
|
|
+ # If sliding window is manually disabled, max_length should be less
|
|
|
+ # than the sliding window length in the model config.
|
|
|
+ if disable_sliding_window and sliding_window_len is not None:
|
|
|
+ max_len_key = "sliding_window" \
|
|
|
+ if sliding_window_len < derived_max_model_len else max_len_key
|
|
|
+ derived_max_model_len = min(derived_max_model_len, sliding_window_len)
|
|
|
+
|
|
|
+ # If none of the keys were found in the config, use a default and
|
|
|
+ # log a warning.
|
|
|
if derived_max_model_len == float("inf"):
|
|
|
if max_model_len is not None:
|
|
|
# If max_model_len is specified, we use it.
|
|
@@ -1248,6 +1293,13 @@ def _get_and_verify_max_len(
|
|
|
if rope_scaling is not None:
|
|
|
rope_type = rope_scaling.get("type", rope_scaling.get("rope_type"))
|
|
|
if rope_type not in {"su", "longrope", "llama3"}:
|
|
|
+ if disable_sliding_window:
|
|
|
+ # TODO: Find a model that supports rope_scaling
|
|
|
+ # with sliding window to see if this case should be allowed.
|
|
|
+ raise NotImplementedError(
|
|
|
+ "Disabling sliding window is not supported for models "
|
|
|
+ "with rope_scaling. Please raise an issue so we can "
|
|
|
+ "investigate.")
|
|
|
assert "factor" in rope_scaling
|
|
|
scaling_factor = rope_scaling["factor"]
|
|
|
if rope_type == "yarn":
|