1
0
Эх сурвалжийг харах

chore: separate kv_scale into k_scale and v_scale

AlpinDale 7 сар өмнө
parent
commit
9d7beaa5b9

+ 11 - 7
aphrodite/_custom_ops.py

@@ -81,7 +81,8 @@ def paged_attention_v1(
     max_seq_len: int,
     alibi_slopes: Optional[torch.Tensor],
     kv_cache_dtype: str,
-    kv_scale: float,
+    k_scale: float,
+    v_scale: float,
     tp_rank: int = 0,
     blocksparse_local_blocks: int = 0,
     blocksparse_vert_stride: int = 0,
@@ -91,8 +92,9 @@ def paged_attention_v1(
     torch.ops._C.paged_attention_v1(
         out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
         seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
-        kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
-        blocksparse_block_size, blocksparse_head_sliding_step)
+        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
+        blocksparse_vert_stride, blocksparse_block_size,
+        blocksparse_head_sliding_step)
 
 
 def paged_attention_v2(
@@ -111,7 +113,8 @@ def paged_attention_v2(
     max_seq_len: int,
     alibi_slopes: Optional[torch.Tensor],
     kv_cache_dtype: str,
-    kv_scale: float,
+    k_scale: float,
+    v_scale: float,
     tp_rank: int = 0,
     blocksparse_local_blocks: int = 0,
     blocksparse_vert_stride: int = 0,
@@ -121,7 +124,7 @@ def paged_attention_v2(
     torch.ops._C.paged_attention_v2(
         out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
         num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
-        alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
+        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
         blocksparse_local_blocks, blocksparse_vert_stride,
         blocksparse_block_size, blocksparse_head_sliding_step)
 
@@ -418,11 +421,12 @@ def reshape_and_cache(
     value_cache: torch.Tensor,
     slot_mapping: torch.Tensor,
     kv_cache_dtype: str,
-    kv_scale: float,
+    k_scale: float,
+    v_scale: float,
 ) -> None:
     torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
                                              value_cache, slot_mapping,
-                                             kv_cache_dtype, kv_scale)
+                                             kv_cache_dtype, k_scale, v_scale)
 
 
 def reshape_and_cache_flash(

+ 6 - 3
aphrodite/_ipex_ops.py

@@ -53,7 +53,8 @@ class ipex_ops:
         max_context_len: int,
         alibi_slopes: Optional[torch.Tensor],
         kv_cache_dtype: str,
-        kv_scale: float,
+        k_scale: float,
+        v_scale: float,
         tp_rank: int = 0,
         blocksparse_local_blocks: int = 0,
         blocksparse_vert_stride: int = 0,
@@ -93,7 +94,8 @@ class ipex_ops:
         max_context_len: int,
         alibi_slopes: Optional[torch.Tensor],
         kv_cache_dtype: str,
-        kv_scale: float,
+        k_scale: float,
+        v_scale: float,
         tp_rank: int = 0,
         blocksparse_local_blocks: int = 0,
         blocksparse_vert_stride: int = 0,
@@ -221,7 +223,8 @@ class ipex_ops:
         value_cache: torch.Tensor,
         slot_mapping: torch.Tensor,
         kv_cache_dtype: str,
-        kv_scale: float,
+        k_scale: float,
+        v_scale: float,
     ) -> None:
         assert kv_cache_dtype == "auto"
         ipex.llm.modules.PagedAttention.reshape_and_cache(

+ 2 - 1
aphrodite/attention/backends/abstract.py

@@ -134,7 +134,8 @@ class AttentionImpl(ABC, Generic[T]):
         value: torch.Tensor,
         kv_cache: torch.Tensor,
         attn_metadata: T,
-        kv_scale: float = 1.0,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
         attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         raise NotImplementedError

+ 6 - 3
aphrodite/attention/backends/blocksparse_attn.py

@@ -324,7 +324,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
         value: torch.Tensor,
         kv_cache: torch.Tensor,
         attn_metadata: BlocksparseFlashAttentionMetadata,
-        kv_scale: float = 1.0,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
         attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with FlashAttention and PagedAttention.
@@ -363,7 +364,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
                 value_cache,
                 attn_metadata.slot_mapping,
                 self.kv_cache_dtype,
-                kv_scale,
+                k_scale,
+                v_scale,
             )
 
         if prefill_meta := attn_metadata.prefill_metadata:
@@ -400,7 +402,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,
-                kv_scale,
+                k_scale,
+                v_scale,
                 tp_rank=self.tp_rank,
                 blocksparse_local_blocks=self.local_blocks,
                 blocksparse_vert_stride=self.vert_stride,

+ 4 - 2
aphrodite/attention/backends/flash_attn.py

@@ -258,7 +258,8 @@ class FlashAttentionImpl(AttentionImpl):
         value: torch.Tensor,
         kv_cache: torch.Tensor,
         attn_metadata: FlashAttentionMetadata,
-        kv_scale: float = 1.0,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
         attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with FlashAttention.
@@ -278,7 +279,8 @@ class FlashAttentionImpl(AttentionImpl):
                                       "are not implemented for "
                                       "FlashAttentionImpl")
         # NOTE: FlashAttention does not support FP8 KV cache.
-        assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."
+        assert k_scale == 1.0 and v_scale == 1.0, (
+            "key/v_scale is not supported in FlashAttention.")
 
         num_tokens, hidden_size = query.shape
         # Reshape the query, key, and value tensors.

+ 4 - 2
aphrodite/attention/backends/flashinfer.py

@@ -226,10 +226,12 @@ class FlashInferImpl(AttentionImpl):
         value: torch.Tensor,
         kv_cache: Optional[torch.Tensor],
         attn_metadata: FlashInferMetadata,
-        kv_scale: float = 1.0,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
         attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
-        assert kv_scale == 1.0
+        assert k_scale == 1.0 and v_scale == 1.0, (
+            "key/v_scale is not supported in FlashInfer.")
         if attn_type != AttentionType.DECODER:
             raise NotImplementedError("Encoder self-attention and "
                                       "encoder/decoder cross-attention "

+ 9 - 5
aphrodite/attention/backends/ipex_attn.py

@@ -158,7 +158,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
         value: torch.Tensor,
         kv_cache: Optional[torch.Tensor],
         attn_metadata: IpexAttnMetadata,  # type: ignore
-        kv_scale: float = 1.0,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
         attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with IPEX varlen_attention and PagedAttention.
@@ -171,7 +172,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
-        assert kv_scale == 1.0
+        assert k_scale == 1.0 and v_scale == 1.0
         if attn_type != AttentionType.DECODER:
             raise NotImplementedError("Encoder self-attention and "
                                       "encoder/decoder cross-attention "
@@ -193,7 +194,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
                 value_cache,
                 attn_metadata.slot_mapping.flatten(),
                 self.kv_cache_dtype,
-                kv_scale,
+                k_scale,
+                v_scale,
             )
 
         if attn_metadata.is_prompt:
@@ -274,7 +276,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
                     max_seq_len,
                     self.alibi_slopes,
                     self.kv_cache_dtype,
-                    kv_scale,
+                    k_scale,
+                    v_scale,
                 )
             else:
                 # Run PagedAttention V2.
@@ -306,7 +309,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
                     max_seq_len,
                     self.alibi_slopes,
                     self.kv_cache_dtype,
-                    kv_scale,
+                    k_scale,
+                    v_scale,
                 )
 
             # Reshape the output tensor.

+ 3 - 2
aphrodite/attention/backends/pallas.py

@@ -133,7 +133,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
         value: torch.Tensor,
         kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
         attn_metadata: PallasMetadata,
-        kv_scale: float = 1.0,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
         attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with Pallas attention.
@@ -147,7 +148,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
         Returns:
             shape = [batch_size, seq_len, num_heads * head_size]
         """
-        assert kv_scale == 1.0
+        assert k_scale == 1.0 and v_scale == 1.0
         if attn_type != AttentionType.DECODER:
             raise NotImplementedError("Encoder self-attention and "
                                       "encoder/decoder cross-attention "

+ 6 - 3
aphrodite/attention/backends/rocm_flash_attn.py

@@ -298,7 +298,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
         value: torch.Tensor,
         kv_cache: torch.Tensor,
         attn_metadata: ROCmFlashAttentionMetadata,
-        kv_scale: float = 1.0,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
         attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with FlashAttention and PagedAttention.
@@ -337,7 +338,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                 value_cache,
                 attn_metadata.slot_mapping,
                 self.kv_cache_dtype,
-                kv_scale,
+                k_scale,
+                v_scale,
             )
 
         num_prefill_tokens = attn_metadata.num_prefill_tokens
@@ -457,7 +459,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,
-                kv_scale,
+                k_scale,
+                v_scale,
             )
 
         # Reshape the output tensor.

+ 7 - 5
aphrodite/attention/backends/torch_sdpa.py

@@ -146,7 +146,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
         value: torch.Tensor,
         kv_cache: Optional[torch.Tensor],
         attn_metadata: TorchSDPAMetadata,  # type: ignore
-        kv_scale: float = 1.0,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
         attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with torch SDPA and PagedAttention.
@@ -165,7 +166,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
                                       "encoder/decoder cross-attention "
                                       "are not implemented for "
                                       "TorchSDPABackendImpl")
-        assert kv_scale == 1.0
+        assert k_scale == 1.0 and v_scale == 1.0
         num_tokens, hidden_size = query.shape
         # Reshape the query, key, and value tensors.
         query = query.view(-1, self.num_heads, self.head_size)
@@ -178,8 +179,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
             PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                 value_cache,
                                                 attn_metadata.slot_mapping,
-                                                self.kv_cache_dtype,
-                                                kv_scale)
+                                                self.kv_cache_dtype, k_scale,
+                                                v_scale)
 
         if attn_metadata.is_prompt:
             assert attn_metadata.seq_lens is not None
@@ -242,7 +243,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,
-                kv_scale,
+                k_scale,
+                v_scale,
             )
 
         # Reshape the output tensor.

+ 5 - 3
aphrodite/attention/backends/xformers.py

@@ -426,7 +426,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
         value: Optional[torch.Tensor],
         kv_cache: Optional[torch.Tensor],
         attn_metadata: "XFormersMetadata",
-        kv_scale: float = 1.0,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
         attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with xFormers and PagedAttention.
@@ -530,7 +531,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
                                                     value_cache,
                                                     updated_slot_mapping,
                                                     self.kv_cache_dtype,
-                                                    kv_scale)
+                                                    k_scale, v_scale)
 
         if attn_type != AttentionType.ENCODER:
             # Decoder self-attention supports chunked prefill.
@@ -619,7 +620,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,
-                kv_scale,
+                k_scale,
+                v_scale,
             )
 
         # Reshape the output tensor.

+ 8 - 6
aphrodite/attention/layer.py

@@ -47,13 +47,14 @@ class Attention(nn.Module):
         if num_kv_heads is None:
             num_kv_heads = num_heads
 
-        # The default kv_scale is set to 1.0. This is ignored
+        # The default k/v_scale is set to 1.0. This is ignored
         # when kv-cache is not fp8, and should be used with
         # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
-        # expect the pre-quantized kv_scale to be loaded along
+        # expect the pre-quantized k/v_scale to be loaded along
         # with the model weights.
         self.kv_cache_dtype = kv_cache_dtype
-        self._kv_scale = 1.0
+        self._k_scale = 1.0
+        self._v_scale = 1.0
         quant_method = quant_config.get_quant_method(
             self) if quant_config else None
         if quant_method is not None:
@@ -65,8 +66,8 @@ class Attention(nn.Module):
                     raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                      "fp8 checkpoints.")
                 # When FP8 quantization is enabled, we make a parameter
-                # "kv_scale" so that it can be loaded from FP8 checkpoint.
-                # The kv_scale will then be converted back to self._kv_scale
+                # "k/v_scale" so that it can be loaded from FP8 checkpoint.
+                # The k/v_scale will then be converted back to self._k/v_scale
                 # in a native float32 value after weight loading.
                 self.quant_method = quant_method
                 self.quant_method.create_weights(self)
@@ -98,7 +99,8 @@ class Attention(nn.Module):
                                  value,
                                  kv_cache,
                                  attn_metadata,
-                                 self._kv_scale,
+                                 self._k_scale,
+                                 self._v_scale,
                                  attn_type=attn_type)
 
     def extra_repr(self) -> str:

+ 4 - 2
aphrodite/attention/ops/ipex_attn.py

@@ -45,7 +45,8 @@ class PagedAttention:
         value_cache: torch.Tensor,
         slot_mapping: torch.Tensor,
         kv_cache_dtype: str,
-        kv_scale: float,
+        k_scale: float,
+        v_scale: float,
         *args,
     ) -> None:
         ipex_modules.PagedAttention.reshape_and_cache(
@@ -64,7 +65,8 @@ class PagedAttention:
         num_kv_heads: int,
         scale: float,
         alibi_slopes: Optional[torch.Tensor],
-        kv_scale: float,
+        k_scale: float,
+        v_scale: float,
         *args,
     ) -> torch.Tensor:
         output = torch.empty_like(query)

+ 10 - 5
aphrodite/attention/ops/paged_attn.py

@@ -66,7 +66,8 @@ class PagedAttention:
         value_cache: torch.Tensor,
         slot_mapping: torch.Tensor,
         kv_cache_dtype: str,
-        kv_scale: float,
+        k_scale: float,
+        v_scale: float,
     ) -> None:
         ops.reshape_and_cache(
             key,
@@ -75,7 +76,8 @@ class PagedAttention:
             value_cache,
             slot_mapping.flatten(),
             kv_cache_dtype,
-            kv_scale,
+            k_scale,
+            v_scale,
         )
 
     @staticmethod
@@ -90,7 +92,8 @@ class PagedAttention:
         num_kv_heads: int,
         scale: float,
         alibi_slopes: Optional[torch.Tensor],
-        kv_scale: float,
+        k_scale: float,
+        v_scale: float,
         tp_rank: int = 0,
         blocksparse_local_blocks: int = 0,
         blocksparse_vert_stride: int = 0,
@@ -135,7 +138,8 @@ class PagedAttention:
                 max_seq_len,
                 alibi_slopes,
                 kv_cache_dtype,
-                kv_scale,
+                k_scale,
+                v_scale,
                 tp_rank,
                 blocksparse_local_blocks,
                 blocksparse_vert_stride,
@@ -172,7 +176,8 @@ class PagedAttention:
                 max_seq_len,
                 alibi_slopes,
                 kv_cache_dtype,
-                kv_scale,
+                k_scale,
+                v_scale,
                 tp_rank,
                 blocksparse_local_blocks,
                 blocksparse_vert_stride,

+ 9 - 0
aphrodite/modeling/layers/linear.py

@@ -194,6 +194,15 @@ class ReplicatedLinear(LinearBase):
         else:
             self.register_parameter("bias", None)
 
+    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
+        # If the weight on disk does not have a shape, give it one
+        # (such scales for AutoFp8).
+        if len(loaded_weight.shape) == 0:
+            loaded_weight = loaded_weight.reshape(1)
+
+        assert param.size() == loaded_weight.size()
+        param.data.copy_(loaded_weight)
+
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         bias = self.bias if not self.skip_bias_add else None
         assert self.quant_method is not None

+ 51 - 4
aphrodite/modeling/model_loader/weight_utils.py

@@ -19,6 +19,7 @@ from tqdm.auto import tqdm
 from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
 
 from aphrodite.common.config import LoadConfig, ModelConfig
+from aphrodite.common.utils import print_warning_once
 from aphrodite.quantization import QuantizationConfig, get_quantization_config
 from aphrodite.quantization.schema import QuantParamSchema
 
@@ -426,10 +427,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
 def default_weight_loader(param: torch.Tensor,
                           loaded_weight: torch.Tensor) -> None:
     """Default weight loader."""
-    # If the weight on disk does not have a shape, give it one
-    # (such scales for AutoFp8).
-    if len(loaded_weight.shape) == 0:
-        loaded_weight = loaded_weight.reshape(1)
+
     assert param.size() == loaded_weight.size()
     param.data.copy_(loaded_weight)
 
@@ -456,3 +454,52 @@ def initialize_dummy_weights(
                 param.data.copy_(tmp_param)
             else:
                 param.uniform_(low, high)
+
+
+def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
+    """Remap the name of FP8 k/v_scale parameters.
+    This function handles the remapping of FP8 k/v_scale parameter names.
+    It detects if the given name ends with a suffix and attempts to remap
+    it to the expected name format in the model. If the remapped name is not
+    found in the params_dict, a warning is printed and None is returned.
+    Args:
+        name (str): The original loaded checkpoint parameter name.
+        params_dict (dict): Dictionary containing the model's named parameters.
+    Returns:
+        str: The remapped parameter name if successful, or the original name
+             if no remapping is needed.
+        None: If the remapped name is not found in params_dict.
+    """
+    if name.endswith(".kv_scale"):
+        print_warning_once(
+            "DEPRECATED. Found kv_scale in the checkpoint. "
+            "This format is deprecated in favor of separate k_scale and "
+            "v_scale tensors and will be removed in a future release. "
+            "Functionally, we will remap kv_scale to k_scale and duplicate "
+            "k_scale to v_scale")
+        # NOTE: we remap the deprecated kv_scale to k_scale
+        remapped_name = name.replace(".kv_scale", ".attn.k_scale")
+        if remapped_name not in params_dict:
+            print_warning_once(
+                f"Found kv_scale in the checkpoint (e.g. {name}), "
+                "but not found the expected name in the model "
+                f"(e.g. {remapped_name}). kv_scale is "
+                "not loaded.")
+            return None
+        return remapped_name
+
+    possible_scale_names = [".k_scale", ".v_scale"]
+    for scale_name in possible_scale_names:
+        if name.endswith(scale_name):
+            remapped_name = name.replace(scale_name, f".attn{scale_name}")
+            if remapped_name not in params_dict:
+                print_warning_once(
+                    f"Found {scale_name} in the checkpoint (e.g. {name}), "
+                    "but not found the expected name in the model "
+                    f"(e.g. {remapped_name}). {scale_name} is "
+                    "not loaded.")
+                return None
+            return remapped_name
+
+    # If there were no matches, return the untouched param name
+    return name

+ 5 - 14
aphrodite/modeling/models/llama.py

@@ -30,7 +30,7 @@ from transformers import LlamaConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.common.utils import is_hip, print_warning_once
+from aphrodite.common.utils import is_hip
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
@@ -45,7 +45,7 @@ from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (
-    default_weight_loader, kv_cache_scales_loader)
+    default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
 from aphrodite.modeling.models.interfaces import SupportsLoRA
 from aphrodite.modeling.models.utils import (is_pp_missing_parameter,
                                              make_layers)
@@ -460,18 +460,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
                 if name.endswith(".bias") and name not in params_dict:
                     continue
                 # Remapping the name of FP8 kv-scale.
-                if name.endswith("kv_scale"):
-                    remapped_kv_scale_name = name.replace(
-                        ".kv_scale", ".attn.kv_scale")
-                    if remapped_kv_scale_name not in params_dict:
-                        print_warning_once(
-                            f"Found kv scale in the checkpoint (e.g. {name}), "
-                            "but not found the expected name in the model "
-                            f"(e.g. {remapped_kv_scale_name}). kv-scale is "
-                            "not loaded.")
-                        continue
-                    else:
-                        name = remapped_kv_scale_name
+                name = maybe_remap_kv_scale_name(name, params_dict)
+                if name is None:
+                    continue
 
                 if is_pp_missing_parameter(name, self):
                     continue

+ 5 - 15
aphrodite/modeling/models/mixtral.py

@@ -30,7 +30,6 @@ from transformers import MixtralConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.common.utils import print_warning_once
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.fused_moe import FusedMoE
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -42,7 +41,8 @@ from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
-from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.model_loader.weight_utils import (
+    default_weight_loader, maybe_remap_kv_scale_name)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.quantization.base_config import QuantizationConfig
 
@@ -414,19 +414,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
                     if name.endswith(".bias") and name not in params_dict:
                         continue
                     # Remapping the name of FP8 kv-scale.
-                    if name.endswith("kv_scale"):
-                        remapped_kv_scale_name = name.replace(
-                            ".kv_scale", ".attn.kv_scale")
-                        if remapped_kv_scale_name not in params_dict:
-                            print_warning_once(
-                                "Found kv scale in the checkpoint "
-                                f"(e.g. {name}), but not found the expected "
-                                f"name in the model "
-                                f"(e.g. {remapped_kv_scale_name}). "
-                                "kv-scale is not loaded.")
-                            continue
-                        else:
-                            name = remapped_kv_scale_name
+                    name = maybe_remap_kv_scale_name(name, params_dict)
+                    if name is None:
+                        continue
                     param = params_dict[name]
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)

+ 5 - 14
aphrodite/modeling/models/qwen2.py

@@ -31,7 +31,6 @@ from transformers import Qwen2Config
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.common.utils import print_warning_once
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -43,7 +42,8 @@ from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
-from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.model_loader.weight_utils import (
+    default_weight_loader, maybe_remap_kv_scale_name)
 from aphrodite.modeling.models.interfaces import SupportsLoRA
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.quantization.base_config import QuantizationConfig
@@ -377,18 +377,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
                 if name.endswith(".bias") and name not in params_dict:
                     continue
                 # Remapping the name of FP8 kv-scale.
-                if name.endswith("kv_scale"):
-                    remapped_kv_scale_name = name.replace(
-                        ".kv_scale", ".attn.kv_scale")
-                    if remapped_kv_scale_name not in params_dict:
-                        print_warning_once(
-                            f"Found kv scale in the checkpoint (e.g. {name}), "
-                            "but not found the expected name in the model "
-                            f"(e.g. {remapped_kv_scale_name}). kv-scale is "
-                            "not loaded.")
-                        continue
-                    else:
-                        name = remapped_kv_scale_name
+                name = maybe_remap_kv_scale_name(name, params_dict)
+                if name is None:
+                    continue
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)

+ 38 - 14
aphrodite/quantization/fp8.py

@@ -406,31 +406,55 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
         self.quant_config = quant_config
 
     def create_weights(self, layer: torch.nn.Module):
-        """Create "weight" (aka kv_scale) for an attention layer.
-
+        """Create "weight" (aka k_scale and v_scale) for an attention layer.
         Args:
             layer: The layer that is using the QuantizeMethodBase factory.
         """
-        # Initialize the KV cache scale to 1.0 as the default value.
-        # If the kv_scale appears in the checkpoint, it will be
+        # Initialize the KV cache scales to -1.0, which is an invalid value.
+        # If the k/v_scale appears in the checkpoint, it will be
         # overwritten when loading weights.
-        layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
+        layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
+        layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
 
     def apply(self, layer: torch.nn.Module) -> torch.Tensor:
         raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
 
     def process_weights_after_loading(self, layer: Module) -> None:
-        # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
+        # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
         # regardless whether the kv-scale is available in the checkpoint.
         if layer.kv_cache_dtype != "auto":
-            kv_scale = layer.kv_scale.to("cpu").tolist()
-            if not isinstance(kv_scale, float):
+            if layer.k_scale > 0.0 and layer.v_scale > 0.0:
+                # We prefer to use separate k_scale and v_scale if present
+                k_scale = layer.k_scale.to("cpu").tolist()
+                v_scale = layer.v_scale.to("cpu").tolist()
+            elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
+                # If no scales were loaded (both scales are invalid negative
+                # values), use the default value of 1.0
+                k_scale = Parameter(torch.tensor(1.0), requires_grad=False)
+                v_scale = Parameter(torch.tensor(1.0), requires_grad=False)
+            else:
+                # If we find a single kv_scale in the checkpoint, we remap
+                # kv_scale to k_scale during weight loading, and duplicate
+                # k_scale to v_scale here
+                assert layer.k_scale > 0.0
+                scale_to_duplicate = max(layer.k_scale, layer.v_scale)
+                k_scale = scale_to_duplicate.to("cpu").tolist()
+                v_scale = scale_to_duplicate.to("cpu").tolist()
+
+            if not isinstance(k_scale, float) or not isinstance(
+                    v_scale, float):
                 raise ValueError("Only support per-tensor scaling factor "
                                  "for fp8 KV cache")
-            layer._kv_scale = kv_scale
-            if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
+
+            # These are used in the final Attention.forward()
+            layer._k_scale = k_scale
+            layer._v_scale = v_scale
+            if (layer._k_scale == 1.0 and layer._v_scale == 1.0
+                    and "e5m2" not in layer.kv_cache_dtype):
                 print_warning_once(
-                    "Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
-                    "cause accuracy issues. Please make sure kv-cache scaling "
-                    "factor is available in the fp8 checkpoint.")
-        del layer.kv_scale
+                    "Using KV cache scaling factor 1.0 for fp8_e4m3. This "
+                    "may cause accuracy issues. Please make sure k/v_scale "
+                    "scaling factors are available in the fp8 checkpoint.")
+
+        del layer.k_scale
+        del layer.v_scale

+ 30 - 29
kernels/attention/attention_kernels.cu

@@ -106,9 +106,9 @@ __device__ void paged_attention_kernel(
     const int max_num_blocks_per_seq,
     const float* __restrict__ alibi_slopes,  // [num_heads]
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
-    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
-    const int blocksparse_vert_stride, const int blocksparse_block_size,
-    const int blocksparse_head_sliding_step) {
+    const float k_scale, const float v_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
   const int seq_idx = blockIdx.y;
   const int partition_idx = blockIdx.z;
   const int max_num_partitions = gridDim.z;
@@ -175,7 +175,7 @@ __device__ void paged_attention_kernel(
   // Each thread in a thread group has a different part of the query.
   // For example, if the the thread group size is 4, then the first thread in
   // the group has 0, 4, 8, ... th vectors of the query, and the second thread
-  // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
+  // has 1, 5, 9, ... th vectors of the query, and so on. NOTE: Because
   // q is split from a qkv tensor, it may not be contiguous.
   const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
   __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
@@ -286,7 +286,7 @@ __device__ void paged_attention_kernel(
           Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
               k_ptr + offset1 * BLOCK_SIZE * x + offset2);
           k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
-              k_vec_quant, kv_scale);
+              k_vec_quant, k_scale);
         }
       }
 
@@ -416,7 +416,7 @@ __device__ void paged_attention_kernel(
               *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
           // Vector conversion from V_quant_vec to V_vec.
           v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
-                                                                    kv_scale);
+                                                                    v_scale);
         }
         if (block_idx == num_seq_blocks - 1) {
           // NOTE: When v_vec contains the tokens that are out of the
@@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel(
     const int max_num_blocks_per_seq,
     const float* __restrict__ alibi_slopes,  // [num_heads]
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
-    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
-    const int blocksparse_vert_stride, const int blocksparse_block_size,
-    const int blocksparse_head_sliding_step) {
+    const float k_scale, const float v_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
                          KV_DTYPE, IS_BLOCK_SPARSE>(
       /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
       v_cache, num_kv_heads, scale, block_tables, seq_lens,
       max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
-      kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
+      kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
       blocksparse_vert_stride, blocksparse_block_size,
       blocksparse_head_sliding_step);
 }
@@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel(
     const int max_num_blocks_per_seq,
     const float* __restrict__ alibi_slopes,  // [num_heads]
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
-    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
-    const int blocksparse_vert_stride, const int blocksparse_block_size,
-    const int blocksparse_head_sliding_step) {
+    const float k_scale, const float v_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
                          KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
       exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
       block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
-      kv_block_stride, kv_head_stride, kv_scale, tp_rank,
+      kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
       blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
       blocksparse_head_sliding_step);
 }
@@ -682,11 +682,11 @@ __global__ void paged_attention_v2_reduce_kernel(
           out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads,    \
           scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq,       \
           alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride,         \
-          kv_scale, tp_rank, blocksparse_local_blocks,                         \
+          k_scale, v_scale, tp_rank, blocksparse_local_blocks,                 \
           blocksparse_vert_stride, blocksparse_block_size,                     \
           blocksparse_head_sliding_step);
 
-// TODO: Tune NUM_THREADS.
+// TODO(woosuk): Tune NUM_THREADS.
 template <typename T, typename CACHE_T, int BLOCK_SIZE,
           aphrodite::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
           int NUM_THREADS = 128>
@@ -694,8 +694,8 @@ void paged_attention_v1_launcher(
     torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& value_cache, int num_kv_heads, float scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
-    const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
-    const int tp_rank, const int blocksparse_local_blocks,
+    const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
+    float v_scale, const int tp_rank, const int blocksparse_local_blocks,
     const int blocksparse_vert_stride, const int blocksparse_block_size,
     const int blocksparse_head_sliding_step) {
   int num_seqs = query.size(0);
@@ -771,7 +771,7 @@ void paged_attention_v1_launcher(
   paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,              \
                               IS_BLOCK_SPARSE>(                              \
       out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
-      seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank,                \
+      seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank,        \
       blocksparse_local_blocks, blocksparse_vert_stride,                     \
       blocksparse_block_size, blocksparse_head_sliding_step);
 
@@ -816,8 +816,8 @@ void paged_attention_v1(
     torch::Tensor& seq_lens,      // [num_seqs]
     int64_t block_size, int64_t max_seq_len,
     const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
-    const int64_t blocksparse_local_blocks,
+    const std::string& kv_cache_dtype, double k_scale, double v_scale,
+    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
     const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
     const int64_t blocksparse_head_sliding_step) {
   const bool is_block_sparse = (blocksparse_vert_stride > 1);
@@ -834,7 +834,7 @@ void paged_attention_v1(
           exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
           value_cache_ptr, num_kv_heads, scale, block_tables_ptr,              \
           seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride,    \
-          kv_block_stride, kv_head_stride, kv_scale, tp_rank,                  \
+          kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,          \
           blocksparse_local_blocks, blocksparse_vert_stride,                   \
           blocksparse_block_size, blocksparse_head_sliding_step);              \
   aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS,       \
@@ -851,8 +851,8 @@ void paged_attention_v2_launcher(
     torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& value_cache, int num_kv_heads, float scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
-    const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
-    const int tp_rank, const int blocksparse_local_blocks,
+    const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
+    float v_scale, const int tp_rank, const int blocksparse_local_blocks,
     const int blocksparse_vert_stride, const int blocksparse_block_size,
     const int blocksparse_head_sliding_step) {
   int num_seqs = query.size(0);
@@ -917,7 +917,7 @@ void paged_attention_v2_launcher(
       LAUNCH_PAGED_ATTENTION_V2(128);
       break;
     case 192:
-      LAUNCH_PAGED_ATTENTION_V1(192);
+      LAUNCH_PAGED_ATTENTION_V2(192);
       break;
     case 256:
       LAUNCH_PAGED_ATTENTION_V2(256);
@@ -933,8 +933,9 @@ void paged_attention_v2_launcher(
                               IS_BLOCK_SPARSE>(                               \
       out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache,      \
       num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
-      kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,   \
-      blocksparse_block_size, blocksparse_head_sliding_step);
+      k_scale, v_scale, tp_rank, blocksparse_local_blocks,                    \
+      blocksparse_vert_stride, blocksparse_block_size,                        \
+      blocksparse_head_sliding_step);
 
 #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
   switch (is_block_sparse) {                                               \
@@ -981,8 +982,8 @@ void paged_attention_v2(
     torch::Tensor& seq_lens,      // [num_seqs]
     int64_t block_size, int64_t max_seq_len,
     const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
-    const int64_t blocksparse_local_blocks,
+    const std::string& kv_cache_dtype, double k_scale, double v_scale,
+    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
     const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
     const int64_t blocksparse_head_sliding_step) {
   const bool is_block_sparse = (blocksparse_vert_stride > 1);

+ 2 - 2
kernels/cache.h

@@ -18,8 +18,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
 void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
                        torch::Tensor& key_cache, torch::Tensor& value_cache,
                        torch::Tensor& slot_mapping,
-                       const std::string& kv_cache_dtype,
-                       const double kv_scale);
+                       const std::string& kv_cache_dtype, const double k_scale,
+                       const double v_scale);
 
 void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
                              torch::Tensor& key_cache,

+ 18 - 18
kernels/cache_kernels.cu

@@ -38,7 +38,7 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
     TORCH_CHECK(false, "Invalid device combination");
   }
 
-  // NOTE(youkaichao): keep in mind that `block_mapping` should be
+  // NOTE: keep in mind that `block_mapping` should be
   // a cpu tensor, otherwise every `item` call will require a gpu-cpu
   // synchronization.
   TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
@@ -50,7 +50,7 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
   const at::cuda::OptionalCUDAGuard device_guard(
       src_device.is_cuda() ? src_device : dst_device);
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  // NOTE(woosuk): This can be slow if the number of blocks is large.
+  // NOTE: This can be slow if the number of blocks is large.
   const int64_t num_blocks = block_mapping.size(0);
   for (size_t i = 0; i < num_blocks; i++) {
     int64_t src_block_number = block_mapping[i][0].item<int64_t>();
@@ -95,7 +95,7 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
 
 }  // namespace aphrodite
 
-// NOTE: the key_caches and value_caches vectors are constant but
+// Note: the key_caches and value_caches vectors are constant but
 // not the Tensors they contain. The vectors need to be const refs
 // in order to satisfy pytorch's C++ operator registration code.
 void copy_blocks(std::vector<torch::Tensor> const& key_caches,
@@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
                                          // block_size]
     const int64_t* __restrict__ slot_mapping,  // [num_tokens]
     const int key_stride, const int value_stride, const int num_heads,
-    const int head_size, const int block_size, const int x,
-    const float kv_scale) {
+    const int head_size, const int block_size, const int x, const float k_scale,
+    const float v_scale) {
   const int64_t token_idx = blockIdx.x;
   const int64_t slot_idx = slot_mapping[token_idx];
   if (slot_idx < 0) {
@@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
       value_cache[tgt_value_idx] = tgt_value;
     } else {
       key_cache[tgt_key_idx] =
-          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
+          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
       value_cache[tgt_value_idx] =
-          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
+          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
     }
   }
 }
@@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel(
           reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),           \
           reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),         \
           slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
-          num_heads, head_size, block_size, x, kv_scale);
+          num_heads, head_size, block_size, x, k_scale, v_scale);
 
 void reshape_and_cache(
     torch::Tensor& key,    // [num_tokens, num_heads, head_size]
@@ -258,7 +258,8 @@ void reshape_and_cache(
     torch::Tensor&
         value_cache,  // [num_blocks, num_heads, head_size, block_size]
     torch::Tensor& slot_mapping,  // [num_tokens]
-    const std::string& kv_cache_dtype, const double kv_scale) {
+    const std::string& kv_cache_dtype, const double k_scale,
+    const double v_scale) {
   int num_tokens = key.size(0);
   int num_heads = key.size(1);
   int head_size = key.size(2);
@@ -318,28 +319,27 @@ namespace aphrodite {
 template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
 __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
                                    Tout* __restrict__ dst_cache,
-                                   const float kv_scale,
+                                   const float scale,
                                    const int64_t block_stride) {
   const int64_t block_idx = blockIdx.x;
   for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
     int64_t idx = block_idx * block_stride + i;
     dst_cache[idx] =
-        fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
+        fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
   }
 }
 
 }  // namespace aphrodite
 
-#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE)                      \
-  aphrodite::convert_fp8_kernel<Tout, Tin, KV_DTYPE>               \
-      <<<grid, block, 0, stream>>>(                                \
-          reinterpret_cast<Tin*>(src_cache.data_ptr()),            \
-          reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, \
-          block_stride);
+#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE)           \
+  aphrodite::convert_fp8_kernel<Tout, Tin, KV_DTYPE>    \
+      <<<grid, block, 0, stream>>>(                     \
+          reinterpret_cast<Tin*>(src_cache.data_ptr()), \
+          reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
 
 // Only for testing.
 void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
-                 const double kv_scale, const std::string& kv_cache_dtype) {
+                 const double scale, const std::string& kv_cache_dtype) {
   torch::Device src_device = src_cache.device();
   torch::Device dst_device = dst_cache.device();
   TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")

+ 6 - 6
kernels/cpu/attention.cpp

@@ -423,11 +423,11 @@ void paged_attention_v1(
     torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
     int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
-    const int64_t blocksparse_local_blocks,
+    const std::string& kv_cache_dtype, double k_scale, double v_scale,
+    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
     const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
     const int64_t blocksparse_head_sliding_step) {
-  TORCH_CHECK(kv_scale == 1.0f);
+  TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
   TORCH_CHECK(blocksparse_vert_stride <= 1,
               "CPU backend does not support blocksparse attention yet.");
   APHRODITE_DISPATCH_FLOATING_TYPES(
@@ -742,11 +742,11 @@ void paged_attention_v2(
     torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
     int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
-    const int64_t blocksparse_local_blocks,
+    const std::string& kv_cache_dtype, double k_scale, double v_scale,
+    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
     const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
     const int64_t blocksparse_head_sliding_step) {
-  TORCH_CHECK(kv_scale == 1.0f);
+  TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
   TORCH_CHECK(blocksparse_vert_stride <= 1,
               "CPU backend does not support blocksparse attention yet.");
   APHRODITE_DISPATCH_FLOATING_TYPES(

+ 3 - 2
kernels/cpu/cache.cpp

@@ -107,8 +107,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
 void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
                        torch::Tensor& key_cache, torch::Tensor& value_cache,
                        torch::Tensor& slot_mapping,
-                       const std::string& kv_cache_dtype, double kv_scale) {
-  TORCH_CHECK(kv_scale == 1.0f);
+                       const std::string& kv_cache_dtype, double k_scale,
+                       double v_scale) {
+  TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
   int num_tokens = key.size(0);
   int num_heads = key.size(1);
   int head_size = key.size(2);

+ 5 - 5
kernels/cpu/torch_bindings.cpp

@@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "    Tensor value_cache, int num_kv_heads, float scale,"
       "    Tensor block_tables, Tensor seq_lens, int block_size,"
       "    int max_seq_len, Tensor? alibi_slopes,"
-      "    str kv_cache_dtype, float kv_scale, int tp_rank,"
-      "    int blocksparse_local_blocks,"
+      "    str kv_cache_dtype, float k_scale, float v_scale,"
+      "    int tp_rank, int blocksparse_local_blocks,"
       "    int blocksparse_vert_stride, int blocksparse_block_size,"
       "    int blocksparse_head_sliding_step) -> ()");
   ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
@@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "    Tensor value_cache, int num_kv_heads, float scale,"
       "    Tensor block_tables, Tensor seq_lens, int block_size,"
       "    int max_seq_len, Tensor? alibi_slopes,"
-      "    str kv_cache_dtype, float kv_scale, int tp_rank,"
-      "    int blocksparse_local_blocks,"
+      "    str kv_cache_dtype, float k_scale, float v_scale,"
+      "    int tp_rank, int blocksparse_local_blocks,"
       "    int blocksparse_vert_stride, int blocksparse_block_size,"
       "    int blocksparse_head_sliding_step) -> ()");
   ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
@@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
       "                  Tensor! key_cache, Tensor! value_cache,"
       "                  Tensor slot_mapping,"
       "                  str kv_cache_dtype,"
-      "                  float kv_scale) -> ()");
+      "                  float k_scale, float v_scale) -> ()");
   cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
 }
 

+ 4 - 4
kernels/ops.h

@@ -8,8 +8,8 @@ void paged_attention_v1(
     torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
     int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
-    const int64_t blocksparse_local_blocks,
+    const std::string& kv_cache_dtype, double k_scale, double v_scale,
+    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
     const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
     const int64_t blocksparse_head_sliding_step);
 
@@ -19,8 +19,8 @@ void paged_attention_v2(
     torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
     int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
-    const int64_t blocksparse_local_blocks,
+    const std::string& kv_cache_dtype, double k_scale, double v_scale,
+    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
     const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
     const int64_t blocksparse_head_sliding_step);
 

+ 5 - 5
kernels/torch_bindings.cpp

@@ -28,8 +28,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "    Tensor value_cache, int num_kv_heads, float scale,"
       "    Tensor block_tables, Tensor seq_lens, int block_size,"
       "    int max_seq_len, Tensor? alibi_slopes,"
-      "    str kv_cache_dtype, float kv_scale, int tp_rank,"
-      "    int blocksparse_local_blocks,"
+      "    str kv_cache_dtype, float k_scale, float v_scale,"
+      "    int tp_rank, int blocksparse_local_blocks,"
       "    int blocksparse_vert_stride, int blocksparse_block_size,"
       "    int blocksparse_head_sliding_step) -> ()");
   ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
@@ -42,8 +42,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "    Tensor value_cache, int num_kv_heads, float scale,"
       "    Tensor block_tables, Tensor seq_lens, int block_size,"
       "    int max_seq_len, Tensor? alibi_slopes,"
-      "    str kv_cache_dtype, float kv_scale, int tp_rank,"
-      "    int blocksparse_local_blocks,"
+      "    str kv_cache_dtype, float k_scale, float v_scale,"
+      "    int tp_rank, int blocksparse_local_blocks,"
       "    int blocksparse_vert_stride, int blocksparse_block_size,"
       "    int blocksparse_head_sliding_step) -> ()");
   ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
@@ -250,7 +250,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
       "                  Tensor! key_cache, Tensor! value_cache,"
       "                  Tensor slot_mapping,"
       "                  str kv_cache_dtype,"
-      "                  float kv_scale) -> ()");
+      "                  float k_scale, float v_scale) -> ()");
   cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
 
   // Reshape the key and value tensors and cache them.