瀏覽代碼

make fp8_e4m3 work on nvidia

AlpinDale 7 月之前
父節點
當前提交
656459fd84
共有 36 個文件被更改,包括 218 次插入112 次删除
  1. 23 3
      aphrodite/attention/layer.py
  2. 3 5
      aphrodite/common/config.py
  3. 2 0
      aphrodite/common/utils.py
  4. 3 4
      aphrodite/engine/args_tools.py
  5. 2 1
      aphrodite/modeling/models/arctic.py
  6. 4 2
      aphrodite/modeling/models/baichuan.py
  7. 2 1
      aphrodite/modeling/models/bloom.py
  8. 6 7
      aphrodite/modeling/models/chatglm.py
  9. 6 7
      aphrodite/modeling/models/commandr.py
  10. 6 7
      aphrodite/modeling/models/dbrx.py
  11. 2 1
      aphrodite/modeling/models/deepseek.py
  12. 6 3
      aphrodite/modeling/models/falcon.py
  13. 2 1
      aphrodite/modeling/models/gemma.py
  14. 2 1
      aphrodite/modeling/models/gpt2.py
  15. 2 1
      aphrodite/modeling/models/gpt_bigcode.py
  16. 2 1
      aphrodite/modeling/models/gpt_j.py
  17. 2 1
      aphrodite/modeling/models/gpt_neox.py
  18. 2 1
      aphrodite/modeling/models/internlm2.py
  19. 6 7
      aphrodite/modeling/models/jais.py
  20. 18 14
      aphrodite/modeling/models/llama.py
  21. 2 1
      aphrodite/modeling/models/minicpm.py
  22. 21 8
      aphrodite/modeling/models/mixtral.py
  23. 7 8
      aphrodite/modeling/models/mixtral_quant.py
  24. 2 1
      aphrodite/modeling/models/mpt.py
  25. 2 1
      aphrodite/modeling/models/olmo.py
  26. 2 1
      aphrodite/modeling/models/opt.py
  27. 2 1
      aphrodite/modeling/models/orion.py
  28. 2 1
      aphrodite/modeling/models/phi.py
  29. 2 1
      aphrodite/modeling/models/qwen.py
  30. 2 1
      aphrodite/modeling/models/qwen2.py
  31. 2 1
      aphrodite/modeling/models/qwen2_moe.py
  32. 2 1
      aphrodite/modeling/models/stablelm.py
  33. 7 8
      aphrodite/modeling/models/starcoder2.py
  34. 2 1
      aphrodite/modeling/models/xverse.py
  35. 48 4
      aphrodite/quantization/fp8.py
  36. 12 5
      aphrodite/task_handler/model_runner.py

+ 23 - 3
aphrodite/attention/layer.py

@@ -7,6 +7,7 @@ import torch.nn as nn
 from aphrodite.attention.backends.abstract import AttentionMetadata
 from aphrodite.attention.selector import get_attn_backend
 from aphrodite.common.config import CacheConfig
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class Attention(nn.Module):
@@ -28,6 +29,7 @@ class Attention(nn.Module):
         alibi_slopes: Optional[List[float]] = None,
         sliding_window: Optional[int] = None,
         cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         if cache_config is not None:
@@ -38,6 +40,26 @@ class Attention(nn.Module):
             block_size = 16
         if num_kv_heads is None:
             num_kv_heads = num_heads
+
+        # The default kv_scale is set to 1.0. This is ignored
+        # when kv-cache is not fp8, and should be used with
+        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
+        # expect the pre-quantized kv_scale to be loaded along
+        # with the model weights.
+        self.kv_cache_dtype = kv_cache_dtype
+        self._kv_scale = 1.0
+        quant_method = quant_config.get_quant_method(
+            self) if quant_config else None
+        if quant_method is not None:
+            if self.kv_cache_dtype == "fp8_e5m2":
+                raise ValueError("fp8_e5m2 kv-cache is not supported with "
+                                 "fp8 checkpoints.")
+            # When FP8 quantization is enabled, we make a parameter
+            # "kv_scale" so that it can be loaded from FP8 checkpoint.
+            # The kv_scale will then be converted back
+            # to self._kv_scale in a native float32 value after weight loading.
+            self.quant_method = quant_method
+            self.quant_method.create_weights(self)
         # During model initialization, the default dtype is set as the model
         # weight and activation dtype.
         dtype = torch.get_default_dtype()
@@ -55,10 +77,8 @@ class Attention(nn.Module):
         value: torch.Tensor,
         kv_cache: Optional[torch.Tensor],
         attn_metadata: AttentionMetadata,
-        kv_scale: float = 1.0,
     ) -> torch.Tensor:
-        return self.impl.forward(query, key, value, kv_cache, attn_metadata,
-                                 kv_scale)
+        return self.impl.forward(query, key, value, kv_cache, attn_metadata)
 
     def extra_repr(self) -> str:
         s = f"head_size={self.impl.head_size}"  # type: ignore

+ 3 - 5
aphrodite/common/config.py

@@ -443,14 +443,12 @@ class CacheConfig:
     def _verify_cache_dtype(self) -> None:
         if self.cache_dtype == "auto":
             pass
-        elif self.cache_dtype == "fp8":
+        elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
             logger.info(
                 "Using fp8 data type to store kv cache. It reduces the GPU "
                 "memory footprint and boosts the performance. "
-                "But it may cause slight accuracy drop without scaling "
-                "factors. FP8_E5M2 (without scaling) is only supported on "
-                "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
-                "is instead supported for common inference criteria.")
+                "Meanwhile, it may cause accuracy drop without a proper "
+                "scaling factor")
         else:
             raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
 

+ 2 - 0
aphrodite/common/utils.py

@@ -30,6 +30,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
     "bfloat16": torch.bfloat16,
     "float": torch.float,
     "fp8": torch.uint8,
+    "fp8_e4m3": torch.uint8,
+    "fp8_e5m2": torch.uint8,
 }
 
 

+ 3 - 4
aphrodite/engine/args_tools.py

@@ -207,12 +207,11 @@ class EngineArgs:
         parser.add_argument(
             '--kv-cache-dtype',
             type=str,
-            choices=['auto', 'fp8'],
+            choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
             default=EngineArgs.kv_cache_dtype,
             help='Data type for kv cache storage. If "auto", will use model '
-            'data type. FP8_E5M2 (without scaling) is only supported on cuda '
-            'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
-            'supported for common inference criteria. ')
+            'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
+            'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
         parser.add_argument(
             '--quantization-param-path',
             type=str,

+ 2 - 1
aphrodite/modeling/models/arctic.py

@@ -265,7 +265,8 @@ class ArcticAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 4 - 2
aphrodite/modeling/models/baichuan.py

@@ -153,7 +153,8 @@ class BaiChuanAttention(nn.Module):
             self.attn = Attention(self.num_heads,
                                   self.head_dim,
                                   scaling,
-                                  alibi_slopes=alibi_slopes)
+                                  alibi_slopes=alibi_slopes,
+                                  quant_config=quant_config)
         else:
             self.rotary_emb = get_rope(
                 self.head_dim,
@@ -165,7 +166,8 @@ class BaiChuanAttention(nn.Module):
             self.attn = Attention(self.num_heads,
                                   self.head_dim,
                                   self.scaling,
-                                  cache_config=cache_config)
+                                  cache_config=cache_config,
+                                  quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/bloom.py

@@ -110,7 +110,8 @@ class BloomAttention(nn.Module):
                               self.head_dim,
                               scaling,
                               alibi_slopes=alibi_slopes,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 6 - 7
aphrodite/modeling/models/chatglm.py

@@ -85,13 +85,12 @@ class GLMAttention(nn.Module):
             base=10000 * rope_ratio,
             is_neox_style=False,
         )
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            self.scaling,
-            num_kv_heads=self.num_kv_heads,
-            cache_config=cache_config,
-        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 6 - 7
aphrodite/modeling/models/commandr.py

@@ -176,13 +176,12 @@ class CohereAttention(nn.Module):
             rope_scaling=self.rope_scaling,
             is_neox_style=False,
         )
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            self.scaling,
-            num_kv_heads=self.num_kv_heads,
-            cache_config=cache_config,
-        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
         if self.use_qk_norm:
             self.q_norm = LayerNorm(param_shape=(self.num_heads,
                                                  self.head_dim),

+ 6 - 7
aphrodite/modeling/models/dbrx.py

@@ -217,13 +217,12 @@ class DbrxAttention(nn.Module):
         self.q_size = self.num_heads * self.head_dim
         self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            self.scaling,
-            num_kv_heads=self.num_kv_heads,
-            cache_config=cache_config,
-        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/deepseek.py

@@ -231,7 +231,8 @@ class DeepseekAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 6 - 3
aphrodite/modeling/models/falcon.py

@@ -152,7 +152,8 @@ class FalconAttention(nn.Module):
             self.attn = Attention(self.num_heads,
                                   self.head_dim,
                                   self.inv_norm_factor,
-                                  num_kv_heads=self.num_kv_heads)
+                                  num_kv_heads=self.num_kv_heads,
+                                  quant_config=quant_config)
         elif self.use_alibi:
             tp_rank = get_tensor_model_parallel_rank()
             head_start = tp_rank * self.num_heads
@@ -164,13 +165,15 @@ class FalconAttention(nn.Module):
                                   self.head_dim,
                                   self.inv_norm_factor,
                                   num_kv_heads=self.num_kv_heads,
-                                  alibi_slopes=alibi_slopes)
+                                  alibi_slopes=alibi_slopes,
+                                  quant_config=quant_config)
         else:
             self.attn = Attention(self.num_heads,
                                   self.head_dim,
                                   scale=self.inv_norm_factor,
                                   num_kv_heads=self.num_kv_heads,
-                                  cache_config=cache_config)
+                                  cache_config=cache_config,
+                                  quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/gemma.py

@@ -154,7 +154,8 @@ class GemmaAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/gpt2.py

@@ -74,7 +74,8 @@ class GPT2Attention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               scale=self.scale,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/gpt_bigcode.py

@@ -87,7 +87,8 @@ class GPTBigCodeAttention(nn.Module):
                               self.head_dim,
                               scale=self.scale,
                               num_kv_heads=self.num_kv_heads,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/gpt_j.py

@@ -87,7 +87,8 @@ class GPTJAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_size,
                               scaling,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/gpt_neox.py

@@ -88,7 +88,8 @@ class GPTNeoXAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_size,
                               scaling,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/internlm2.py

@@ -116,7 +116,8 @@ class InternLM2Attention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 6 - 7
aphrodite/modeling/models/jais.py

@@ -104,13 +104,12 @@ class JAISAttention(nn.Module):
         head_end = (tp_rank + 1) * self.num_heads
         alibi_slopes = _get_alibi_slopes(total_num_heads)
         alibi_slopes = alibi_slopes[head_start:head_end]
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            scale=self.scale,
-            alibi_slopes=alibi_slopes,
-            cache_config=cache_config,
-        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scale=self.scale,
+                              alibi_slopes=alibi_slopes,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 18 - 14
aphrodite/modeling/models/llama.py

@@ -46,7 +46,7 @@ from aphrodite.modeling.model_loader.weight_utils import (
     default_weight_loader, kv_cache_scales_loader)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.quantization.base_config import QuantizationConfig
-from aphrodite.common.utils import is_hip
+from aphrodite.common.utils import is_hip, print_warning_once
 
 
 class LlamaMLP(nn.Module):
@@ -121,15 +121,6 @@ class LlamaAttention(nn.Module):
         self.rope_theta = rope_theta
         self.max_position_embeddings = max_position_embeddings
 
-        # This will be overwritten by model initialization if we are using it.
-        # N.B. currently we only support per tensor scalar scaling factors
-        # & only applicable to ROCm (AMD GPU).
-        # The scaling factor convention we are assuming is
-        # quantized_value * scaling_factor ~= true_value
-        # which is consistent with the practice of setting
-        # scaling_factor = tensor_amax / FPtype_max
-        self.kv_scale = 1.0
-
         self.qkv_proj = QKVParallelLinear(
             hidden_size,
             self.head_dim,
@@ -157,7 +148,8 @@ class LlamaAttention(nn.Module):
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
                               sliding_window=sliding_window,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,
@@ -169,8 +161,7 @@ class LlamaAttention(nn.Module):
         qkv, _ = self.qkv_proj(hidden_states)
         q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         q, k = self.rotary_emb(positions, q, k)
-        attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
-                                self.kv_scale)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.o_proj(attn_output)
         return output
 
@@ -424,6 +415,19 @@ class LlamaForCausalLM(nn.Module):
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:
                     continue
+                # Remapping the name of FP8 kv-scale.
+                if name.endswith("kv_scale"):
+                    remapped_kv_scale_name = name.replace(
+                        ".kv_scale", ".attn.kv_scale")
+                    if remapped_kv_scale_name not in params_dict:
+                        print_warning_once(
+                            f"Found kv scale in the checkpoint (e.g. {name}), "
+                            "but not found the expected name in the model "
+                            f"(e.g. {remapped_kv_scale_name}). kv-scale is "
+                            "not loaded.")
+                        continue
+                    else:
+                        name = remapped_kv_scale_name
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
@@ -448,7 +452,7 @@ class LlamaForCausalLM(nn.Module):
                 # scaling_factor = tensor_amax / FPtype_max
                 scaling_factor *= 2
             if hasattr(layer_self_attn, "kv_scale"):
-                layer_self_attn.kv_scale = scaling_factor
+                layer_self_attn.attn._kv_scale = scaling_factor
             else:
                 raise RuntimeError("Self attention has no KV cache scaling "
                                    "factor attribute!")

+ 2 - 1
aphrodite/modeling/models/minicpm.py

@@ -235,7 +235,8 @@ class MiniCPMAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 21 - 8
aphrodite/modeling/models/mixtral.py

@@ -306,14 +306,13 @@ class MixtralAttention(nn.Module):
             base=int(self.rope_theta),
             is_neox_style=True,
         )
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            self.scaling,
-            num_kv_heads=self.num_kv_heads,
-            sliding_window=self.sliding_window,
-            cache_config=cache_config,
-        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              sliding_window=self.sliding_window,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,
@@ -579,6 +578,20 @@ class MixtralForCausalLM(nn.Module):
                     # Skip loading extra bias for GPTQ models.
                     if name.endswith(".bias") and name not in params_dict:
                         continue
+                    # Remapping the name of FP8 kv-scale.
+                    if name.endswith("kv_scale"):
+                        remapped_kv_scale_name = name.replace(
+                            ".kv_scale", ".attn.kv_scale")
+                        if remapped_kv_scale_name not in params_dict:
+                            print_warning_once(
+                                "Found kv scale in the checkpoint "
+                                f"(e.g. {name}), but not found the expected "
+                                f"name in the model "
+                                f"(e.g. {remapped_kv_scale_name}). "
+                                "kv-scale is not loaded.")
+                            continue
+                        else:
+                            name = remapped_kv_scale_name
                     param = params_dict[name]
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)

+ 7 - 8
aphrodite/modeling/models/mixtral_quant.py

@@ -212,14 +212,13 @@ class MixtralAttention(nn.Module):
             base=int(self.rope_theta),
             is_neox_style=True,
         )
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            self.scaling,
-            num_kv_heads=self.num_kv_heads,
-            sliding_window=self.sliding_window,
-            cache_config=cache_config,
-        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              sliding_window=self.sliding_window,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/mpt.py

@@ -109,7 +109,8 @@ class MPTAttention(nn.Module):
                               scaling,
                               alibi_slopes=alibi_slopes,
                               num_kv_heads=self.num_kv_heads,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/olmo.py

@@ -95,7 +95,8 @@ class OlmoAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               scale=self.scaling,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
         # Attention output projection.
         self.o_proj = RowParallelLinear(

+ 2 - 1
aphrodite/modeling/models/opt.py

@@ -90,7 +90,8 @@ class OPTAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               scale=self.scaling,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/orion.py

@@ -120,7 +120,8 @@ class OrionAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/phi.py

@@ -109,7 +109,8 @@ class PhiAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_size,
                               scaling,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/qwen.py

@@ -105,7 +105,8 @@ class QWenAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               self.scaling,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/qwen2.py

@@ -140,7 +140,8 @@ class Qwen2Attention(nn.Module):
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
                               sliding_window=self.sliding_window,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/qwen2_moe.py

@@ -240,7 +240,8 @@ class Qwen2MoeAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/stablelm.py

@@ -126,7 +126,8 @@ class StablelmAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_key_value_heads,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 7 - 8
aphrodite/modeling/models/starcoder2.py

@@ -96,14 +96,13 @@ class Starcoder2Attention(nn.Module):
             base=int(self.rope_theta),
             is_neox_style=True,
         )
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            self.scaling,
-            num_kv_heads=self.num_kv_heads,
-            sliding_window=self.sliding_window,
-            cache_config=cache_config,
-        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              sliding_window=self.sliding_window,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 2 - 1
aphrodite/modeling/models/xverse.py

@@ -134,7 +134,8 @@ class XverseAttention(nn.Module):
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
                               sliding_window=sliding_window,
-                              cache_config=cache_config)
+                              cache_config=cache_config,
+                              quant_config=quant_config)
 
     def forward(
         self,

+ 48 - 4
aphrodite/quantization/fp8.py

@@ -1,14 +1,16 @@
-from typing import Any, Dict, List, Optional, Tuple, Union
 from contextlib import suppress
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import torch
+from loguru import logger
 from torch.nn import Module
 from torch.nn.parameter import Parameter
-from loguru import logger
 
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
-from aphrodite.quantization.base_config import (QuantizationConfig)
 from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import (QuantizationConfig,
+                                                QuantizeMethodBase)
+from aphrodite.common.utils import print_warning_once
 
 HAS_QUANTS = False
 with suppress(ImportError):
@@ -96,9 +98,13 @@ class Fp8Config(QuantizationConfig):
                    activation_scheme=activation_scheme)
 
     def get_quant_method(
-            self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
+            self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
+        from aphrodite.attention.layer import Attention  # Avoid circular import
+
         if isinstance(layer, LinearBase):
             return Fp8LinearMethod(self)
+        if isinstance(layer, Attention):
+            return Fp8KVCacheMethod(self)
         return None
 
     def get_scaled_act_names(self) -> List[str]:
@@ -288,6 +294,44 @@ class Fp8LinearMethod(LinearMethodBase):
         return torch.narrow(output, 0, 0, x.shape[0])
 
 
+class Fp8KVCacheMethod(QuantizeMethodBase):
+    """Supports loading kv-cache scaling factors from FP8 checkpoints.
+    """
+
+    def __init__(self, quant_config: Fp8Config):
+        self.quant_config = quant_config
+
+    def create_weights(self, layer: torch.nn.Module):
+        """Create "weight" (aka kv_scale) for an attention layer. 
+        
+        Args:
+            layer: The layer that is using the QuantizeMethodBase factory.
+        """
+        # Initialize the KV cache scale to 1.0 as the default value.
+        # If the kv_scale appears in the checkpoint, it will be
+        # overwritten when loading weights.
+        layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
+
+    def apply(self, layer: torch.nn.Module) -> torch.Tensor:
+        raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
+
+    def process_weights_after_loading(self, layer: Module) -> None:
+        # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
+        # regardless whether the kv-scale is available in the checkpoint.
+        if layer.kv_cache_dtype != "auto":
+            kv_scale = layer.kv_scale.to("cpu").tolist()
+            if not isinstance(kv_scale, float):
+                raise ValueError("Only support per-tensor scaling factor "
+                                 "for fp8 KV cache")
+            layer._kv_scale = kv_scale
+            if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
+                print_warning_once(
+                    "Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
+                    "cause accuracy issues. Please make sure kv-cache scaling "
+                    "factor is available in the fp8 checkpoint.")
+        del layer.kv_scale
+
+
 def all_close_1d(x: torch.Tensor) -> bool:
     assert len(x.shape) == 1
     return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))

+ 12 - 5
aphrodite/task_handler/model_runner.py

@@ -1,4 +1,5 @@
 import time
+import warnings
 from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
 
 import numpy as np
@@ -209,11 +210,21 @@ class ModelRunner:
             self.model = self.lora_manager.create_lora_manager(self.model)
 
         if self.kv_cache_dtype == "fp8" and is_hip():
-            # Currently scaled KV cache is only enabled on ROCm
+            # Currently only ROCm accepts kv-cache scaling factors
+            # via quantization_param_path and this will be deprecated
+            # in the future.
             if self.model_config.quantization_param_path is not None:
                 if callable(getattr(self.model, "load_kv_cache_scales", None)):
+                    warnings.warn(
+                        "Loading kv cache scaling factor from JSON is "
+                        "deprecated and will be removed. Please include "
+                        "kv cache scaling factors in the model checkpoint.",
+                        FutureWarning,
+                        stacklevel=2)
                     self.model.load_kv_cache_scales(
                         self.model_config.quantization_param_path)
+                    logger.info("Loaded KV cache scaling factors from "
+                                f"{self.model_config.quantization_param_path}")
                 else:
                     raise RuntimeError("Using FP8 KV cache and scaling factors"
                                        " provided but model "
@@ -224,10 +235,6 @@ class ModelRunner:
                     "Using FP8 KV cache but no scaling factors "
                     "provided. Defaulting to scaling factors of 1.0. "
                     "This may lead to less accurate results!")
-        elif self.model_config.quantization_param_path is not None:
-            logger.warning("KV cache scaling factors provided, "
-                           "but the KV cache data type is not FP8. "
-                           "KV cache scaling factors will not be used.")
 
     def save_sharded_state(
         self,