Pārlūkot izejas kodu

add guards for prefix caching, fp8, chunked, etc

AlpinDale 7 mēneši atpakaļ
vecāks
revīzija
ac79d115b3

+ 2 - 1
aphrodite/attention/layer.py

@@ -29,7 +29,6 @@ class Attention(nn.Module):
         scale: float,
         num_kv_heads: Optional[int] = None,
         alibi_slopes: Optional[List[float]] = None,
-        sliding_window: Optional[int] = None,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         blocksparse_params: Optional[Dict[str, Any]] = None,
@@ -38,9 +37,11 @@ class Attention(nn.Module):
         if cache_config is not None:
             kv_cache_dtype = cache_config.cache_dtype
             block_size = cache_config.block_size
+            sliding_window = cache_config.sliding_window
         else:
             kv_cache_dtype = "auto"
             block_size = 16
+            sliding_window = None
         if num_kv_heads is None:
             num_kv_heads = num_heads
 

+ 55 - 3
aphrodite/common/config.py

@@ -77,6 +77,10 @@ class ModelConfig:
         max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
             When a sequence has context length larger than this, we fall back
             to eager mode
+        disable_sliding_window: Whether to disable sliding window. If True,
+            we will disable the sliding window functionality of the model.
+            If the model does not support sliding window, this argument is
+            ignored.
         skip_tokenizer_init: If true, skip initialization of tokenizer and
             detokenizer.
     """
@@ -104,6 +108,7 @@ class ModelConfig:
         max_context_len_to_capture: Optional[int] = None,
         max_seq_len_to_capture: Optional[int] = None,
         max_logprobs: int = 5,
+        disable_sliding_window: bool = False,
         skip_tokenizer_init: bool = False,
     ) -> None:
         self.model = model
@@ -129,14 +134,18 @@ class ModelConfig:
         self.max_seq_len_to_capture = (max_seq_len_to_capture
                                        or max_context_len_to_capture)
         self.max_logprobs = max_logprobs
+        self.disable_sliding_window = disable_sliding_window
         self.skip_tokenizer_init = skip_tokenizer_init
 
         self.hf_config = get_config(self.model, trust_remote_code, revision,
                                     code_revision, rope_scaling)
         self.hf_text_config = get_hf_text_config(self.hf_config)
         self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
-        self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
-                                                     max_model_len)
+        self.max_model_len = _get_and_verify_max_len(
+            hf_config=self.hf_text_config,
+            max_model_len=max_model_len,
+            disable_sliding_window=self.disable_sliding_window,
+            sliding_window_len=self.get_hf_config_sliding_window())
 
         if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
                 and getattr(self.hf_config, "rope_scaling", None) is None):
@@ -308,7 +317,7 @@ class ModelConfig:
                 "must be divisible by pipeline parallel size "
                 f"({pipeline_parallel_size}).")
 
-    def get_sliding_window(self) -> Optional[int]:
+    def get_hf_config_sliding_window(self) -> Optional[int]:
         """Get the sliding window size, or None if disabled.
         """
 
@@ -320,6 +329,15 @@ class ModelConfig:
             return None
         return getattr(self.hf_text_config, "sliding_window", None)
 
+    def get_sliding_window(self) -> Optional[int]:
+        """Get the sliding window size, or None if disabled.
+        """
+        # If user disables sliding window, return None.
+        if self.disable_sliding_window:
+            return None
+        # Otherwise get the value from the hf config.
+        return self.get_hf_config_sliding_window()
+
     def get_vocab_size(self) -> int:
         return self.hf_text_config.vocab_size
 
@@ -424,6 +442,7 @@ class CacheConfig:
         self.enable_prefix_caching = enable_prefix_caching
         self._verify_args()
         self._verify_cache_dtype()
+        self._verify_prefix_caching()
 
         # Will be set after profiling.
         self.num_gpu_blocks = None
@@ -452,6 +471,19 @@ class CacheConfig:
         else:
             raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
 
+    def _verify_prefix_caching(self) -> None:
+        if not self.enable_prefix_caching:
+            return
+
+        if self.sliding_window is not None:
+            raise NotImplementedError(
+                "Prefix caching is not supported with sliding window. "
+                "Run with --disable-sliding-window to use prefix caching.")
+        if self.cache_dtype == "fp8":
+            raise NotImplementedError(
+                "Prefix caching is not supported for fp8 cache_dtype. "
+                "Run with --kv-cache-dtype auto to use prefix caching.")
+
     def verify_with_parallel_config(
         self,
         parallel_config: "ParallelConfig",
@@ -1203,6 +1235,8 @@ def _get_and_verify_dtype(
 def _get_and_verify_max_len(
     hf_config: PretrainedConfig,
     max_model_len: Optional[int],
+    disable_sliding_window: bool,
+    sliding_window_len: Optional[int],
 ) -> int:
     """Get and verify the model's maximum length."""
     derived_max_model_len = float("inf")
@@ -1224,6 +1258,7 @@ def _get_and_verify_max_len(
         "max_seq_length",
         "seq_len",
     ]
+    # Choose the smallest "max_length" from the possible keys.
     max_len_key = None
     for key in possible_keys:
         max_len = getattr(hf_config, key, None)
@@ -1231,6 +1266,16 @@ def _get_and_verify_max_len(
             max_len_key = key if max_len < derived_max_model_len \
                 else max_len_key
             derived_max_model_len = min(derived_max_model_len, max_len)
+
+    # If sliding window is manually disabled, max_length should be less
+    # than the sliding window length in the model config.
+    if disable_sliding_window and sliding_window_len is not None:
+        max_len_key = "sliding_window" \
+            if sliding_window_len < derived_max_model_len else max_len_key
+        derived_max_model_len = min(derived_max_model_len, sliding_window_len)
+
+    # If none of the keys were found in the config, use a default and
+    # log a warning.
     if derived_max_model_len == float("inf"):
         if max_model_len is not None:
             # If max_model_len is specified, we use it.
@@ -1248,6 +1293,13 @@ def _get_and_verify_max_len(
     if rope_scaling is not None:
         rope_type = rope_scaling.get("type", rope_scaling.get("rope_type"))
         if rope_type not in {"su", "longrope", "llama3"}:
+            if disable_sliding_window:
+                # TODO: Find a model that supports rope_scaling
+                # with sliding window to see if this case should be allowed.
+                raise NotImplementedError(
+                    "Disabling sliding window is not supported for models "
+                    "with rope_scaling. Please raise an issue so we can "
+                    "investigate.")
             assert "factor" in rope_scaling
             scaling_factor = rope_scaling["factor"]
             if rope_type == "yarn":

+ 9 - 2
aphrodite/engine/args_tools.py

@@ -36,6 +36,7 @@ class EngineArgs:
     max_parallel_loading_workers: Optional[int] = None
     block_size: int = 16
     enable_prefix_caching: bool = False
+    disable_sliding_window: bool = False
     use_v2_block_manager: bool = False
     swap_space: int = 4  # GiB
     gpu_memory_utilization: float = 0.90
@@ -293,8 +294,12 @@ class EngineArgs:
             "--enable-prefix-caching",
             "--context-shift",
             action="store_true",
-            help="Enable context shifting.",
+            help="Enable automatic prefix caching.",
         )
+        parser.add_argument('--disable-sliding-window',
+                            action='store_true',
+                            help='Disables sliding window, '
+                            'capping to sliding window size')
         parser.add_argument("--use-v2-block-manager",
                             action="store_true",
                             help="Use the v2 block manager.")
@@ -633,6 +638,7 @@ class EngineArgs:
             self.max_context_len_to_capture,
             self.max_seq_len_to_capture,
             self.max_logprobs,
+            self.disable_sliding_window,
             self.skip_tokenizer_init,
         )
 
@@ -727,7 +733,8 @@ class EngineArgs:
         if (model_config.get_sliding_window() is not None
                 and scheduler_config.chunked_prefill_enabled):
             raise ValueError(
-                "Chunked prefill is not supported with sliding window.")
+                "Chunked prefill is not supported with sliding window. "
+                "Set --disable-sliding-window to disable sliding window.")
 
         return EngineConfig(model_config=model_config,
                             cache_config=cache_config,

+ 0 - 4
aphrodite/modeling/models/llama.py

@@ -94,7 +94,6 @@ class LlamaAttention(nn.Module):
         max_position_embeddings: int = 8192,
         quant_config: Optional[QuantizationConfig] = None,
         bias: bool = False,
-        sliding_window: Optional[int] = None,
         cache_config: Optional[CacheConfig] = None,
     ) -> None:
         super().__init__()
@@ -148,7 +147,6 @@ class LlamaAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              sliding_window=sliding_window,
                               cache_config=cache_config,
                               quant_config=quant_config)
 
@@ -185,7 +183,6 @@ class LlamaDecoderLayer(nn.Module):
                 config.original_max_position_embeddings)
         max_position_embeddings = getattr(config, "max_position_embeddings",
                                           8192)
-        sliding_window = getattr(config, "sliding_window", None)
         # Support abacusai/Smaug-72B-v0.1 with attention_bias
         # Support internlm/internlm-7b with bias
         attention_bias = getattr(config, "attention_bias", False) or getattr(
@@ -201,7 +198,6 @@ class LlamaDecoderLayer(nn.Module):
             max_position_embeddings=max_position_embeddings,
             quant_config=quant_config,
             bias=attention_bias,
-            sliding_window=sliding_window,
             cache_config=cache_config,
         )
         self.mlp = LlamaMLP(

+ 1 - 5
aphrodite/modeling/models/mixtral.py

@@ -251,8 +251,7 @@ class MixtralAttention(nn.Module):
                  max_position: int = 4096 * 32,
                  rope_theta: float = 10000,
                  cache_config: Optional[CacheConfig] = None,
-                 quant_config: Optional[QuantizationConfig] = None,
-                 sliding_window: Optional[int] = None) -> None:
+                 quant_config: Optional[QuantizationConfig] = None) -> None:
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -274,7 +273,6 @@ class MixtralAttention(nn.Module):
         self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
-        self.sliding_window = sliding_window
 
         if isinstance(
                 quant_config,
@@ -310,7 +308,6 @@ class MixtralAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              sliding_window=self.sliding_window,
                               cache_config=cache_config,
                               quant_config=quant_config)
 
@@ -347,7 +344,6 @@ class MixtralDecoderLayer(nn.Module):
             max_position=config.max_position_embeddings,
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
-            sliding_window=config.sliding_window,
             cache_config=cache_config,
             quant_config=quant_config)
         self.block_sparse_moe = MixtralMoE(

+ 0 - 4
aphrodite/modeling/models/mixtral_quant.py

@@ -165,7 +165,6 @@ class MixtralAttention(nn.Module):
         max_position: int = 4096 * 32,
         rope_theta: float = 10000,
         quant_config: Optional[QuantizationConfig] = None,
-        sliding_window: Optional[int] = None,
         cache_config: Optional[CacheConfig] = None,
     ) -> None:
         super().__init__()
@@ -189,7 +188,6 @@ class MixtralAttention(nn.Module):
         self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
-        self.sliding_window = sliding_window
 
         self.qkv_proj = QKVParallelLinear(
             hidden_size,
@@ -216,7 +214,6 @@ class MixtralAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              sliding_window=self.sliding_window,
                               cache_config=cache_config,
                               quant_config=quant_config)
 
@@ -253,7 +250,6 @@ class MixtralDecoderLayer(nn.Module):
             max_position=config.max_position_embeddings,
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
-            sliding_window=config.sliding_window,
             cache_config=cache_config,
             quant_config=quant_config)
         self.block_sparse_moe = MixtralMoE(config=config,

+ 13 - 11
aphrodite/modeling/models/qwen2.py

@@ -85,10 +85,8 @@ class Qwen2Attention(nn.Module):
                  num_kv_heads: int,
                  max_position: int = 4096 * 32,
                  rope_theta: float = 10000,
-                 use_sliding_window: bool = False,
                  cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None,
-                 sliding_window: Optional[int] = None,
                  rope_scaling: Optional[Tuple] = None) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -111,7 +109,6 @@ class Qwen2Attention(nn.Module):
         self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
-        self.sliding_window = sliding_window if use_sliding_window else None
 
         self.qkv_proj = QKVParallelLinear(
             hidden_size,
@@ -139,7 +136,6 @@ class Qwen2Attention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              sliding_window=self.sliding_window,
                               cache_config=cache_config,
                               quant_config=quant_config)
 
@@ -163,7 +159,6 @@ class Qwen2DecoderLayer(nn.Module):
     def __init__(
         self,
         config: Qwen2Config,
-        layer_idx: int,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
@@ -172,18 +167,14 @@ class Qwen2DecoderLayer(nn.Module):
         # Requires transformers > 4.32.0
         rope_theta = getattr(config, "rope_theta", 1000000)
         rope_scaling = getattr(config, "rope_scaling", None)
-        use_sliding_window = (config.use_sliding_window
-                              and layer_idx < config.max_window_layers)
         self.self_attn = Qwen2Attention(
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
             max_position=config.max_position_embeddings,
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
-            use_sliding_window=use_sliding_window,
             cache_config=cache_config,
             quant_config=quant_config,
-            sliding_window=config.sliding_window,
             rope_scaling=rope_scaling)
         self.mlp = Qwen2MLP(
             hidden_size=self.hidden_size,
@@ -243,8 +234,8 @@ class Qwen2Model(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config)
-            for layer_idx in range(config.num_hidden_layers)
+            Qwen2DecoderLayer(config, cache_config, quant_config)
+            for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
@@ -301,6 +292,17 @@ class Qwen2ForCausalLM(nn.Module):
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         del lora_config
+        # TODO (: see if this can be moved out
+        if (cache_config.sliding_window is not None
+                and hasattr(config, "max_window_layers")):
+            raise ValueError(
+                "Sliding window for some but all layers is not "
+                "supported. This model uses sliding window "
+                "but `max_window_layers` = "
+                f"{config.max_window_layers} is less than "
+                "`num_hidden_layers` = "
+                f"{config.num_hidden_layers}. Please open an issue"
+                " to discuss this feature.")
         super().__init__()
         self.config = config
         self.quant_config = quant_config

+ 0 - 2
aphrodite/modeling/models/starcoder2.py

@@ -73,7 +73,6 @@ class Starcoder2Attention(nn.Module):
         self.rope_theta = config.rope_theta
         self.max_position_embeddings = config.max_position_embeddings
         self.use_bias = config.use_bias
-        self.sliding_window = config.sliding_window
 
         self.qkv_proj = QKVParallelLinear(
             self.hidden_size,
@@ -100,7 +99,6 @@ class Starcoder2Attention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              sliding_window=self.sliding_window,
                               cache_config=cache_config,
                               quant_config=quant_config)
 

+ 0 - 4
aphrodite/modeling/models/xverse.py

@@ -87,7 +87,6 @@ class XverseAttention(nn.Module):
         max_position_embeddings: int = 8192,
         quant_config: Optional[QuantizationConfig] = None,
         bias: bool = False,
-        sliding_window: Optional[int] = None,
         cache_config: Optional[CacheConfig] = None,
     ) -> None:
         super().__init__()
@@ -133,7 +132,6 @@ class XverseAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              sliding_window=sliding_window,
                               cache_config=cache_config,
                               quant_config=quant_config)
 
@@ -166,7 +164,6 @@ class XverseDecoderLayer(nn.Module):
         rope_scaling = getattr(config, "rope_scaling", None)
         max_position_embeddings = getattr(config, "max_position_embeddings",
                                           8192)
-        sliding_window = getattr(config, "sliding_window", None)
         self.self_attn = XverseAttention(
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
@@ -177,7 +174,6 @@ class XverseDecoderLayer(nn.Module):
             max_position_embeddings=max_position_embeddings,
             quant_config=quant_config,
             bias=getattr(config, "bias", False),
-            sliding_window=sliding_window,
             cache_config=cache_config,
         )
         self.mlp = XverseMLP(