فهرست منبع

refactor: isolate FP8 from mixtral

AlpinDale 6 ماه پیش
والد
کامیت
cf472315cc

+ 4 - 0
aphrodite/modeling/layers/fused_moe/__init__.py

@@ -1,5 +1,7 @@
 from aphrodite.modeling.layers.fused_moe.fused_moe import (
     fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
+from aphrodite.modeling.layers.fused_moe.layer import (FusedMoE,
+                                                       FusedMoEMethodBase)
 
 __all__ = [
     "fused_moe",
@@ -7,4 +9,6 @@ __all__ = [
     "fused_experts",
     "get_config_file_name",
     "grouped_topk",
+    "FusedMoE",
+    "FusedMoEMethodBase",
 ]

+ 191 - 0
aphrodite/modeling/layers/fused_moe/layer.py

@@ -0,0 +1,191 @@
+from abc import abstractmethod
+from typing import Optional
+
+import torch
+
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_all_reduce)
+from aphrodite.modeling.layers.fused_moe.fused_moe import fused_moe
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import (QuantizationConfig,
+                                                QuantizeMethodBase)
+
+
+class FusedMoEMethodBase(QuantizeMethodBase):
+
+    @abstractmethod
+    def create_weights(self, layer: torch.nn.Module, num_experts: int,
+                       hidden_size: int, intermediate_size: int,
+                       params_dtype: torch.dtype, **extra_weight_attrs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              router_logits: torch.Tensor,
+              top_k: int,
+              renormalize: bool = True) -> torch.Tensor:
+        raise NotImplementedError
+
+
+class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
+    """MoE method without quantization."""
+
+    def create_weights(self, layer: torch.nn.Module, num_experts: int,
+                       hidden_size: int, intermediate_size: int,
+                       params_dtype: torch.dtype, **extra_weight_attrs):
+
+        # Fused gate_up_proj (column parallel)
+        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
+                                                    2 * intermediate_size,
+                                                    hidden_size,
+                                                    dtype=params_dtype),
+                                        requires_grad=False)
+        layer.register_parameter("w13_weight", w13_weight)
+        set_weight_attrs(w13_weight, extra_weight_attrs)
+
+        # down_proj (row parallel)
+        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
+                                                   hidden_size,
+                                                   intermediate_size,
+                                                   dtype=params_dtype),
+                                       requires_grad=False)
+        layer.register_parameter("w2_weight", w2_weight)
+        set_weight_attrs(w2_weight, extra_weight_attrs)
+
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              router_logits: torch.Tensor,
+              top_k: int,
+              renormalize: bool = True) -> torch.Tensor:
+
+        return fused_moe(x,
+                         layer.w13_weight,
+                         layer.w2_weight,
+                         router_logits,
+                         top_k,
+                         renormalize=renormalize,
+                         inplace=True)
+
+
+class FusedMoE(torch.nn.Module):
+    """FusedMoE layer for MoE models.
+    This layer contains both MergedColumnParallel weights (gate_up_proj / 
+    w13) and RowParallelLinear weights (down_proj/ w2).
+    Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
+    copy that naming convention here and handle any remapping in the
+    load_weights function in each model implementation.
+    Args:
+        num_experts: Number of experts in the model
+        top_k: Number of experts selected for each token
+        hidden_size: Input hidden state size of the transformer
+        intermediate_size: Intermediate size of the experts
+        params_dtype: Data type for the parameters.
+        reduce_results: Whether to all all_reduce on the output of the layer
+        renomalize: Whether to renormalize the logits in the fused_moe kernel
+        quant_config: Quantization configure.
+    """
+
+    def __init__(
+        self,
+        num_experts: int,
+        top_k: int,
+        hidden_size: int,
+        intermediate_size: int,
+        params_dtype: Optional[torch.dtype] = None,
+        reduce_results: bool = False,
+        renormalize: bool = True,
+        quant_config: Optional[QuantizationConfig] = None,
+        tp_size: Optional[int] = None,
+    ):
+        super().__init__()
+
+        if params_dtype is None:
+            params_dtype = torch.get_default_dtype()
+
+        self.tp_size = (tp_size if tp_size is not None else
+                        get_tensor_model_parallel_world_size())
+        self.top_k = top_k
+        self.num_experts = num_experts
+        self.intermediate_size_per_partition = intermediate_size // self.tp_size
+        self.reduce_results = reduce_results
+        self.renormalize = renormalize
+
+        if quant_config is None:
+            self.quant_method: Optional[QuantizeMethodBase] = (
+                UnquantizedFusedMoEMethod())
+        else:
+            self.quant_method = quant_config.get_quant_method(self)
+        assert self.quant_method is not None
+
+        self.quant_method.create_weights(
+            layer=self,
+            num_experts=num_experts,
+            hidden_size=hidden_size,
+            intermediate_size=self.intermediate_size_per_partition,
+            params_dtype=params_dtype,
+            weight_loader=self.weight_loader)
+
+    def weight_loader(self, param: torch.nn.Parameter,
+                      loaded_weight: torch.Tensor, weight_name: str,
+                      shard_id: int, expert_id: int):
+        param_data = param.data
+
+        # FIXME: Overfit to Mixtral.
+        # Follow up PR to enable fp8 for other MoE models.
+        if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
+            if param_data[expert_id] != 1 and (param_data[expert_id] -
+                                               loaded_weight).abs() > 1e-5:
+                raise ValueError(
+                    "input_scales of w1 and w3 of a layer "
+                    f"must be equal. But got {param_data[expert_id]} "
+                    f"vs. {loaded_weight}")
+            param_data[expert_id] = loaded_weight
+        # FIXME: Overfit to Mixtral.
+        # Follow up PR to enable fp8 for other MoE models.
+        elif "weight_scale" in weight_name:
+            # We have to keep the weight scales of w1 and w3 because
+            # we need to re-quantize w1/w3 weights after weight loading.
+            assert "w1" in weight_name or "w3" in weight_name
+            shard_id = 0 if "w1" in weight_name else 1
+            param_data[expert_id][shard_id] = loaded_weight
+        else:
+            tp_rank = get_tensor_model_parallel_rank()
+            shard_size = self.intermediate_size_per_partition
+            shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
+
+            # w1, gate_proj case: Load into first shard of w13.
+            if shard_id == 0:
+                param_data[expert_id,
+                           0:shard_size, :] = loaded_weight[shard, :]
+            # w3, up_proj case: Load into second shard of w13.
+            elif shard_id == 2:
+                param_data[expert_id, shard_size:2 *
+                           shard_size, :] = loaded_weight[shard, :]
+            # w2, down_proj case: Load into only shard of w2.
+            elif shard_id == 1:
+                param_data[expert_id, :, :] = loaded_weight[:, shard]
+            else:
+                raise ValueError(
+                    f"Shard id must be in [0,1,2] but got {shard_id}")
+
+    def forward(self, hidden_states: torch.Tensor,
+                router_logits: torch.Tensor):
+        assert self.quant_method is not None
+
+        # Matrix multiply.
+        final_hidden_states = self.quant_method.apply(
+            self,
+            x=hidden_states,
+            router_logits=router_logits,
+            top_k=self.top_k,
+            renormalize=self.renormalize)
+
+        if self.reduce_results and self.tp_size > 1:
+            final_hidden_states = tensor_model_parallel_all_reduce(
+                final_hidden_states)
+
+        return final_hidden_states

+ 55 - 243
aphrodite/modeling/models/mixtral.py

@@ -27,15 +27,12 @@ import torch
 from torch import nn
 from transformers import MixtralConfig
 
-from aphrodite import _custom_ops as ops
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.utils import print_warning_once
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size,
-                                   tensor_model_parallel_all_reduce)
-from aphrodite.modeling.layers.fused_moe import fused_moe
+from aphrodite.distributed import get_tensor_model_parallel_world_size
+from aphrodite.modeling.layers.fused_moe import FusedMoE
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               ReplicatedLinear,
@@ -46,12 +43,10 @@ from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
-from aphrodite.modeling.models.interfaces import SupportsLoRA
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
-from aphrodite.quantization.fp8 import (Fp8Config, per_tensor_dequantize,
-                                        per_tensor_quantize)
+
+from .interfaces import SupportsLoRA
 
 
 class MixtralMoE(nn.Module):
@@ -63,240 +58,55 @@ class MixtralMoE(nn.Module):
     across ranks.
     """
 
-    def __init__(
-        self,
-        num_experts: int,
-        top_k: int,
-        hidden_size: int,
-        intermediate_size: int,
-        params_dtype: Optional[torch.dtype] = None,
-        tp_size: Optional[int] = None,
-        quant_config: Optional[QuantizationConfig] = None,
-    ):
+    def __init__(self,
+                 num_experts: int,
+                 top_k: int,
+                 hidden_size: int,
+                 intermediate_size: int,
+                 params_dtype: Optional[torch.dtype] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 tp_size: Optional[int] = None):
         super().__init__()
-        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
-        self.num_total_experts = num_experts
-        self.top_k = top_k
         self.hidden_size = hidden_size
-        self.intermediate_size = intermediate_size // self.tp_size
-        self.quant_config = quant_config
-
-        # FIXME(pcmoritz): Make this more general to support different
-        # quantization schemes
-        self.use_fp8 = isinstance(quant_config, Fp8Config)
-
-        if params_dtype is None:
-            params_dtype = torch.get_default_dtype()
-        self.params_dtype = params_dtype
 
         # Gate always runs at half / full precision for now.
-        self.gate = ReplicatedLinear(self.hidden_size,
-                                     self.num_total_experts,
+        self.gate = ReplicatedLinear(hidden_size,
+                                     num_experts,
                                      bias=False,
-                                     params_dtype=self.params_dtype,
+                                     params_dtype=params_dtype,
                                      quant_config=None)
 
-        if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
-            params_dtype = torch.float8_e4m3fn
-
-        self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts,
-                                                   2 * self.intermediate_size,
-                                                   self.hidden_size,
-                                                   dtype=params_dtype),
-                                       requires_grad=False)
-        self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts,
-                                                  self.hidden_size,
-                                                  self.intermediate_size,
-                                                  dtype=params_dtype),
-                                      requires_grad=False)
-
-        set_weight_attrs(self.w13_weight, {
-            "weight_loader": self.weight_loader,
-        })
-        set_weight_attrs(self.w2_weight, {
-            "weight_loader": self.weight_loader,
-        })
-
-        # Used for fp8.
-        self.w13_scale = None
-        self.w2_scale = None
-        self.a13_scale = None
-        self.a2_scale = None
-
-        if self.use_fp8:
-            # WEIGHT_SCALE (for fp8)
-            # Allocate 2 scales for w1 and w3 respectively.
-            # They will be combined to a single scale after weight loading.
-            self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
-                                                     2,
-                                                     dtype=torch.float32),
-                                          requires_grad=False)
-            self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
-                                                    dtype=torch.float32),
-                                         requires_grad=False)
-
-            # If loading fp8 checkpoint, pass the weight loaders.
-            # If loading an fp16 checkpoint, do not (we will quantize in
-            #   process_weights_after_loading()
-            if quant_config.is_checkpoint_fp8_serialized:
-                set_weight_attrs(self.w13_scale, {
-                    "weight_loader": self.weight_loader,
-                })
-                set_weight_attrs(self.w2_scale, {
-                    "weight_loader": self.weight_loader,
-                })
-
-            # ACT_SCALE (for fp8)
-            if quant_config.activation_scheme == "static":
-                if not quant_config.is_checkpoint_fp8_serialized:
-                    raise ValueError(
-                        "Found static activation scheme for checkpoint that "
-                        "was not serialized fp8.")
-                self.a13_scale = nn.Parameter(torch.ones(
-                    self.num_total_experts, dtype=torch.float32),
-                                              requires_grad=False)
-                self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts,
-                                                        dtype=torch.float32),
-                                             requires_grad=False)
-
-                set_weight_attrs(self.a13_scale, {
-                    "weight_loader": self.weight_loader,
-                })
-                set_weight_attrs(self.a2_scale, {
-                    "weight_loader": self.weight_loader,
-                })
-
-    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
-                      weight_name: str, expert_id: int):
-        tp_rank = get_tensor_model_parallel_rank()
-        param_data = param.data
-        shard_size = self.intermediate_size
-        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
-        if weight_name.endswith("w1.weight"):
-            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
-        if weight_name.endswith("w3.weight"):
-            param_data[expert_id,
-                       shard_size:2 * shard_size, :] = loaded_weight[shard, :]
-        if weight_name.endswith("w2.weight"):
-            param_data[expert_id, :, :] = loaded_weight[:, shard]
-
-        # Loading scales
-        if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
-            if param_data[expert_id] != 1 and (param_data[expert_id] -
-                                               loaded_weight).abs() > 1e-5:
-                raise ValueError(
-                    "act_scales of w1 and w3 of a layer "
-                    f"must be equal. But got {param_data[expert_id]} "
-                    f"vs. {loaded_weight}")
-            param_data[expert_id] = loaded_weight
-        elif "weight_scale" in weight_name:
-            # We have to keep the weight scales of w1 and w3 because
-            # we need to re-quantize w1/w3 weights after weight loading.
-            assert "w1" in weight_name or "w3" in weight_name
-            shard_id = 0 if "w1" in weight_name else 1
-            param_data[expert_id][shard_id] = loaded_weight
-
-    def process_weights_after_loading(self):
-        # Fp8 is the only case where we need to process after loading.
-        if not self.use_fp8:
-            return
-
-        # If checkpoint is fp16, quantize here.
-        if not self.quant_config.is_checkpoint_fp8_serialized:
-            w13_weight = torch.empty_like(self.w13_weight.data,
-                                          dtype=torch.float8_e4m3fn)
-            w2_weight = torch.empty_like(self.w2_weight.data,
-                                         dtype=torch.float8_e4m3fn)
-
-            # Re-initialize w13_scale because we directly quantize
-            # merged w13 weights and generate a single scaling factor.
-            self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
-                                                     dtype=torch.float32),
-                                          requires_grad=False)
-            for expert in range(self.num_total_experts):
-                w13_weight[expert, :, :], self.w13_scale[
-                    expert] = ops.scaled_fp8_quant(
-                        self.w13_weight.data[expert, :, :])
-                w2_weight[expert, :, :], self.w2_scale[
-                    expert] = ops.scaled_fp8_quant(
-                        self.w2_weight.data[expert, :, :])
-            self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
-            self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
-
-        else:
-            # If checkpoint is fp8 + static, cleanup act_scales.
-            #   Since state_dict has an input_scale per expert but our kernels
-            #   are passed one input_scale shared across all experts.
-            if self.quant_config.activation_scheme == "static":
-                if self.a13_scale is None or self.a2_scale is None:
-                    raise ValueError(
-                        "QuantConfig has static quantization, but found "
-                        "activation scales are None.")
-
-                if (not all_close_1d(self.a13_scale)
-                        or not all_close_1d(self.a2_scale)):
-                    print_warning_once(
-                        "Found act_scales that are not equal for "
-                        "fp8 MoE layer. Using the maximum across experts "
-                        "for each layer. ")
-
-                self.a13_scale = nn.Parameter(self.a13_scale.max(),
-                                              requires_grad=False)
-                self.a2_scale = nn.Parameter(self.a2_scale.max(),
-                                             requires_grad=False)
-
-            assert self.w13_scale is not None
-            shard_size = self.intermediate_size
-            max_w13_scales = self.w13_scale.max(dim=1).values
-            for expert_id in range(self.num_total_experts):
-                start = 0
-                for shard_id in range(2):
-                    dq_weight = per_tensor_dequantize(
-                        self.w13_weight[expert_id][start:start +
-                                                   shard_size, :],
-                        self.w13_scale[expert_id][shard_id])
-                    self.w13_weight[expert_id][
-                        start:start + shard_size, :] = per_tensor_quantize(
-                            dq_weight, max_w13_scales[expert_id])
-                    start += shard_size
-
-            self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False)
+        self.experts = FusedMoE(num_experts=num_experts,
+                                top_k=top_k,
+                                hidden_size=hidden_size,
+                                intermediate_size=intermediate_size,
+                                params_dtype=params_dtype,
+                                reduce_results=True,
+                                renormalize=True,
+                                quant_config=quant_config,
+                                tp_size=tp_size)
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         num_tokens, hidden_size = hidden_states.shape
         hidden_states = hidden_states.view(-1, self.hidden_size)
         # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
-        final_hidden_states = fused_moe(hidden_states,
-                                        self.w13_weight,
-                                        self.w2_weight,
-                                        router_logits,
-                                        self.top_k,
-                                        renormalize=True,
-                                        inplace=True,
-                                        use_fp8=self.use_fp8,
-                                        w1_scale=self.w13_scale,
-                                        w2_scale=self.w2_scale,
-                                        a1_scale=self.a13_scale,
-                                        a2_scale=self.a2_scale)
-
-        if self.tp_size > 1:
-            final_hidden_states = tensor_model_parallel_all_reduce(
-                final_hidden_states)
-
+        final_hidden_states = self.experts(hidden_states, router_logits)
         return final_hidden_states.view(num_tokens, hidden_size)
 
 
 class MixtralAttention(nn.Module):
 
-    def __init__(self,
-                 hidden_size: int,
-                 num_heads: int,
-                 num_kv_heads: int,
-                 max_position: int = 4096 * 32,
-                 rope_theta: float = 10000,
-                 cache_config: Optional[CacheConfig] = None,
-                 quant_config: Optional[QuantizationConfig] = None) -> None:
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        max_position: int = 4096 * 32,
+        rope_theta: float = 10000,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -501,8 +311,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
+
         self.config = config
         self.lora_config = lora_config
+
         self.model = MixtralModel(config,
                                   cache_config,
                                   quant_config,
@@ -559,25 +371,28 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
 
         expert_params_mapping = [
             # These are the weight scales for the experts
-            # (param_name, weight_name, expert_id)
-            ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
-             f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
-            for expert_id in range(self.config.num_local_experts)
-            for weight_name in ["w1", "w2", "w3"]
+            # (param_name, weight_name, expert_id, shard_id)
+            ("experts.w13_scale"
+             if weight_name in ["w1", "w3"] else "experts.w2_scale",
+             f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
+             shard_id) for expert_id in range(self.config.num_local_experts)
+            for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
         ] + [
             # These are the weights for the experts
             # (param_name, weight_name, expert_id)
-            ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
-             f"experts.{expert_id}.{weight_name}.weight", expert_id)
+            ("experts.w13_weight"
+             if weight_name in ["w1", "w3"] else "experts.w2_weight",
+             f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
             for expert_id in range(self.config.num_local_experts)
-            for weight_name in ["w1", "w2", "w3"]
+            for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
         ] + [
             # These are the activation scales for the experts
             # (param_name, weight_name, expert_id)
-            ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
-             f"experts.{expert_id}.{weight_name}.input_scale", expert_id)
-            for expert_id in range(self.config.num_local_experts)
-            for weight_name in ["w1", "w2", "w3"]
+            ("experts.a13_scale"
+             if weight_name in ["w1", "w3"] else "experts.a2_scale",
+             f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
+             shard_id) for expert_id in range(self.config.num_local_experts)
+            for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
         ]
 
         params_dict = dict(self.named_parameters())
@@ -597,7 +412,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                for param_name, weight_name, expert_id in expert_params_mapping:
+                for mapping in expert_params_mapping:
+                    param_name, weight_name, expert_id, shard_id = mapping
                     if weight_name not in name:
                         continue
                     name = name.replace(weight_name, param_name)
@@ -606,6 +422,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
                     weight_loader(param,
                                   loaded_weight,
                                   weight_name,
+                                  shard_id=shard_id,
                                   expert_id=expert_id)
                     break
                 else:
@@ -630,8 +447,3 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
-
-
-def all_close_1d(x: torch.Tensor) -> bool:
-    assert len(x.shape) == 1
-    return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))

+ 69 - 67
aphrodite/modeling/models/qwen2_moe.py

@@ -32,11 +32,10 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size,
+from aphrodite.distributed import (get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.fused_moe import fused_moe
+from aphrodite.modeling.layers.fused_moe import FusedMoE
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
@@ -92,28 +91,23 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
-        self.config = config
-        self.rank = get_tensor_model_parallel_rank()
         self.tp_size = get_tensor_model_parallel_world_size()
-        self.n_routed_experts = config.num_experts
-        self.top_k = config.num_experts_per_tok
-        if self.tp_size > self.n_routed_experts:
+
+        if self.tp_size > config.num_experts:
             raise ValueError(
                 f"Tensor parallel size {self.tp_size} is greater than "
-                f"the number of experts {self.n_routed_experts}.")
-
-        self.experts = nn.ModuleList([
-            Qwen2MoeMLP(hidden_size=config.hidden_size,
-                        intermediate_size=config.moe_intermediate_size,
-                        hidden_act=config.hidden_act,
-                        quant_config=quant_config,
-                        reduce_results=False)
-            for idx in range(self.n_routed_experts)
-        ])
-        self.pack_params()
+                f"the number of experts {config.num_experts}.")
+
+        self.experts = FusedMoE(num_experts=config.num_experts,
+                                top_k=config.num_experts_per_tok,
+                                hidden_size=config.hidden_size,
+                                intermediate_size=config.moe_intermediate_size,
+                                reduce_results=False,
+                                renormalize=config.norm_topk_prob,
+                                quant_config=quant_config)
 
         self.gate = ReplicatedLinear(config.hidden_size,
-                                     self.n_routed_experts,
+                                     config.num_experts,
                                      bias=False,
                                      quant_config=None)
         if config.shared_expert_intermediate_size > 0:
@@ -130,25 +124,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
                                                   1,
                                                   bias=False)
 
-    def pack_params(self):
-        w1 = []
-        w2 = []
-        for expert in self.experts:
-            w1.append(expert.gate_up_proj.weight)
-            w2.append(expert.down_proj.weight)
-        self.w1 = torch._utils._flatten_dense_tensors(w1)
-        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
-        for data, param in zip(w1s, w1):
-            param.data = data
-        self.w1 = self.w1.view(len(w1), *w1s[0].shape)
-
-        self.w2 = torch._utils._flatten_dense_tensors(w2)
-        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
-        for data, param in zip(w2s, w2):
-            param.data = data
-
-        self.w2 = self.w2.view(len(w2), *w2s[0].shape)
-
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         num_tokens, hidden_dim = hidden_states.shape
         hidden_states = hidden_states.view(-1, hidden_dim)
@@ -161,18 +136,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
 
         # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
-        final_hidden_states = fused_moe(hidden_states,
-                                        self.w1,
-                                        self.w2,
-                                        router_logits,
-                                        self.top_k,
-                                        renormalize=self.config.norm_topk_prob,
-                                        inplace=True)
-
+        final_hidden_states = self.experts(hidden_states=hidden_states,
+                                           router_logits=router_logits)
         if shared_output is not None:
             final_hidden_states = final_hidden_states + shared_output
-        final_hidden_states = tensor_model_parallel_all_reduce(
-            final_hidden_states)
+        if self.tp_size > 1:
+            final_hidden_states = tensor_model_parallel_all_reduce(
+                final_hidden_states)
 
         return final_hidden_states.view(num_tokens, hidden_dim)
 
@@ -283,7 +253,12 @@ class Qwen2MoeDecoderLayer(nn.Module):
             cache_config=cache_config,
             quant_config=quant_config,
         )
-        if (layer_idx not in config.mlp_only_layers) and (
+
+        # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
+        # `mlp_only_layers` in the config.
+        mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
+                           config.mlp_only_layers)
+        if (layer_idx not in mlp_only_layers) and (
                 config.num_experts > 0 and
             (layer_idx + 1) % config.decoder_sparse_step == 0):
             self.mlp = Qwen2MoeSparseMoeBlock(config=config,
@@ -426,38 +401,65 @@ class Qwen2MoeForCausalLM(nn.Module):
             ("gate_up_proj", "up_proj", 1),
         ]
 
+        expert_params_mapping = [
+            # These are the weights for the experts
+            # (param_name, weight_name, expert_id, shard_id)
+            ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
+             else "experts.w2_weight",
+             f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
+            for expert_id in range(self.config.num_experts) for shard_id,
+            weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
+        ]
+
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                # Skip non-stacked layers and experts (experts handled below).
                 if weight_name not in name:
                     continue
+                # We have mlp.experts[0].gate_proj in the checkpoint.
+                # Since we handle the experts below in expert_params_mapping,
+                # we need to skip here BEFORE we update the name, otherwise
+                # name will be updated to mlp.experts[0].gate_up_proj, which
+                # will then be updated below in expert_params_mapping
+                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+                if "mlp.experts" in name:
+                    continue
                 name = name.replace(weight_name, param_name)
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:
                     continue
-                # Skip experts that are not assigned to this worker.
-                if (("mlp.experts." in name or "mlp.shared_expert." in name)
-                        and name not in params_dict):
-                    continue
                 if name not in params_dict:
                     continue
+
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                # Skip loading extra bias for GPTQ models.
-                if name.endswith(".bias") and name not in params_dict:
-                    continue
-                # Skip experts that are not assigned to this worker.
-                if (("mlp.experts." in name or "mlp.shared_expert." in name)
-                        and name not in params_dict):
-                    continue
-                if name not in params_dict:
-                    continue
-                param = params_dict[name]
-                weight_loader = getattr(param, "weight_loader",
-                                        default_weight_loader)
-                weight_loader(param, loaded_weight)
+                for mapping in expert_params_mapping:
+                    param_name, weight_name, expert_id, shard_id = mapping
+                    if weight_name not in name:
+                        continue
+                    name = name.replace(weight_name, param_name)
+                    param = params_dict[name]
+                    weight_loader = param.weight_loader
+                    weight_loader(param,
+                                  loaded_weight,
+                                  weight_name,
+                                  shard_id=shard_id,
+                                  expert_id=expert_id)
+                    break
+                else:
+                    # Skip loading extra bias for GPTQ models.
+                    if name.endswith(".bias") and name not in params_dict:
+                        continue
+                    if name not in params_dict:
+                        continue
+
+                    param = params_dict[name]
+                    weight_loader = getattr(param, "weight_loader",
+                                            default_weight_loader)
+                    weight_loader(param, loaded_weight)

+ 189 - 1
aphrodite/quantization/fp8.py

@@ -8,6 +8,8 @@ from torch.nn.parameter import Parameter
 from aphrodite import _custom_ops as ops
 from aphrodite.common.utils import (get_device_capability_stateless,
                                     print_warning_once)
+from aphrodite.modeling.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
+                                                 fused_moe)
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import (QuantizationConfig,
@@ -70,7 +72,9 @@ class Fp8Config(QuantizationConfig):
 
         if isinstance(layer, LinearBase):
             return Fp8LinearMethod(self)
-        if isinstance(layer, Attention):
+        elif isinstance(layer, FusedMoE):
+            return Fp8MoEMethod(self)
+        elif isinstance(layer, Attention):
             return Fp8KVCacheMethod(self)
         return None
 
@@ -267,6 +271,185 @@ class Fp8LinearMethod(LinearMethodBase):
         return torch.narrow(output, 0, 0, x.shape[0])
 
 
+class Fp8MoEMethod(FusedMoEMethodBase):
+    """MoE method for FP8.
+    Supports loading FP8 checkpoints with static weight scale and
+    dynamic/static activation scale.
+    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
+    activation scaling. The weight scaling factor will be initialized after
+    the model weights are loaded.
+    Args:
+        quant_config: The quantization config.
+    """
+
+    def __init__(self, quant_config: Fp8Config):
+        self.quant_config = quant_config
+
+    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
+                       intermediate_size: int, params_dtype: torch.dtype,
+                       **extra_weight_attrs):
+
+        layer.process_after_load = True
+
+        if self.quant_config.is_checkpoint_fp8_serialized:
+            params_dtype = torch.float8_e4m3fn
+
+        # WEIGHTS
+        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
+                                                    2 * intermediate_size,
+                                                    hidden_size,
+                                                    dtype=params_dtype),
+                                        requires_grad=False)
+        layer.register_parameter("w13_weight", w13_weight)
+        set_weight_attrs(w13_weight, extra_weight_attrs)
+
+        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
+                                                   hidden_size,
+                                                   intermediate_size,
+                                                   dtype=params_dtype),
+                                       requires_grad=False)
+        layer.register_parameter("w2_weight", w2_weight)
+        set_weight_attrs(w2_weight, extra_weight_attrs)
+
+        # WEIGHT_SCALES
+        # Allocate 2 scales for w1 and w3 respectively.
+        # They will be combined to a single scale after weight loading.
+        w13_scale = torch.nn.Parameter(torch.ones(num_experts,
+                                                  2,
+                                                  dtype=torch.float32),
+                                       requires_grad=False)
+        layer.register_parameter("w13_scale", w13_scale)
+
+        w2_scale = torch.nn.Parameter(torch.ones(num_experts,
+                                                 dtype=torch.float32),
+                                      requires_grad=False)
+        layer.register_parameter("w2_scale", w2_scale)
+
+        # If loading fp8 checkpoint, pass the weight loaders.
+        # If loading an fp16 checkpoint, do not (we will quantize in
+        #   process_weights_after_loading()
+        if self.quant_config.is_checkpoint_fp8_serialized:
+            set_weight_attrs(w13_scale, extra_weight_attrs)
+            set_weight_attrs(w2_scale, extra_weight_attrs)
+
+        # INPUT_SCALES
+        if self.quant_config.activation_scheme == "static":
+            if not self.quant_config.is_checkpoint_fp8_serialized:
+                raise ValueError(
+                    "Found static activation scheme for checkpoint that "
+                    "was not serialized fp8.")
+
+            a13_scale = torch.nn.Parameter(torch.ones(num_experts,
+                                                      dtype=torch.float32),
+                                           requires_grad=False)
+            layer.register_parameter("a13_scale", a13_scale)
+            set_weight_attrs(a13_scale, extra_weight_attrs)
+
+            a2_scale = torch.nn.Parameter(torch.ones(num_experts,
+                                                     dtype=torch.float32),
+                                          requires_grad=False)
+            layer.register_parameter("a2_scale", a2_scale)
+            set_weight_attrs(a2_scale, extra_weight_attrs)
+        else:
+            layer.a13_scale = None
+            layer.a2_scale = None
+
+    def process_weights_after_loading(self, layer: Module) -> None:
+        if (not hasattr(layer, "process_after_load")
+                or not layer.process_after_load):
+            return
+
+        # If checkpoint is fp16, quantize in place.
+        if not self.quant_config.is_checkpoint_fp8_serialized:
+            w13_weight = torch.empty_like(layer.w13_weight.data,
+                                          dtype=torch.float8_e4m3fn)
+            w2_weight = torch.empty_like(layer.w2_weight.data,
+                                         dtype=torch.float8_e4m3fn)
+
+            # Re-initialize w13_scale because we directly quantize
+            # merged w13 weights and generate a single scaling factor.
+            layer.w13_scale = torch.nn.Parameter(torch.ones(
+                layer.num_experts,
+                dtype=torch.float32,
+                device=w13_weight.device),
+                                                 requires_grad=False)
+            for expert in range(layer.num_experts):
+                w13_weight[expert, :, :], layer.w13_scale[
+                    expert] = ops.scaled_fp8_quant(
+                        layer.w13_weight.data[expert, :, :])
+                w2_weight[expert, :, :], layer.w2_scale[
+                    expert] = ops.scaled_fp8_quant(
+                        layer.w2_weight.data[expert, :, :])
+            layer.w13_weight = torch.nn.Parameter(w13_weight,
+                                                  requires_grad=False)
+            layer.w2_weight = torch.nn.Parameter(w2_weight,
+                                                 requires_grad=False)
+            return
+
+        # If checkpoint is fp8, we need to handle that the
+        # MoE kernels require single activation scale and single weight
+        # scale for w13 per expert.
+        else:
+            # Fp8 moe kernels require a single activation scale.
+            # We take the max of all the scales in case they differ.
+            if self.quant_config.activation_scheme == "static":
+                if layer.a13_scale is None or layer.a2_scale is None:
+                    raise ValueError(
+                        "QuantConfig has static quantization, but found "
+                        "activation scales are None.")
+                if (not all_close_1d(layer.a13_scale)
+                        or not all_close_1d(layer.a2_scale)):
+                    print_warning_once(
+                        "Found input_scales that are not equal for "
+                        "fp8 MoE layer. Using the maximum across experts "
+                        "for each layer. ")
+                layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
+                                                     requires_grad=False)
+                layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
+                                                    requires_grad=False)
+
+            # Fp8 moe kernel needs single weight scale for w13 per expert.
+            # We take the max then dequant and requant each expert.
+            assert layer.w13_scale is not None
+            shard_size = layer.intermediate_size_per_partition
+            max_w13_scales = layer.w13_scale.max(dim=1).values
+            for expert_id in range(layer.num_experts):
+                start = 0
+                for shard_id in range(2):
+                    dq_weight = per_tensor_dequantize(
+                        layer.w13_weight[expert_id][start:start +
+                                                    shard_size, :],
+                        layer.w13_scale[expert_id][shard_id])
+                    layer.w13_weight[expert_id][
+                        start:start + shard_size, :] = per_tensor_quantize(
+                            dq_weight, max_w13_scales[expert_id])
+                    start += shard_size
+
+            layer.w13_scale = torch.nn.Parameter(max_w13_scales,
+                                                 requires_grad=False)
+            return
+
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              router_logits: torch.Tensor,
+              top_k: int,
+              renormalize: bool = True) -> torch.Tensor:
+
+        return fused_moe(x,
+                         layer.w13_weight,
+                         layer.w2_weight,
+                         router_logits,
+                         top_k,
+                         renormalize=renormalize,
+                         inplace=True,
+                         use_fp8=True,
+                         w1_scale=layer.w13_scale,
+                         w2_scale=layer.w2_scale,
+                         a1_scale=layer.a13_scale,
+                         a2_scale=layer.a2_scale)
+
+
 class Fp8KVCacheMethod(QuantizeMethodBase):
     """Supports loading kv-cache scaling factors from FP8 checkpoints.
     """
@@ -318,3 +501,8 @@ def per_tensor_dequantize(
     fake_qweight = tensor.to(torch.float16)
     dq_weight = fake_qweight * inv_scale
     return dq_weight
+
+
+def all_close_1d(x: torch.Tensor) -> bool:
+    assert len(x.shape) == 1
+    return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))