Kaynağa Gözat

fix: fp8 kv cache for qwen2 models

AlpinDale 7 ay önce
ebeveyn
işleme
025322ee5f
1 değiştirilmiş dosya ile 14 ekleme ve 0 silme
  1. 14 0
      aphrodite/modeling/models/qwen2.py

+ 14 - 0
aphrodite/modeling/models/qwen2.py

@@ -31,6 +31,7 @@ from transformers import Qwen2Config
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.utils import print_warning_once
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -373,6 +374,19 @@ class Qwen2ForCausalLM(nn.Module):
                 # Skip loading extra bias for GPTQ models.
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:
                 if name.endswith(".bias") and name not in params_dict:
                     continue
                     continue
+                # Remapping the name of FP8 kv-scale.
+                if name.endswith("kv_scale"):
+                    remapped_kv_scale_name = name.replace(
+                        ".kv_scale", ".attn.kv_scale")
+                    if remapped_kv_scale_name not in params_dict:
+                        print_warning_once(
+                            f"Found kv scale in the checkpoint (e.g. {name}), "
+                            "but not found the expected name in the model "
+                            f"(e.g. {remapped_kv_scale_name}). kv-scale is "
+                            "not loaded.")
+                        continue
+                    else:
+                        name = remapped_kv_scale_name
                 param = params_dict[name]
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                                         default_weight_loader)