Browse Source

feat: support FP8 for DeepSeekV2 MoE

AlpinDale 6 months ago
parent
commit
1efd0f89b7

+ 24 - 10
aphrodite/modeling/layers/fused_moe/fused_moe.py

@@ -394,14 +394,15 @@ def fused_topk(
 
 
 
 
 # This is used by the Deepseek-V2 model
 # This is used by the Deepseek-V2 model
-def grouped_topk(
-    hidden_states: torch.Tensor,
-    gating_output: torch.Tensor,
-    topk: int,
-    renormalize: bool,
-    num_expert_group: int = 0,
-    topk_group: int = 0,
-):
+def grouped_topk(hidden_states: torch.Tensor,
+                 gating_output: torch.Tensor,
+                 topk: int,
+                 renormalize: bool,
+                 num_expert_group: int = 0,
+                 topk_group: int = 0):
+
+    assert hidden_states.shape[0] == gating_output.shape[0], (
+        "Number of tokens mismatch")
     scores = torch.softmax(gating_output, dim=-1)
     scores = torch.softmax(gating_output, dim=-1)
     num_token = scores.shape[0]
     num_token = scores.shape[0]
     group_scores = scores.view(num_token, num_expert_group,
     group_scores = scores.view(num_token, num_expert_group,
@@ -556,6 +557,9 @@ def fused_moe(
     renormalize: bool,
     renormalize: bool,
     inplace: bool = False,
     inplace: bool = False,
     override_config: Optional[Dict[str, Any]] = None,
     override_config: Optional[Dict[str, Any]] = None,
+    use_grouped_topk: bool = False,
+    num_expert_group: Optional[int] = None,
+    topk_group: Optional[int] = None,
     use_fp8: bool = False,
     use_fp8: bool = False,
     w1_scale: Optional[torch.Tensor] = None,
     w1_scale: Optional[torch.Tensor] = None,
     w2_scale: Optional[torch.Tensor] = None,
     w2_scale: Optional[torch.Tensor] = None,
@@ -578,6 +582,10 @@ def fused_moe(
         Defaults to False.
         Defaults to False.
     - override_config (Optional[Dict[str, Any]]): Optional override
     - override_config (Optional[Dict[str, Any]]): Optional override
         for the kernel configuration.
         for the kernel configuration.
+    - num_expert_group: Optional[int]: additional parameter for grouped_topk
+    - topk_group: Optional[int]: additional parameter for grouped_topk
+    - use_grouped_topk: If True, use grouped_topk instead of fused_topk
+        note: Deepseekv2 model uses grouped_topk
     - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
     - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
         products for w1 and w2. Defaults to False.
         products for w1 and w2. Defaults to False.
     - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
     - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
@@ -591,8 +599,14 @@ def fused_moe(
     # Check constraints.
     # Check constraints.
     assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
     assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
 
 
-    topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
-                                        renormalize)
+    if use_grouped_topk:
+        assert num_expert_group is not None and topk_group is not None
+        topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
+                                              topk, renormalize,
+                                              num_expert_group, topk_group)
+    else:
+        topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
+                                            renormalize)
     return fused_experts(hidden_states,
     return fused_experts(hidden_states,
                          w1,
                          w1,
                          w2,
                          w2,

+ 78 - 15
aphrodite/modeling/layers/fused_moe/layer.py

@@ -1,5 +1,5 @@
 from abc import abstractmethod
 from abc import abstractmethod
-from typing import Optional
+from typing import List, Optional, Tuple
 
 
 import torch
 import torch
 
 
@@ -26,7 +26,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
               x: torch.Tensor,
               x: torch.Tensor,
               router_logits: torch.Tensor,
               router_logits: torch.Tensor,
               top_k: int,
               top_k: int,
-              renormalize: bool = True) -> torch.Tensor:
+              renormalize: bool = True,
+              use_grouped_topk: bool = False,
+              num_expert_group: Optional[int] = None,
+              topk_group: Optional[int] = None) -> torch.Tensor:
         raise NotImplementedError
         raise NotImplementedError
 
 
 
 
@@ -60,7 +63,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
               x: torch.Tensor,
               x: torch.Tensor,
               router_logits: torch.Tensor,
               router_logits: torch.Tensor,
               top_k: int,
               top_k: int,
-              renormalize: bool = True) -> torch.Tensor:
+              renormalize: bool = True,
+              use_grouped_topk: bool = False,
+              num_expert_group: Optional[int] = None,
+              topk_group: Optional[int] = None) -> torch.Tensor:
 
 
         return fused_moe(x,
         return fused_moe(x,
                          layer.w13_weight,
                          layer.w13_weight,
@@ -68,7 +74,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
                          router_logits,
                          router_logits,
                          top_k,
                          top_k,
                          renormalize=renormalize,
                          renormalize=renormalize,
-                         inplace=True)
+                         inplace=True,
+                         use_grouped_topk=use_grouped_topk,
+                         num_expert_group=num_expert_group,
+                         topk_group=topk_group)
 
 
 
 
 class FusedMoE(torch.nn.Module):
 class FusedMoE(torch.nn.Module):
@@ -98,6 +107,9 @@ class FusedMoE(torch.nn.Module):
         params_dtype: Optional[torch.dtype] = None,
         params_dtype: Optional[torch.dtype] = None,
         reduce_results: bool = False,
         reduce_results: bool = False,
         renormalize: bool = True,
         renormalize: bool = True,
+        use_grouped_topk: bool = False,
+        num_expert_group: Optional[int] = None,
+        topk_group: Optional[int] = None,
         quant_config: Optional[QuantizationConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         tp_size: Optional[int] = None,
         tp_size: Optional[int] = None,
     ):
     ):
@@ -113,6 +125,11 @@ class FusedMoE(torch.nn.Module):
         self.intermediate_size_per_partition = intermediate_size // self.tp_size
         self.intermediate_size_per_partition = intermediate_size // self.tp_size
         self.reduce_results = reduce_results
         self.reduce_results = reduce_results
         self.renormalize = renormalize
         self.renormalize = renormalize
+        self.use_grouped_topk = use_grouped_topk
+        if self.use_grouped_topk:
+            assert num_expert_group is not None and topk_group is not None
+        self.num_expert_group = num_expert_group
+        self.topk_group = topk_group
 
 
         if quant_config is None:
         if quant_config is None:
             self.quant_method: Optional[QuantizeMethodBase] = (
             self.quant_method: Optional[QuantizeMethodBase] = (
@@ -134,9 +151,8 @@ class FusedMoE(torch.nn.Module):
                       shard_id: int, expert_id: int):
                       shard_id: int, expert_id: int):
         param_data = param.data
         param_data = param.data
 
 
-        # FIXME: Overfit to Mixtral.
-        # Follow up PR to enable fp8 for other MoE models.
-        if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
+        # Input scales can be loaded directly and should be equal.
+        if "input_scale" in weight_name:
             if param_data[expert_id] != 1 and (param_data[expert_id] -
             if param_data[expert_id] != 1 and (param_data[expert_id] -
                                                loaded_weight).abs() > 1e-5:
                                                loaded_weight).abs() > 1e-5:
                 raise ValueError(
                 raise ValueError(
@@ -144,14 +160,21 @@ class FusedMoE(torch.nn.Module):
                     f"must be equal. But got {param_data[expert_id]} "
                     f"must be equal. But got {param_data[expert_id]} "
                     f"vs. {loaded_weight}")
                     f"vs. {loaded_weight}")
             param_data[expert_id] = loaded_weight
             param_data[expert_id] = loaded_weight
-        # FIXME: Overfit to Mixtral.
-        # Follow up PR to enable fp8 for other MoE models.
+        # Weight scales
         elif "weight_scale" in weight_name:
         elif "weight_scale" in weight_name:
-            # We have to keep the weight scales of w1 and w3 because
-            # we need to re-quantize w1/w3 weights after weight loading.
-            assert "w1" in weight_name or "w3" in weight_name
-            shard_id = 0 if "w1" in weight_name else 1
-            param_data[expert_id][shard_id] = loaded_weight
+            # If we are in merged column case (gate_up_proj)
+            #   shard_id 0 == gate_proj / w1
+            #   shard_id 2 == up_proj / w3
+            if shard_id == 0 or shard_id == 2:
+                # We have to keep the weight scales of w1 and w3 because
+                # we need to re-quantize w1/w3 weights after weight loading.
+                idx = 0 if shard_id == 0 else 1
+                param_data[expert_id][idx] = loaded_weight
+            # If we are in the row parallel case (down_proj)
+            #   shard_id 1 == down_proj / w2
+            else:
+                param_data[expert_id] = loaded_weight
+        # Weights
         else:
         else:
             tp_rank = get_tensor_model_parallel_rank()
             tp_rank = get_tensor_model_parallel_rank()
             shard_size = self.intermediate_size_per_partition
             shard_size = self.intermediate_size_per_partition
@@ -182,10 +205,50 @@ class FusedMoE(torch.nn.Module):
             x=hidden_states,
             x=hidden_states,
             router_logits=router_logits,
             router_logits=router_logits,
             top_k=self.top_k,
             top_k=self.top_k,
-            renormalize=self.renormalize)
+            renormalize=self.renormalize,
+            use_grouped_topk=self.use_grouped_topk,
+            num_expert_group=self.num_expert_group,
+            topk_group=self.topk_group)
 
 
         if self.reduce_results and self.tp_size > 1:
         if self.reduce_results and self.tp_size > 1:
             final_hidden_states = tensor_model_parallel_all_reduce(
             final_hidden_states = tensor_model_parallel_all_reduce(
                 final_hidden_states)
                 final_hidden_states)
 
 
         return final_hidden_states
         return final_hidden_states
+
+    @classmethod
+    def make_expert_params_mapping(
+            cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
+            ckpt_up_proj_name: str,
+            num_experts: int) -> List[Tuple[str, str, int, int]]:
+
+        gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
+        gate_down_up = [
+            ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
+        ]
+
+        return [
+            # These are the weight scales for the experts
+            # (param_name, weight_name, expert_id, shard_id)
+            ("experts.w13_scale"
+             if weight_name in gate_up else "experts.w2_scale",
+             f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
+             shard_id) for expert_id in range(num_experts)
+            for shard_id, weight_name in enumerate(gate_down_up)
+        ] + [
+            # These are the weights for the experts
+            # (param_name, weight_name, expert_id, shard_id)
+            ("experts.w13_weight"
+             if weight_name in gate_up else "experts.w2_weight",
+             f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
+            for expert_id in range(num_experts)
+            for shard_id, weight_name in enumerate(gate_down_up)
+        ] + [
+            # These are the weight scales for the experts
+            # (param_name, weight_name, expert_id, shard_id)
+            ("experts.a13_scale"
+             if weight_name in gate_up else "experts.a2_scale",
+             f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
+             shard_id) for expert_id in range(num_experts)
+            for shard_id, weight_name in enumerate(gate_down_up)
+        ]

+ 69 - 73
aphrodite/modeling/models/deepseek_v2.py

@@ -31,11 +31,10 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size,
+from aphrodite.distributed import (get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.fused_moe import fused_experts, grouped_topk
+from aphrodite.modeling.layers.fused_moe import FusedMoE
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               MergedColumnParallelLinear,
@@ -91,32 +90,34 @@ class DeepseekV2MoE(nn.Module):
         quant_config: Optional[QuantizationConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
     ):
         super().__init__()
         super().__init__()
-        self.config = config
-        self.rank = get_tensor_model_parallel_rank()
         self.tp_size = get_tensor_model_parallel_world_size()
         self.tp_size = get_tensor_model_parallel_world_size()
-        self.n_routed_experts = config.n_routed_experts
-        self.top_k = config.num_experts_per_tok
         self.routed_scaling_factor = config.routed_scaling_factor
         self.routed_scaling_factor = config.routed_scaling_factor
-        if self.tp_size > self.n_routed_experts:
+        self.n_shared_experts = config.n_shared_experts
+        self.routed_scaling_factor = config.routed_scaling_factor
+        if self.tp_size > config.n_routed_experts:
             raise ValueError(
             raise ValueError(
                 f"Tensor parallel size {self.tp_size} is greater than "
                 f"Tensor parallel size {self.tp_size} is greater than "
-                f"the number of experts {self.n_routed_experts}.")
-
-        self.experts = nn.ModuleList([
-            DeepseekV2MLP(hidden_size=config.hidden_size,
-                          intermediate_size=config.moe_intermediate_size,
-                          hidden_act=config.hidden_act,
-                          quant_config=quant_config,
-                          reduce_results=False)
-            for idx in range(self.n_routed_experts)
-        ])
-        self.pack_params()
+                f"the number of experts {config.n_routed_experts}.")
+
+        if config.hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {config.hidden_act}. "
+                             "Only silu is supported for now.")
+
+        self.experts = FusedMoE(num_experts=config.n_routed_experts,
+                                top_k=config.num_experts_per_tok,
+                                hidden_size=config.hidden_size,
+                                intermediate_size=config.moe_intermediate_size,
+                                reduce_results=False,
+                                renormalize=config.norm_topk_prob,
+                                quant_config=quant_config,
+                                use_grouped_topk=True,
+                                num_expert_group=config.n_group,
+                                topk_group=config.topk_group)
 
 
         self.gate = ReplicatedLinear(config.hidden_size,
         self.gate = ReplicatedLinear(config.hidden_size,
-                                     self.n_routed_experts,
+                                     config.n_routed_experts,
                                      bias=False,
                                      bias=False,
                                      quant_config=None)
                                      quant_config=None)
-
         if config.n_shared_experts is not None:
         if config.n_shared_experts is not None:
             intermediate_size = (config.moe_intermediate_size *
             intermediate_size = (config.moe_intermediate_size *
                                  config.n_shared_experts)
                                  config.n_shared_experts)
@@ -128,50 +129,21 @@ class DeepseekV2MoE(nn.Module):
                 reduce_results=False,
                 reduce_results=False,
             )
             )
 
 
-    def pack_params(self):
-        w1 = []
-        w2 = []
-        for expert in self.experts:
-            w1.append(expert.gate_up_proj.weight)
-            w2.append(expert.down_proj.weight)
-        self.w1 = torch._utils._flatten_dense_tensors(w1)
-        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
-        for data, param in zip(w1s, w1):
-            param.data = data
-        self.w1 = self.w1.view(len(w1), *w1s[0].shape)
-
-        self.w2 = torch._utils._flatten_dense_tensors(w2)
-        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
-        for data, param in zip(w2s, w2):
-            param.data = data
-
-        self.w2 = self.w2.view(len(w2), *w2s[0].shape)
-
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         num_tokens, hidden_dim = hidden_states.shape
         num_tokens, hidden_dim = hidden_states.shape
         hidden_states = hidden_states.view(-1, hidden_dim)
         hidden_states = hidden_states.view(-1, hidden_dim)
-        if self.config.n_shared_experts is not None:
+        if self.n_shared_experts is not None:
             shared_output = self.shared_experts(hidden_states)
             shared_output = self.shared_experts(hidden_states)
         # router_logits: (num_tokens, n_experts)
         # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
         router_logits, _ = self.gate(hidden_states)
-        topk_weights, topk_ids = grouped_topk(
-            hidden_states,
-            router_logits,
-            self.top_k,
-            renormalize=self.config.norm_topk_prob,
-            num_expert_group=self.config.n_group,
-            topk_group=self.config.topk_group)
-        final_hidden_states = fused_experts(
-            hidden_states,
-            self.w1,
-            self.w2,
-            topk_weights,
-            topk_ids,
-            inplace=True) * self.routed_scaling_factor
-        if self.config.n_shared_experts is not None:
+        final_hidden_states = self.experts(
+            hidden_states=hidden_states,
+            router_logits=router_logits) * self.routed_scaling_factor
+        if shared_output is not None:
             final_hidden_states = final_hidden_states + shared_output
             final_hidden_states = final_hidden_states + shared_output
-        final_hidden_states = tensor_model_parallel_all_reduce(
-            final_hidden_states)
+        if self.tp_size > 1:
+            final_hidden_states = tensor_model_parallel_all_reduce(
+                final_hidden_states)
 
 
         return final_hidden_states.view(num_tokens, hidden_dim)
         return final_hidden_states.view(num_tokens, hidden_dim)
 
 
@@ -504,34 +476,58 @@ class DeepseekV2ForCausalLM(nn.Module):
             ("gate_up_proj", "up_proj", 1),
             ("gate_up_proj", "up_proj", 1),
         ]
         ]
 
 
+        # Params for weights, fp8 weight scales, fp8 activation scales
+        # (param_name, weight_name, expert_id, shard_id)
+        expert_params_mapping = FusedMoE.make_expert_params_mapping(
+            ckpt_gate_proj_name="gate_proj",
+            ckpt_down_proj_name="down_proj",
+            ckpt_up_proj_name="up_proj",
+            num_experts=self.config.n_routed_experts)
+
         params_dict = dict(self.named_parameters())
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in weights:
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
             if "rotary_emb.inv_freq" in name:
                 continue
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                # Skip non-stacked layers and experts (experts handled below).
                 if weight_name not in name:
                 if weight_name not in name:
                     continue
                     continue
+                # We have mlp.experts[0].gate_proj in the checkpoint.
+                # Since we handle the experts below in expert_params_mapping,
+                # we need to skip here BEFORE we update the name, otherwise
+                # name will be updated to mlp.experts[0].gate_up_proj, which
+                # will then be updated below in expert_params_mapping
+                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+                if (("mlp.experts." in name) and name not in params_dict):
+                    continue
                 name = name.replace(weight_name, param_name)
                 name = name.replace(weight_name, param_name)
                 # Skip loading extra bias for GPTQ models.
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:
                 if name.endswith(".bias") and name not in params_dict:
                     continue
                     continue
-                # Skip experts that are not assigned to this worker.
-                if (("mlp.experts." in name or "mlp.shared_experts." in name)
-                        and name not in params_dict):
-                    continue
                 param = params_dict[name]
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 weight_loader(param, loaded_weight, shard_id)
                 break
                 break
             else:
             else:
-                # Skip loading extra bias for GPTQ models.
-                if name.endswith(".bias") and name not in params_dict:
-                    continue
-                # Skip experts that are not assigned to this worker.
-                if (("mlp.experts." in name or "mlp.shared_experts." in name)
-                        and name not in params_dict):
-                    continue
-                param = params_dict[name]
-                weight_loader = getattr(param, "weight_loader",
-                                        default_weight_loader)
-                weight_loader(param, loaded_weight)
+                for mapping in expert_params_mapping:
+                    param_name, weight_name, expert_id, shard_id = mapping
+                    if weight_name not in name:
+                        continue
+                    name = name.replace(weight_name, param_name)
+                    param = params_dict[name]
+                    weight_loader = param.weight_loader
+                    weight_loader(param,
+                                  loaded_weight,
+                                  weight_name,
+                                  shard_id=shard_id,
+                                  expert_id=expert_id)
+                    break
+                else:
+                    # Skip loading extra bias for GPTQ models.
+                    if name.endswith(".bias") and name not in params_dict:
+                        continue
+
+                    param = params_dict[name]
+                    weight_loader = getattr(param, "weight_loader",
+                                            default_weight_loader)
+                    weight_loader(param, loaded_weight)

+ 7 - 25
aphrodite/modeling/models/mixtral.py

@@ -371,31 +371,13 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
             ("qkv_proj", "v_proj", "v"),
             ("qkv_proj", "v_proj", "v"),
         ]
         ]
 
 
-        expert_params_mapping = [
-            # These are the weight scales for the experts
-            # (param_name, weight_name, expert_id, shard_id)
-            ("experts.w13_scale"
-             if weight_name in ["w1", "w3"] else "experts.w2_scale",
-             f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
-             shard_id) for expert_id in range(self.config.num_local_experts)
-            for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
-        ] + [
-            # These are the weights for the experts
-            # (param_name, weight_name, expert_id)
-            ("experts.w13_weight"
-             if weight_name in ["w1", "w3"] else "experts.w2_weight",
-             f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
-            for expert_id in range(self.config.num_local_experts)
-            for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
-        ] + [
-            # These are the activation scales for the experts
-            # (param_name, weight_name, expert_id)
-            ("experts.a13_scale"
-             if weight_name in ["w1", "w3"] else "experts.a2_scale",
-             f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
-             shard_id) for expert_id in range(self.config.num_local_experts)
-            for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
-        ]
+        # Params for weights, fp8 weight scales, fp8 activation scales
+        # (param_name, weight_name, expert_id, shard_id)
+        expert_params_mapping = FusedMoE.make_expert_params_mapping(
+            ckpt_gate_proj_name="w1",
+            ckpt_down_proj_name="w2",
+            ckpt_up_proj_name="w3",
+            num_experts=self.config.num_local_experts)
 
 
         params_dict = dict(self.named_parameters())
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in weights:
         for name, loaded_weight in weights:

+ 22 - 11
aphrodite/modeling/models/qwen2_moe.py

@@ -32,6 +32,7 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import print_warning_once
 from aphrodite.distributed import (get_tensor_model_parallel_world_size,
 from aphrodite.distributed import (get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -405,15 +406,13 @@ class Qwen2MoeForCausalLM(nn.Module):
             ("gate_up_proj", "up_proj", 1),
             ("gate_up_proj", "up_proj", 1),
         ]
         ]
 
 
-        expert_params_mapping = [
-            # These are the weights for the experts
-            # (param_name, weight_name, expert_id, shard_id)
-            ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
-             else "experts.w2_weight",
-             f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
-            for expert_id in range(self.config.num_experts) for shard_id,
-            weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
-        ]
+        # Params for weights, fp8 weight scales, fp8 activation scales
+        # (param_name, weight_name, expert_id, shard_id)
+        expert_params_mapping = FusedMoE.make_expert_params_mapping(
+            ckpt_gate_proj_name="gate_proj",
+            ckpt_down_proj_name="down_proj",
+            ckpt_up_proj_name="up_proj",
+            num_experts=self.config.num_experts)
 
 
         params_dict = dict(self.named_parameters())
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in weights:
         for name, loaded_weight in weights:
@@ -460,8 +459,20 @@ class Qwen2MoeForCausalLM(nn.Module):
                     # Skip loading extra bias for GPTQ models.
                     # Skip loading extra bias for GPTQ models.
                     if name.endswith(".bias") and name not in params_dict:
                     if name.endswith(".bias") and name not in params_dict:
                         continue
                         continue
-                    if name not in params_dict:
-                        continue
+                    # Remapping the name of FP8 kv-scale.
+                    if name.endswith("kv_scale"):
+                        remapped_kv_scale_name = name.replace(
+                            ".kv_scale", ".attn.kv_scale")
+                        if remapped_kv_scale_name not in params_dict:
+                            print_warning_once(
+                                "Found kv scale in the checkpoint "
+                                f"(e.g. {name}), but not found the expected "
+                                f"name in the model "
+                                f"(e.g. {remapped_kv_scale_name}). "
+                                "kv-scale is not loaded.")
+                            continue
+                        else:
+                            name = remapped_kv_scale_name
 
 
                     param = params_dict[name]
                     param = params_dict[name]
                     weight_loader = getattr(param, "weight_loader",
                     weight_loader = getattr(param, "weight_loader",

+ 8 - 2
aphrodite/quantization/fp8.py

@@ -376,7 +376,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
               x: torch.Tensor,
               x: torch.Tensor,
               router_logits: torch.Tensor,
               router_logits: torch.Tensor,
               top_k: int,
               top_k: int,
-              renormalize: bool = True) -> torch.Tensor:
+              renormalize: bool = True,
+              use_grouped_topk: bool = False,
+              num_expert_group: Optional[int] = None,
+              topk_group: Optional[int] = None) -> torch.Tensor:
 
 
         return fused_moe(x,
         return fused_moe(x,
                          layer.w13_weight,
                          layer.w13_weight,
@@ -389,7 +392,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
                          w1_scale=layer.w13_scale,
                          w1_scale=layer.w13_scale,
                          w2_scale=layer.w2_scale,
                          w2_scale=layer.w2_scale,
                          a1_scale=layer.a13_scale,
                          a1_scale=layer.a13_scale,
-                         a2_scale=layer.a2_scale)
+                         a2_scale=layer.a2_scale,
+                         use_grouped_topk=use_grouped_topk,
+                         num_expert_group=num_expert_group,
+                         topk_group=topk_group)
 
 
 
 
 class Fp8KVCacheMethod(QuantizeMethodBase):
 class Fp8KVCacheMethod(QuantizeMethodBase):