Forráskód Böngészése

feat: support FP8 for DeepSeekV2 MoE

AlpinDale 6 hónapja
szülő
commit
1efd0f89b7

+ 24 - 10
aphrodite/modeling/layers/fused_moe/fused_moe.py

@@ -394,14 +394,15 @@ def fused_topk(
 
 
 # This is used by the Deepseek-V2 model
-def grouped_topk(
-    hidden_states: torch.Tensor,
-    gating_output: torch.Tensor,
-    topk: int,
-    renormalize: bool,
-    num_expert_group: int = 0,
-    topk_group: int = 0,
-):
+def grouped_topk(hidden_states: torch.Tensor,
+                 gating_output: torch.Tensor,
+                 topk: int,
+                 renormalize: bool,
+                 num_expert_group: int = 0,
+                 topk_group: int = 0):
+
+    assert hidden_states.shape[0] == gating_output.shape[0], (
+        "Number of tokens mismatch")
     scores = torch.softmax(gating_output, dim=-1)
     num_token = scores.shape[0]
     group_scores = scores.view(num_token, num_expert_group,
@@ -556,6 +557,9 @@ def fused_moe(
     renormalize: bool,
     inplace: bool = False,
     override_config: Optional[Dict[str, Any]] = None,
+    use_grouped_topk: bool = False,
+    num_expert_group: Optional[int] = None,
+    topk_group: Optional[int] = None,
     use_fp8: bool = False,
     w1_scale: Optional[torch.Tensor] = None,
     w2_scale: Optional[torch.Tensor] = None,
@@ -578,6 +582,10 @@ def fused_moe(
         Defaults to False.
     - override_config (Optional[Dict[str, Any]]): Optional override
         for the kernel configuration.
+    - num_expert_group: Optional[int]: additional parameter for grouped_topk
+    - topk_group: Optional[int]: additional parameter for grouped_topk
+    - use_grouped_topk: If True, use grouped_topk instead of fused_topk
+        note: Deepseekv2 model uses grouped_topk
     - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
         products for w1 and w2. Defaults to False.
     - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
@@ -591,8 +599,14 @@ def fused_moe(
     # Check constraints.
     assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
 
-    topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
-                                        renormalize)
+    if use_grouped_topk:
+        assert num_expert_group is not None and topk_group is not None
+        topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
+                                              topk, renormalize,
+                                              num_expert_group, topk_group)
+    else:
+        topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
+                                            renormalize)
     return fused_experts(hidden_states,
                          w1,
                          w2,

+ 78 - 15
aphrodite/modeling/layers/fused_moe/layer.py

@@ -1,5 +1,5 @@
 from abc import abstractmethod
-from typing import Optional
+from typing import List, Optional, Tuple
 
 import torch
 
@@ -26,7 +26,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
               x: torch.Tensor,
               router_logits: torch.Tensor,
               top_k: int,
-              renormalize: bool = True) -> torch.Tensor:
+              renormalize: bool = True,
+              use_grouped_topk: bool = False,
+              num_expert_group: Optional[int] = None,
+              topk_group: Optional[int] = None) -> torch.Tensor:
         raise NotImplementedError
 
 
@@ -60,7 +63,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
               x: torch.Tensor,
               router_logits: torch.Tensor,
               top_k: int,
-              renormalize: bool = True) -> torch.Tensor:
+              renormalize: bool = True,
+              use_grouped_topk: bool = False,
+              num_expert_group: Optional[int] = None,
+              topk_group: Optional[int] = None) -> torch.Tensor:
 
         return fused_moe(x,
                          layer.w13_weight,
@@ -68,7 +74,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
                          router_logits,
                          top_k,
                          renormalize=renormalize,
-                         inplace=True)
+                         inplace=True,
+                         use_grouped_topk=use_grouped_topk,
+                         num_expert_group=num_expert_group,
+                         topk_group=topk_group)
 
 
 class FusedMoE(torch.nn.Module):
@@ -98,6 +107,9 @@ class FusedMoE(torch.nn.Module):
         params_dtype: Optional[torch.dtype] = None,
         reduce_results: bool = False,
         renormalize: bool = True,
+        use_grouped_topk: bool = False,
+        num_expert_group: Optional[int] = None,
+        topk_group: Optional[int] = None,
         quant_config: Optional[QuantizationConfig] = None,
         tp_size: Optional[int] = None,
     ):
@@ -113,6 +125,11 @@ class FusedMoE(torch.nn.Module):
         self.intermediate_size_per_partition = intermediate_size // self.tp_size
         self.reduce_results = reduce_results
         self.renormalize = renormalize
+        self.use_grouped_topk = use_grouped_topk
+        if self.use_grouped_topk:
+            assert num_expert_group is not None and topk_group is not None
+        self.num_expert_group = num_expert_group
+        self.topk_group = topk_group
 
         if quant_config is None:
             self.quant_method: Optional[QuantizeMethodBase] = (
@@ -134,9 +151,8 @@ class FusedMoE(torch.nn.Module):
                       shard_id: int, expert_id: int):
         param_data = param.data
 
-        # FIXME: Overfit to Mixtral.
-        # Follow up PR to enable fp8 for other MoE models.
-        if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
+        # Input scales can be loaded directly and should be equal.
+        if "input_scale" in weight_name:
             if param_data[expert_id] != 1 and (param_data[expert_id] -
                                                loaded_weight).abs() > 1e-5:
                 raise ValueError(
@@ -144,14 +160,21 @@ class FusedMoE(torch.nn.Module):
                     f"must be equal. But got {param_data[expert_id]} "
                     f"vs. {loaded_weight}")
             param_data[expert_id] = loaded_weight
-        # FIXME: Overfit to Mixtral.
-        # Follow up PR to enable fp8 for other MoE models.
+        # Weight scales
         elif "weight_scale" in weight_name:
-            # We have to keep the weight scales of w1 and w3 because
-            # we need to re-quantize w1/w3 weights after weight loading.
-            assert "w1" in weight_name or "w3" in weight_name
-            shard_id = 0 if "w1" in weight_name else 1
-            param_data[expert_id][shard_id] = loaded_weight
+            # If we are in merged column case (gate_up_proj)
+            #   shard_id 0 == gate_proj / w1
+            #   shard_id 2 == up_proj / w3
+            if shard_id == 0 or shard_id == 2:
+                # We have to keep the weight scales of w1 and w3 because
+                # we need to re-quantize w1/w3 weights after weight loading.
+                idx = 0 if shard_id == 0 else 1
+                param_data[expert_id][idx] = loaded_weight
+            # If we are in the row parallel case (down_proj)
+            #   shard_id 1 == down_proj / w2
+            else:
+                param_data[expert_id] = loaded_weight
+        # Weights
         else:
             tp_rank = get_tensor_model_parallel_rank()
             shard_size = self.intermediate_size_per_partition
@@ -182,10 +205,50 @@ class FusedMoE(torch.nn.Module):
             x=hidden_states,
             router_logits=router_logits,
             top_k=self.top_k,
-            renormalize=self.renormalize)
+            renormalize=self.renormalize,
+            use_grouped_topk=self.use_grouped_topk,
+            num_expert_group=self.num_expert_group,
+            topk_group=self.topk_group)
 
         if self.reduce_results and self.tp_size > 1:
             final_hidden_states = tensor_model_parallel_all_reduce(
                 final_hidden_states)
 
         return final_hidden_states
+
+    @classmethod
+    def make_expert_params_mapping(
+            cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
+            ckpt_up_proj_name: str,
+            num_experts: int) -> List[Tuple[str, str, int, int]]:
+
+        gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
+        gate_down_up = [
+            ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
+        ]
+
+        return [
+            # These are the weight scales for the experts
+            # (param_name, weight_name, expert_id, shard_id)
+            ("experts.w13_scale"
+             if weight_name in gate_up else "experts.w2_scale",
+             f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
+             shard_id) for expert_id in range(num_experts)
+            for shard_id, weight_name in enumerate(gate_down_up)
+        ] + [
+            # These are the weights for the experts
+            # (param_name, weight_name, expert_id, shard_id)
+            ("experts.w13_weight"
+             if weight_name in gate_up else "experts.w2_weight",
+             f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
+            for expert_id in range(num_experts)
+            for shard_id, weight_name in enumerate(gate_down_up)
+        ] + [
+            # These are the weight scales for the experts
+            # (param_name, weight_name, expert_id, shard_id)
+            ("experts.a13_scale"
+             if weight_name in gate_up else "experts.a2_scale",
+             f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
+             shard_id) for expert_id in range(num_experts)
+            for shard_id, weight_name in enumerate(gate_down_up)
+        ]

+ 69 - 73
aphrodite/modeling/models/deepseek_v2.py

@@ -31,11 +31,10 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size,
+from aphrodite.distributed import (get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.fused_moe import fused_experts, grouped_topk
+from aphrodite.modeling.layers.fused_moe import FusedMoE
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
@@ -91,32 +90,34 @@ class DeepseekV2MoE(nn.Module):
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
-        self.config = config
-        self.rank = get_tensor_model_parallel_rank()
         self.tp_size = get_tensor_model_parallel_world_size()
-        self.n_routed_experts = config.n_routed_experts
-        self.top_k = config.num_experts_per_tok
         self.routed_scaling_factor = config.routed_scaling_factor
-        if self.tp_size > self.n_routed_experts:
+        self.n_shared_experts = config.n_shared_experts
+        self.routed_scaling_factor = config.routed_scaling_factor
+        if self.tp_size > config.n_routed_experts:
             raise ValueError(
                 f"Tensor parallel size {self.tp_size} is greater than "
-                f"the number of experts {self.n_routed_experts}.")
-
-        self.experts = nn.ModuleList([
-            DeepseekV2MLP(hidden_size=config.hidden_size,
-                          intermediate_size=config.moe_intermediate_size,
-                          hidden_act=config.hidden_act,
-                          quant_config=quant_config,
-                          reduce_results=False)
-            for idx in range(self.n_routed_experts)
-        ])
-        self.pack_params()
+                f"the number of experts {config.n_routed_experts}.")
+
+        if config.hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {config.hidden_act}. "
+                             "Only silu is supported for now.")
+
+        self.experts = FusedMoE(num_experts=config.n_routed_experts,
+                                top_k=config.num_experts_per_tok,
+                                hidden_size=config.hidden_size,
+                                intermediate_size=config.moe_intermediate_size,
+                                reduce_results=False,
+                                renormalize=config.norm_topk_prob,
+                                quant_config=quant_config,
+                                use_grouped_topk=True,
+                                num_expert_group=config.n_group,
+                                topk_group=config.topk_group)
 
         self.gate = ReplicatedLinear(config.hidden_size,
-                                     self.n_routed_experts,
+                                     config.n_routed_experts,
                                      bias=False,
                                      quant_config=None)
-
         if config.n_shared_experts is not None:
             intermediate_size = (config.moe_intermediate_size *
                                  config.n_shared_experts)
@@ -128,50 +129,21 @@ class DeepseekV2MoE(nn.Module):
                 reduce_results=False,
             )
 
-    def pack_params(self):
-        w1 = []
-        w2 = []
-        for expert in self.experts:
-            w1.append(expert.gate_up_proj.weight)
-            w2.append(expert.down_proj.weight)
-        self.w1 = torch._utils._flatten_dense_tensors(w1)
-        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
-        for data, param in zip(w1s, w1):
-            param.data = data
-        self.w1 = self.w1.view(len(w1), *w1s[0].shape)
-
-        self.w2 = torch._utils._flatten_dense_tensors(w2)
-        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
-        for data, param in zip(w2s, w2):
-            param.data = data
-
-        self.w2 = self.w2.view(len(w2), *w2s[0].shape)
-
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         num_tokens, hidden_dim = hidden_states.shape
         hidden_states = hidden_states.view(-1, hidden_dim)
-        if self.config.n_shared_experts is not None:
+        if self.n_shared_experts is not None:
             shared_output = self.shared_experts(hidden_states)
         # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
-        topk_weights, topk_ids = grouped_topk(
-            hidden_states,
-            router_logits,
-            self.top_k,
-            renormalize=self.config.norm_topk_prob,
-            num_expert_group=self.config.n_group,
-            topk_group=self.config.topk_group)
-        final_hidden_states = fused_experts(
-            hidden_states,
-            self.w1,
-            self.w2,
-            topk_weights,
-            topk_ids,
-            inplace=True) * self.routed_scaling_factor
-        if self.config.n_shared_experts is not None:
+        final_hidden_states = self.experts(
+            hidden_states=hidden_states,
+            router_logits=router_logits) * self.routed_scaling_factor
+        if shared_output is not None:
             final_hidden_states = final_hidden_states + shared_output
-        final_hidden_states = tensor_model_parallel_all_reduce(
-            final_hidden_states)
+        if self.tp_size > 1:
+            final_hidden_states = tensor_model_parallel_all_reduce(
+                final_hidden_states)
 
         return final_hidden_states.view(num_tokens, hidden_dim)
 
@@ -504,34 +476,58 @@ class DeepseekV2ForCausalLM(nn.Module):
             ("gate_up_proj", "up_proj", 1),
         ]
 
+        # Params for weights, fp8 weight scales, fp8 activation scales
+        # (param_name, weight_name, expert_id, shard_id)
+        expert_params_mapping = FusedMoE.make_expert_params_mapping(
+            ckpt_gate_proj_name="gate_proj",
+            ckpt_down_proj_name="down_proj",
+            ckpt_up_proj_name="up_proj",
+            num_experts=self.config.n_routed_experts)
+
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                # Skip non-stacked layers and experts (experts handled below).
                 if weight_name not in name:
                     continue
+                # We have mlp.experts[0].gate_proj in the checkpoint.
+                # Since we handle the experts below in expert_params_mapping,
+                # we need to skip here BEFORE we update the name, otherwise
+                # name will be updated to mlp.experts[0].gate_up_proj, which
+                # will then be updated below in expert_params_mapping
+                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+                if (("mlp.experts." in name) and name not in params_dict):
+                    continue
                 name = name.replace(weight_name, param_name)
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:
                     continue
-                # Skip experts that are not assigned to this worker.
-                if (("mlp.experts." in name or "mlp.shared_experts." in name)
-                        and name not in params_dict):
-                    continue
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                # Skip loading extra bias for GPTQ models.
-                if name.endswith(".bias") and name not in params_dict:
-                    continue
-                # Skip experts that are not assigned to this worker.
-                if (("mlp.experts." in name or "mlp.shared_experts." in name)
-                        and name not in params_dict):
-                    continue
-                param = params_dict[name]
-                weight_loader = getattr(param, "weight_loader",
-                                        default_weight_loader)
-                weight_loader(param, loaded_weight)
+                for mapping in expert_params_mapping:
+                    param_name, weight_name, expert_id, shard_id = mapping
+                    if weight_name not in name:
+                        continue
+                    name = name.replace(weight_name, param_name)
+                    param = params_dict[name]
+                    weight_loader = param.weight_loader
+                    weight_loader(param,
+                                  loaded_weight,
+                                  weight_name,
+                                  shard_id=shard_id,
+                                  expert_id=expert_id)
+                    break
+                else:
+                    # Skip loading extra bias for GPTQ models.
+                    if name.endswith(".bias") and name not in params_dict:
+                        continue
+
+                    param = params_dict[name]
+                    weight_loader = getattr(param, "weight_loader",
+                                            default_weight_loader)
+                    weight_loader(param, loaded_weight)

+ 7 - 25
aphrodite/modeling/models/mixtral.py

@@ -371,31 +371,13 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
             ("qkv_proj", "v_proj", "v"),
         ]
 
-        expert_params_mapping = [
-            # These are the weight scales for the experts
-            # (param_name, weight_name, expert_id, shard_id)
-            ("experts.w13_scale"
-             if weight_name in ["w1", "w3"] else "experts.w2_scale",
-             f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
-             shard_id) for expert_id in range(self.config.num_local_experts)
-            for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
-        ] + [
-            # These are the weights for the experts
-            # (param_name, weight_name, expert_id)
-            ("experts.w13_weight"
-             if weight_name in ["w1", "w3"] else "experts.w2_weight",
-             f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
-            for expert_id in range(self.config.num_local_experts)
-            for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
-        ] + [
-            # These are the activation scales for the experts
-            # (param_name, weight_name, expert_id)
-            ("experts.a13_scale"
-             if weight_name in ["w1", "w3"] else "experts.a2_scale",
-             f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
-             shard_id) for expert_id in range(self.config.num_local_experts)
-            for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
-        ]
+        # Params for weights, fp8 weight scales, fp8 activation scales
+        # (param_name, weight_name, expert_id, shard_id)
+        expert_params_mapping = FusedMoE.make_expert_params_mapping(
+            ckpt_gate_proj_name="w1",
+            ckpt_down_proj_name="w2",
+            ckpt_up_proj_name="w3",
+            num_experts=self.config.num_local_experts)
 
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in weights:

+ 22 - 11
aphrodite/modeling/models/qwen2_moe.py

@@ -32,6 +32,7 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import print_warning_once
 from aphrodite.distributed import (get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -405,15 +406,13 @@ class Qwen2MoeForCausalLM(nn.Module):
             ("gate_up_proj", "up_proj", 1),
         ]
 
-        expert_params_mapping = [
-            # These are the weights for the experts
-            # (param_name, weight_name, expert_id, shard_id)
-            ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
-             else "experts.w2_weight",
-             f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
-            for expert_id in range(self.config.num_experts) for shard_id,
-            weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
-        ]
+        # Params for weights, fp8 weight scales, fp8 activation scales
+        # (param_name, weight_name, expert_id, shard_id)
+        expert_params_mapping = FusedMoE.make_expert_params_mapping(
+            ckpt_gate_proj_name="gate_proj",
+            ckpt_down_proj_name="down_proj",
+            ckpt_up_proj_name="up_proj",
+            num_experts=self.config.num_experts)
 
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in weights:
@@ -460,8 +459,20 @@ class Qwen2MoeForCausalLM(nn.Module):
                     # Skip loading extra bias for GPTQ models.
                     if name.endswith(".bias") and name not in params_dict:
                         continue
-                    if name not in params_dict:
-                        continue
+                    # Remapping the name of FP8 kv-scale.
+                    if name.endswith("kv_scale"):
+                        remapped_kv_scale_name = name.replace(
+                            ".kv_scale", ".attn.kv_scale")
+                        if remapped_kv_scale_name not in params_dict:
+                            print_warning_once(
+                                "Found kv scale in the checkpoint "
+                                f"(e.g. {name}), but not found the expected "
+                                f"name in the model "
+                                f"(e.g. {remapped_kv_scale_name}). "
+                                "kv-scale is not loaded.")
+                            continue
+                        else:
+                            name = remapped_kv_scale_name
 
                     param = params_dict[name]
                     weight_loader = getattr(param, "weight_loader",

+ 8 - 2
aphrodite/quantization/fp8.py

@@ -376,7 +376,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
               x: torch.Tensor,
               router_logits: torch.Tensor,
               top_k: int,
-              renormalize: bool = True) -> torch.Tensor:
+              renormalize: bool = True,
+              use_grouped_topk: bool = False,
+              num_expert_group: Optional[int] = None,
+              topk_group: Optional[int] = None) -> torch.Tensor:
 
         return fused_moe(x,
                          layer.w13_weight,
@@ -389,7 +392,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
                          w1_scale=layer.w13_scale,
                          w2_scale=layer.w2_scale,
                          a1_scale=layer.a13_scale,
-                         a2_scale=layer.a2_scale)
+                         a2_scale=layer.a2_scale,
+                         use_grouped_topk=use_grouped_topk,
+                         num_expert_group=num_expert_group,
+                         topk_group=topk_group)
 
 
 class Fp8KVCacheMethod(QuantizeMethodBase):