1
0
Эх сурвалжийг харах

fix: fp8 kv cache for qwen2 models

AlpinDale 7 сар өмнө
parent
commit
025322ee5f

+ 14 - 0
aphrodite/modeling/models/qwen2.py

@@ -31,6 +31,7 @@ from transformers import Qwen2Config
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.utils import print_warning_once
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -373,6 +374,19 @@ class Qwen2ForCausalLM(nn.Module):
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:
                     continue
+                # Remapping the name of FP8 kv-scale.
+                if name.endswith("kv_scale"):
+                    remapped_kv_scale_name = name.replace(
+                        ".kv_scale", ".attn.kv_scale")
+                    if remapped_kv_scale_name not in params_dict:
+                        print_warning_once(
+                            f"Found kv scale in the checkpoint (e.g. {name}), "
+                            "but not found the expected name in the model "
+                            f"(e.g. {remapped_kv_scale_name}). kv-scale is "
+                            "not loaded.")
+                        continue
+                    else:
+                        name = remapped_kv_scale_name
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)