Browse Source

refactor: isolate FP8 from mixtral

AlpinDale 6 months ago
parent
commit
cf472315cc

+ 4 - 0
aphrodite/modeling/layers/fused_moe/__init__.py

@@ -1,5 +1,7 @@
 from aphrodite.modeling.layers.fused_moe.fused_moe import (
 from aphrodite.modeling.layers.fused_moe.fused_moe import (
     fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
     fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
+from aphrodite.modeling.layers.fused_moe.layer import (FusedMoE,
+                                                       FusedMoEMethodBase)
 
 
 __all__ = [
 __all__ = [
     "fused_moe",
     "fused_moe",
@@ -7,4 +9,6 @@ __all__ = [
     "fused_experts",
     "fused_experts",
     "get_config_file_name",
     "get_config_file_name",
     "grouped_topk",
     "grouped_topk",
+    "FusedMoE",
+    "FusedMoEMethodBase",
 ]
 ]

+ 191 - 0
aphrodite/modeling/layers/fused_moe/layer.py

@@ -0,0 +1,191 @@
+from abc import abstractmethod
+from typing import Optional
+
+import torch
+
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_all_reduce)
+from aphrodite.modeling.layers.fused_moe.fused_moe import fused_moe
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import (QuantizationConfig,
+                                                QuantizeMethodBase)
+
+
+class FusedMoEMethodBase(QuantizeMethodBase):
+
+    @abstractmethod
+    def create_weights(self, layer: torch.nn.Module, num_experts: int,
+                       hidden_size: int, intermediate_size: int,
+                       params_dtype: torch.dtype, **extra_weight_attrs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              router_logits: torch.Tensor,
+              top_k: int,
+              renormalize: bool = True) -> torch.Tensor:
+        raise NotImplementedError
+
+
+class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
+    """MoE method without quantization."""
+
+    def create_weights(self, layer: torch.nn.Module, num_experts: int,
+                       hidden_size: int, intermediate_size: int,
+                       params_dtype: torch.dtype, **extra_weight_attrs):
+
+        # Fused gate_up_proj (column parallel)
+        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
+                                                    2 * intermediate_size,
+                                                    hidden_size,
+                                                    dtype=params_dtype),
+                                        requires_grad=False)
+        layer.register_parameter("w13_weight", w13_weight)
+        set_weight_attrs(w13_weight, extra_weight_attrs)
+
+        # down_proj (row parallel)
+        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
+                                                   hidden_size,
+                                                   intermediate_size,
+                                                   dtype=params_dtype),
+                                       requires_grad=False)
+        layer.register_parameter("w2_weight", w2_weight)
+        set_weight_attrs(w2_weight, extra_weight_attrs)
+
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              router_logits: torch.Tensor,
+              top_k: int,
+              renormalize: bool = True) -> torch.Tensor:
+
+        return fused_moe(x,
+                         layer.w13_weight,
+                         layer.w2_weight,
+                         router_logits,
+                         top_k,
+                         renormalize=renormalize,
+                         inplace=True)
+
+
+class FusedMoE(torch.nn.Module):
+    """FusedMoE layer for MoE models.
+    This layer contains both MergedColumnParallel weights (gate_up_proj / 
+    w13) and RowParallelLinear weights (down_proj/ w2).
+    Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
+    copy that naming convention here and handle any remapping in the
+    load_weights function in each model implementation.
+    Args:
+        num_experts: Number of experts in the model
+        top_k: Number of experts selected for each token
+        hidden_size: Input hidden state size of the transformer
+        intermediate_size: Intermediate size of the experts
+        params_dtype: Data type for the parameters.
+        reduce_results: Whether to all all_reduce on the output of the layer
+        renomalize: Whether to renormalize the logits in the fused_moe kernel
+        quant_config: Quantization configure.
+    """
+
+    def __init__(
+        self,
+        num_experts: int,
+        top_k: int,
+        hidden_size: int,
+        intermediate_size: int,
+        params_dtype: Optional[torch.dtype] = None,
+        reduce_results: bool = False,
+        renormalize: bool = True,
+        quant_config: Optional[QuantizationConfig] = None,
+        tp_size: Optional[int] = None,
+    ):
+        super().__init__()
+
+        if params_dtype is None:
+            params_dtype = torch.get_default_dtype()
+
+        self.tp_size = (tp_size if tp_size is not None else
+                        get_tensor_model_parallel_world_size())
+        self.top_k = top_k
+        self.num_experts = num_experts
+        self.intermediate_size_per_partition = intermediate_size // self.tp_size
+        self.reduce_results = reduce_results
+        self.renormalize = renormalize
+
+        if quant_config is None:
+            self.quant_method: Optional[QuantizeMethodBase] = (
+                UnquantizedFusedMoEMethod())
+        else:
+            self.quant_method = quant_config.get_quant_method(self)
+        assert self.quant_method is not None
+
+        self.quant_method.create_weights(
+            layer=self,
+            num_experts=num_experts,
+            hidden_size=hidden_size,
+            intermediate_size=self.intermediate_size_per_partition,
+            params_dtype=params_dtype,
+            weight_loader=self.weight_loader)
+
+    def weight_loader(self, param: torch.nn.Parameter,
+                      loaded_weight: torch.Tensor, weight_name: str,
+                      shard_id: int, expert_id: int):
+        param_data = param.data
+
+        # FIXME: Overfit to Mixtral.
+        # Follow up PR to enable fp8 for other MoE models.
+        if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
+            if param_data[expert_id] != 1 and (param_data[expert_id] -
+                                               loaded_weight).abs() > 1e-5:
+                raise ValueError(
+                    "input_scales of w1 and w3 of a layer "
+                    f"must be equal. But got {param_data[expert_id]} "
+                    f"vs. {loaded_weight}")
+            param_data[expert_id] = loaded_weight
+        # FIXME: Overfit to Mixtral.
+        # Follow up PR to enable fp8 for other MoE models.
+        elif "weight_scale" in weight_name:
+            # We have to keep the weight scales of w1 and w3 because
+            # we need to re-quantize w1/w3 weights after weight loading.
+            assert "w1" in weight_name or "w3" in weight_name
+            shard_id = 0 if "w1" in weight_name else 1
+            param_data[expert_id][shard_id] = loaded_weight
+        else:
+            tp_rank = get_tensor_model_parallel_rank()
+            shard_size = self.intermediate_size_per_partition
+            shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
+
+            # w1, gate_proj case: Load into first shard of w13.
+            if shard_id == 0:
+                param_data[expert_id,
+                           0:shard_size, :] = loaded_weight[shard, :]
+            # w3, up_proj case: Load into second shard of w13.
+            elif shard_id == 2:
+                param_data[expert_id, shard_size:2 *
+                           shard_size, :] = loaded_weight[shard, :]
+            # w2, down_proj case: Load into only shard of w2.
+            elif shard_id == 1:
+                param_data[expert_id, :, :] = loaded_weight[:, shard]
+            else:
+                raise ValueError(
+                    f"Shard id must be in [0,1,2] but got {shard_id}")
+
+    def forward(self, hidden_states: torch.Tensor,
+                router_logits: torch.Tensor):
+        assert self.quant_method is not None
+
+        # Matrix multiply.
+        final_hidden_states = self.quant_method.apply(
+            self,
+            x=hidden_states,
+            router_logits=router_logits,
+            top_k=self.top_k,
+            renormalize=self.renormalize)
+
+        if self.reduce_results and self.tp_size > 1:
+            final_hidden_states = tensor_model_parallel_all_reduce(
+                final_hidden_states)
+
+        return final_hidden_states

+ 55 - 243
aphrodite/modeling/models/mixtral.py

@@ -27,15 +27,12 @@ import torch
 from torch import nn
 from torch import nn
 from transformers import MixtralConfig
 from transformers import MixtralConfig
 
 
-from aphrodite import _custom_ops as ops
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.utils import print_warning_once
 from aphrodite.common.utils import print_warning_once
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size,
-                                   tensor_model_parallel_all_reduce)
-from aphrodite.modeling.layers.fused_moe import fused_moe
+from aphrodite.distributed import get_tensor_model_parallel_world_size
+from aphrodite.modeling.layers.fused_moe import FusedMoE
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (QKVParallelLinear,
 from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               ReplicatedLinear,
                                               ReplicatedLinear,
@@ -46,12 +43,10 @@ from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
-from aphrodite.modeling.models.interfaces import SupportsLoRA
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.base_config import QuantizationConfig
-from aphrodite.quantization.fp8 import (Fp8Config, per_tensor_dequantize,
-                                        per_tensor_quantize)
+
+from .interfaces import SupportsLoRA
 
 
 
 
 class MixtralMoE(nn.Module):
 class MixtralMoE(nn.Module):
@@ -63,240 +58,55 @@ class MixtralMoE(nn.Module):
     across ranks.
     across ranks.
     """
     """
 
 
-    def __init__(
-        self,
-        num_experts: int,
-        top_k: int,
-        hidden_size: int,
-        intermediate_size: int,
-        params_dtype: Optional[torch.dtype] = None,
-        tp_size: Optional[int] = None,
-        quant_config: Optional[QuantizationConfig] = None,
-    ):
+    def __init__(self,
+                 num_experts: int,
+                 top_k: int,
+                 hidden_size: int,
+                 intermediate_size: int,
+                 params_dtype: Optional[torch.dtype] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 tp_size: Optional[int] = None):
         super().__init__()
         super().__init__()
-        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
-        self.num_total_experts = num_experts
-        self.top_k = top_k
         self.hidden_size = hidden_size
         self.hidden_size = hidden_size
-        self.intermediate_size = intermediate_size // self.tp_size
-        self.quant_config = quant_config
-
-        # FIXME(pcmoritz): Make this more general to support different
-        # quantization schemes
-        self.use_fp8 = isinstance(quant_config, Fp8Config)
-
-        if params_dtype is None:
-            params_dtype = torch.get_default_dtype()
-        self.params_dtype = params_dtype
 
 
         # Gate always runs at half / full precision for now.
         # Gate always runs at half / full precision for now.
-        self.gate = ReplicatedLinear(self.hidden_size,
-                                     self.num_total_experts,
+        self.gate = ReplicatedLinear(hidden_size,
+                                     num_experts,
                                      bias=False,
                                      bias=False,
-                                     params_dtype=self.params_dtype,
+                                     params_dtype=params_dtype,
                                      quant_config=None)
                                      quant_config=None)
 
 
-        if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
-            params_dtype = torch.float8_e4m3fn
-
-        self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts,
-                                                   2 * self.intermediate_size,
-                                                   self.hidden_size,
-                                                   dtype=params_dtype),
-                                       requires_grad=False)
-        self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts,
-                                                  self.hidden_size,
-                                                  self.intermediate_size,
-                                                  dtype=params_dtype),
-                                      requires_grad=False)
-
-        set_weight_attrs(self.w13_weight, {
-            "weight_loader": self.weight_loader,
-        })
-        set_weight_attrs(self.w2_weight, {
-            "weight_loader": self.weight_loader,
-        })
-
-        # Used for fp8.
-        self.w13_scale = None
-        self.w2_scale = None
-        self.a13_scale = None
-        self.a2_scale = None
-
-        if self.use_fp8:
-            # WEIGHT_SCALE (for fp8)
-            # Allocate 2 scales for w1 and w3 respectively.
-            # They will be combined to a single scale after weight loading.
-            self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
-                                                     2,
-                                                     dtype=torch.float32),
-                                          requires_grad=False)
-            self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
-                                                    dtype=torch.float32),
-                                         requires_grad=False)
-
-            # If loading fp8 checkpoint, pass the weight loaders.
-            # If loading an fp16 checkpoint, do not (we will quantize in
-            #   process_weights_after_loading()
-            if quant_config.is_checkpoint_fp8_serialized:
-                set_weight_attrs(self.w13_scale, {
-                    "weight_loader": self.weight_loader,
-                })
-                set_weight_attrs(self.w2_scale, {
-                    "weight_loader": self.weight_loader,
-                })
-
-            # ACT_SCALE (for fp8)
-            if quant_config.activation_scheme == "static":
-                if not quant_config.is_checkpoint_fp8_serialized:
-                    raise ValueError(
-                        "Found static activation scheme for checkpoint that "
-                        "was not serialized fp8.")
-                self.a13_scale = nn.Parameter(torch.ones(
-                    self.num_total_experts, dtype=torch.float32),
-                                              requires_grad=False)
-                self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts,
-                                                        dtype=torch.float32),
-                                             requires_grad=False)
-
-                set_weight_attrs(self.a13_scale, {
-                    "weight_loader": self.weight_loader,
-                })
-                set_weight_attrs(self.a2_scale, {
-                    "weight_loader": self.weight_loader,
-                })
-
-    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
-                      weight_name: str, expert_id: int):
-        tp_rank = get_tensor_model_parallel_rank()
-        param_data = param.data
-        shard_size = self.intermediate_size
-        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
-        if weight_name.endswith("w1.weight"):
-            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
-        if weight_name.endswith("w3.weight"):
-            param_data[expert_id,
-                       shard_size:2 * shard_size, :] = loaded_weight[shard, :]
-        if weight_name.endswith("w2.weight"):
-            param_data[expert_id, :, :] = loaded_weight[:, shard]
-
-        # Loading scales
-        if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
-            if param_data[expert_id] != 1 and (param_data[expert_id] -
-                                               loaded_weight).abs() > 1e-5:
-                raise ValueError(
-                    "act_scales of w1 and w3 of a layer "
-                    f"must be equal. But got {param_data[expert_id]} "
-                    f"vs. {loaded_weight}")
-            param_data[expert_id] = loaded_weight
-        elif "weight_scale" in weight_name:
-            # We have to keep the weight scales of w1 and w3 because
-            # we need to re-quantize w1/w3 weights after weight loading.
-            assert "w1" in weight_name or "w3" in weight_name
-            shard_id = 0 if "w1" in weight_name else 1
-            param_data[expert_id][shard_id] = loaded_weight
-
-    def process_weights_after_loading(self):
-        # Fp8 is the only case where we need to process after loading.
-        if not self.use_fp8:
-            return
-
-        # If checkpoint is fp16, quantize here.
-        if not self.quant_config.is_checkpoint_fp8_serialized:
-            w13_weight = torch.empty_like(self.w13_weight.data,
-                                          dtype=torch.float8_e4m3fn)
-            w2_weight = torch.empty_like(self.w2_weight.data,
-                                         dtype=torch.float8_e4m3fn)
-
-            # Re-initialize w13_scale because we directly quantize
-            # merged w13 weights and generate a single scaling factor.
-            self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
-                                                     dtype=torch.float32),
-                                          requires_grad=False)
-            for expert in range(self.num_total_experts):
-                w13_weight[expert, :, :], self.w13_scale[
-                    expert] = ops.scaled_fp8_quant(
-                        self.w13_weight.data[expert, :, :])
-                w2_weight[expert, :, :], self.w2_scale[
-                    expert] = ops.scaled_fp8_quant(
-                        self.w2_weight.data[expert, :, :])
-            self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
-            self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
-
-        else:
-            # If checkpoint is fp8 + static, cleanup act_scales.
-            #   Since state_dict has an input_scale per expert but our kernels
-            #   are passed one input_scale shared across all experts.
-            if self.quant_config.activation_scheme == "static":
-                if self.a13_scale is None or self.a2_scale is None:
-                    raise ValueError(
-                        "QuantConfig has static quantization, but found "
-                        "activation scales are None.")
-
-                if (not all_close_1d(self.a13_scale)
-                        or not all_close_1d(self.a2_scale)):
-                    print_warning_once(
-                        "Found act_scales that are not equal for "
-                        "fp8 MoE layer. Using the maximum across experts "
-                        "for each layer. ")
-
-                self.a13_scale = nn.Parameter(self.a13_scale.max(),
-                                              requires_grad=False)
-                self.a2_scale = nn.Parameter(self.a2_scale.max(),
-                                             requires_grad=False)
-
-            assert self.w13_scale is not None
-            shard_size = self.intermediate_size
-            max_w13_scales = self.w13_scale.max(dim=1).values
-            for expert_id in range(self.num_total_experts):
-                start = 0
-                for shard_id in range(2):
-                    dq_weight = per_tensor_dequantize(
-                        self.w13_weight[expert_id][start:start +
-                                                   shard_size, :],
-                        self.w13_scale[expert_id][shard_id])
-                    self.w13_weight[expert_id][
-                        start:start + shard_size, :] = per_tensor_quantize(
-                            dq_weight, max_w13_scales[expert_id])
-                    start += shard_size
-
-            self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False)
+        self.experts = FusedMoE(num_experts=num_experts,
+                                top_k=top_k,
+                                hidden_size=hidden_size,
+                                intermediate_size=intermediate_size,
+                                params_dtype=params_dtype,
+                                reduce_results=True,
+                                renormalize=True,
+                                quant_config=quant_config,
+                                tp_size=tp_size)
 
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         num_tokens, hidden_size = hidden_states.shape
         num_tokens, hidden_size = hidden_states.shape
         hidden_states = hidden_states.view(-1, self.hidden_size)
         hidden_states = hidden_states.view(-1, self.hidden_size)
         # router_logits: (num_tokens, n_experts)
         # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
         router_logits, _ = self.gate(hidden_states)
-        final_hidden_states = fused_moe(hidden_states,
-                                        self.w13_weight,
-                                        self.w2_weight,
-                                        router_logits,
-                                        self.top_k,
-                                        renormalize=True,
-                                        inplace=True,
-                                        use_fp8=self.use_fp8,
-                                        w1_scale=self.w13_scale,
-                                        w2_scale=self.w2_scale,
-                                        a1_scale=self.a13_scale,
-                                        a2_scale=self.a2_scale)
-
-        if self.tp_size > 1:
-            final_hidden_states = tensor_model_parallel_all_reduce(
-                final_hidden_states)
-
+        final_hidden_states = self.experts(hidden_states, router_logits)
         return final_hidden_states.view(num_tokens, hidden_size)
         return final_hidden_states.view(num_tokens, hidden_size)
 
 
 
 
 class MixtralAttention(nn.Module):
 class MixtralAttention(nn.Module):
 
 
-    def __init__(self,
-                 hidden_size: int,
-                 num_heads: int,
-                 num_kv_heads: int,
-                 max_position: int = 4096 * 32,
-                 rope_theta: float = 10000,
-                 cache_config: Optional[CacheConfig] = None,
-                 quant_config: Optional[QuantizationConfig] = None) -> None:
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        max_position: int = 4096 * 32,
+        rope_theta: float = 10000,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
         super().__init__()
         super().__init__()
         self.hidden_size = hidden_size
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
         tp_size = get_tensor_model_parallel_world_size()
@@ -501,8 +311,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
         lora_config: Optional[LoRAConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
+
         self.config = config
         self.config = config
         self.lora_config = lora_config
         self.lora_config = lora_config
+
         self.model = MixtralModel(config,
         self.model = MixtralModel(config,
                                   cache_config,
                                   cache_config,
                                   quant_config,
                                   quant_config,
@@ -559,25 +371,28 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
 
 
         expert_params_mapping = [
         expert_params_mapping = [
             # These are the weight scales for the experts
             # These are the weight scales for the experts
-            # (param_name, weight_name, expert_id)
-            ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
-             f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
-            for expert_id in range(self.config.num_local_experts)
-            for weight_name in ["w1", "w2", "w3"]
+            # (param_name, weight_name, expert_id, shard_id)
+            ("experts.w13_scale"
+             if weight_name in ["w1", "w3"] else "experts.w2_scale",
+             f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
+             shard_id) for expert_id in range(self.config.num_local_experts)
+            for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
         ] + [
         ] + [
             # These are the weights for the experts
             # These are the weights for the experts
             # (param_name, weight_name, expert_id)
             # (param_name, weight_name, expert_id)
-            ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
-             f"experts.{expert_id}.{weight_name}.weight", expert_id)
+            ("experts.w13_weight"
+             if weight_name in ["w1", "w3"] else "experts.w2_weight",
+             f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
             for expert_id in range(self.config.num_local_experts)
             for expert_id in range(self.config.num_local_experts)
-            for weight_name in ["w1", "w2", "w3"]
+            for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
         ] + [
         ] + [
             # These are the activation scales for the experts
             # These are the activation scales for the experts
             # (param_name, weight_name, expert_id)
             # (param_name, weight_name, expert_id)
-            ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
-             f"experts.{expert_id}.{weight_name}.input_scale", expert_id)
-            for expert_id in range(self.config.num_local_experts)
-            for weight_name in ["w1", "w2", "w3"]
+            ("experts.a13_scale"
+             if weight_name in ["w1", "w3"] else "experts.a2_scale",
+             f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
+             shard_id) for expert_id in range(self.config.num_local_experts)
+            for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
         ]
         ]
 
 
         params_dict = dict(self.named_parameters())
         params_dict = dict(self.named_parameters())
@@ -597,7 +412,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
                 weight_loader(param, loaded_weight, shard_id)
                 weight_loader(param, loaded_weight, shard_id)
                 break
                 break
             else:
             else:
-                for param_name, weight_name, expert_id in expert_params_mapping:
+                for mapping in expert_params_mapping:
+                    param_name, weight_name, expert_id, shard_id = mapping
                     if weight_name not in name:
                     if weight_name not in name:
                         continue
                         continue
                     name = name.replace(weight_name, param_name)
                     name = name.replace(weight_name, param_name)
@@ -606,6 +422,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
                     weight_loader(param,
                     weight_loader(param,
                                   loaded_weight,
                                   loaded_weight,
                                   weight_name,
                                   weight_name,
+                                  shard_id=shard_id,
                                   expert_id=expert_id)
                                   expert_id=expert_id)
                     break
                     break
                 else:
                 else:
@@ -630,8 +447,3 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
                     weight_loader = getattr(param, "weight_loader",
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
                     weight_loader(param, loaded_weight)
-
-
-def all_close_1d(x: torch.Tensor) -> bool:
-    assert len(x.shape) == 1
-    return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))

+ 69 - 67
aphrodite/modeling/models/qwen2_moe.py

@@ -32,11 +32,10 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size,
+from aphrodite.distributed import (get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.fused_moe import fused_moe
+from aphrodite.modeling.layers.fused_moe import FusedMoE
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
@@ -92,28 +91,23 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
         quant_config: Optional[QuantizationConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
     ):
         super().__init__()
         super().__init__()
-        self.config = config
-        self.rank = get_tensor_model_parallel_rank()
         self.tp_size = get_tensor_model_parallel_world_size()
         self.tp_size = get_tensor_model_parallel_world_size()
-        self.n_routed_experts = config.num_experts
-        self.top_k = config.num_experts_per_tok
-        if self.tp_size > self.n_routed_experts:
+
+        if self.tp_size > config.num_experts:
             raise ValueError(
             raise ValueError(
                 f"Tensor parallel size {self.tp_size} is greater than "
                 f"Tensor parallel size {self.tp_size} is greater than "
-                f"the number of experts {self.n_routed_experts}.")
-
-        self.experts = nn.ModuleList([
-            Qwen2MoeMLP(hidden_size=config.hidden_size,
-                        intermediate_size=config.moe_intermediate_size,
-                        hidden_act=config.hidden_act,
-                        quant_config=quant_config,
-                        reduce_results=False)
-            for idx in range(self.n_routed_experts)
-        ])
-        self.pack_params()
+                f"the number of experts {config.num_experts}.")
+
+        self.experts = FusedMoE(num_experts=config.num_experts,
+                                top_k=config.num_experts_per_tok,
+                                hidden_size=config.hidden_size,
+                                intermediate_size=config.moe_intermediate_size,
+                                reduce_results=False,
+                                renormalize=config.norm_topk_prob,
+                                quant_config=quant_config)
 
 
         self.gate = ReplicatedLinear(config.hidden_size,
         self.gate = ReplicatedLinear(config.hidden_size,
-                                     self.n_routed_experts,
+                                     config.num_experts,
                                      bias=False,
                                      bias=False,
                                      quant_config=None)
                                      quant_config=None)
         if config.shared_expert_intermediate_size > 0:
         if config.shared_expert_intermediate_size > 0:
@@ -130,25 +124,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
                                                   1,
                                                   1,
                                                   bias=False)
                                                   bias=False)
 
 
-    def pack_params(self):
-        w1 = []
-        w2 = []
-        for expert in self.experts:
-            w1.append(expert.gate_up_proj.weight)
-            w2.append(expert.down_proj.weight)
-        self.w1 = torch._utils._flatten_dense_tensors(w1)
-        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
-        for data, param in zip(w1s, w1):
-            param.data = data
-        self.w1 = self.w1.view(len(w1), *w1s[0].shape)
-
-        self.w2 = torch._utils._flatten_dense_tensors(w2)
-        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
-        for data, param in zip(w2s, w2):
-            param.data = data
-
-        self.w2 = self.w2.view(len(w2), *w2s[0].shape)
-
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         num_tokens, hidden_dim = hidden_states.shape
         num_tokens, hidden_dim = hidden_states.shape
         hidden_states = hidden_states.view(-1, hidden_dim)
         hidden_states = hidden_states.view(-1, hidden_dim)
@@ -161,18 +136,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
 
 
         # router_logits: (num_tokens, n_experts)
         # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
         router_logits, _ = self.gate(hidden_states)
-        final_hidden_states = fused_moe(hidden_states,
-                                        self.w1,
-                                        self.w2,
-                                        router_logits,
-                                        self.top_k,
-                                        renormalize=self.config.norm_topk_prob,
-                                        inplace=True)
-
+        final_hidden_states = self.experts(hidden_states=hidden_states,
+                                           router_logits=router_logits)
         if shared_output is not None:
         if shared_output is not None:
             final_hidden_states = final_hidden_states + shared_output
             final_hidden_states = final_hidden_states + shared_output
-        final_hidden_states = tensor_model_parallel_all_reduce(
-            final_hidden_states)
+        if self.tp_size > 1:
+            final_hidden_states = tensor_model_parallel_all_reduce(
+                final_hidden_states)
 
 
         return final_hidden_states.view(num_tokens, hidden_dim)
         return final_hidden_states.view(num_tokens, hidden_dim)
 
 
@@ -283,7 +253,12 @@ class Qwen2MoeDecoderLayer(nn.Module):
             cache_config=cache_config,
             cache_config=cache_config,
             quant_config=quant_config,
             quant_config=quant_config,
         )
         )
-        if (layer_idx not in config.mlp_only_layers) and (
+
+        # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
+        # `mlp_only_layers` in the config.
+        mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
+                           config.mlp_only_layers)
+        if (layer_idx not in mlp_only_layers) and (
                 config.num_experts > 0 and
                 config.num_experts > 0 and
             (layer_idx + 1) % config.decoder_sparse_step == 0):
             (layer_idx + 1) % config.decoder_sparse_step == 0):
             self.mlp = Qwen2MoeSparseMoeBlock(config=config,
             self.mlp = Qwen2MoeSparseMoeBlock(config=config,
@@ -426,38 +401,65 @@ class Qwen2MoeForCausalLM(nn.Module):
             ("gate_up_proj", "up_proj", 1),
             ("gate_up_proj", "up_proj", 1),
         ]
         ]
 
 
+        expert_params_mapping = [
+            # These are the weights for the experts
+            # (param_name, weight_name, expert_id, shard_id)
+            ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
+             else "experts.w2_weight",
+             f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
+            for expert_id in range(self.config.num_experts) for shard_id,
+            weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
+        ]
+
         params_dict = dict(self.named_parameters())
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in weights:
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
             if "rotary_emb.inv_freq" in name:
                 continue
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                # Skip non-stacked layers and experts (experts handled below).
                 if weight_name not in name:
                 if weight_name not in name:
                     continue
                     continue
+                # We have mlp.experts[0].gate_proj in the checkpoint.
+                # Since we handle the experts below in expert_params_mapping,
+                # we need to skip here BEFORE we update the name, otherwise
+                # name will be updated to mlp.experts[0].gate_up_proj, which
+                # will then be updated below in expert_params_mapping
+                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+                if "mlp.experts" in name:
+                    continue
                 name = name.replace(weight_name, param_name)
                 name = name.replace(weight_name, param_name)
                 # Skip loading extra bias for GPTQ models.
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:
                 if name.endswith(".bias") and name not in params_dict:
                     continue
                     continue
-                # Skip experts that are not assigned to this worker.
-                if (("mlp.experts." in name or "mlp.shared_expert." in name)
-                        and name not in params_dict):
-                    continue
                 if name not in params_dict:
                 if name not in params_dict:
                     continue
                     continue
+
                 param = params_dict[name]
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 weight_loader(param, loaded_weight, shard_id)
                 break
                 break
             else:
             else:
-                # Skip loading extra bias for GPTQ models.
-                if name.endswith(".bias") and name not in params_dict:
-                    continue
-                # Skip experts that are not assigned to this worker.
-                if (("mlp.experts." in name or "mlp.shared_expert." in name)
-                        and name not in params_dict):
-                    continue
-                if name not in params_dict:
-                    continue
-                param = params_dict[name]
-                weight_loader = getattr(param, "weight_loader",
-                                        default_weight_loader)
-                weight_loader(param, loaded_weight)
+                for mapping in expert_params_mapping:
+                    param_name, weight_name, expert_id, shard_id = mapping
+                    if weight_name not in name:
+                        continue
+                    name = name.replace(weight_name, param_name)
+                    param = params_dict[name]
+                    weight_loader = param.weight_loader
+                    weight_loader(param,
+                                  loaded_weight,
+                                  weight_name,
+                                  shard_id=shard_id,
+                                  expert_id=expert_id)
+                    break
+                else:
+                    # Skip loading extra bias for GPTQ models.
+                    if name.endswith(".bias") and name not in params_dict:
+                        continue
+                    if name not in params_dict:
+                        continue
+
+                    param = params_dict[name]
+                    weight_loader = getattr(param, "weight_loader",
+                                            default_weight_loader)
+                    weight_loader(param, loaded_weight)

+ 189 - 1
aphrodite/quantization/fp8.py

@@ -8,6 +8,8 @@ from torch.nn.parameter import Parameter
 from aphrodite import _custom_ops as ops
 from aphrodite import _custom_ops as ops
 from aphrodite.common.utils import (get_device_capability_stateless,
 from aphrodite.common.utils import (get_device_capability_stateless,
                                     print_warning_once)
                                     print_warning_once)
+from aphrodite.modeling.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
+                                                 fused_moe)
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import (QuantizationConfig,
 from aphrodite.quantization.base_config import (QuantizationConfig,
@@ -70,7 +72,9 @@ class Fp8Config(QuantizationConfig):
 
 
         if isinstance(layer, LinearBase):
         if isinstance(layer, LinearBase):
             return Fp8LinearMethod(self)
             return Fp8LinearMethod(self)
-        if isinstance(layer, Attention):
+        elif isinstance(layer, FusedMoE):
+            return Fp8MoEMethod(self)
+        elif isinstance(layer, Attention):
             return Fp8KVCacheMethod(self)
             return Fp8KVCacheMethod(self)
         return None
         return None
 
 
@@ -267,6 +271,185 @@ class Fp8LinearMethod(LinearMethodBase):
         return torch.narrow(output, 0, 0, x.shape[0])
         return torch.narrow(output, 0, 0, x.shape[0])
 
 
 
 
+class Fp8MoEMethod(FusedMoEMethodBase):
+    """MoE method for FP8.
+    Supports loading FP8 checkpoints with static weight scale and
+    dynamic/static activation scale.
+    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
+    activation scaling. The weight scaling factor will be initialized after
+    the model weights are loaded.
+    Args:
+        quant_config: The quantization config.
+    """
+
+    def __init__(self, quant_config: Fp8Config):
+        self.quant_config = quant_config
+
+    def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
+                       intermediate_size: int, params_dtype: torch.dtype,
+                       **extra_weight_attrs):
+
+        layer.process_after_load = True
+
+        if self.quant_config.is_checkpoint_fp8_serialized:
+            params_dtype = torch.float8_e4m3fn
+
+        # WEIGHTS
+        w13_weight = torch.nn.Parameter(torch.empty(num_experts,
+                                                    2 * intermediate_size,
+                                                    hidden_size,
+                                                    dtype=params_dtype),
+                                        requires_grad=False)
+        layer.register_parameter("w13_weight", w13_weight)
+        set_weight_attrs(w13_weight, extra_weight_attrs)
+
+        w2_weight = torch.nn.Parameter(torch.empty(num_experts,
+                                                   hidden_size,
+                                                   intermediate_size,
+                                                   dtype=params_dtype),
+                                       requires_grad=False)
+        layer.register_parameter("w2_weight", w2_weight)
+        set_weight_attrs(w2_weight, extra_weight_attrs)
+
+        # WEIGHT_SCALES
+        # Allocate 2 scales for w1 and w3 respectively.
+        # They will be combined to a single scale after weight loading.
+        w13_scale = torch.nn.Parameter(torch.ones(num_experts,
+                                                  2,
+                                                  dtype=torch.float32),
+                                       requires_grad=False)
+        layer.register_parameter("w13_scale", w13_scale)
+
+        w2_scale = torch.nn.Parameter(torch.ones(num_experts,
+                                                 dtype=torch.float32),
+                                      requires_grad=False)
+        layer.register_parameter("w2_scale", w2_scale)
+
+        # If loading fp8 checkpoint, pass the weight loaders.
+        # If loading an fp16 checkpoint, do not (we will quantize in
+        #   process_weights_after_loading()
+        if self.quant_config.is_checkpoint_fp8_serialized:
+            set_weight_attrs(w13_scale, extra_weight_attrs)
+            set_weight_attrs(w2_scale, extra_weight_attrs)
+
+        # INPUT_SCALES
+        if self.quant_config.activation_scheme == "static":
+            if not self.quant_config.is_checkpoint_fp8_serialized:
+                raise ValueError(
+                    "Found static activation scheme for checkpoint that "
+                    "was not serialized fp8.")
+
+            a13_scale = torch.nn.Parameter(torch.ones(num_experts,
+                                                      dtype=torch.float32),
+                                           requires_grad=False)
+            layer.register_parameter("a13_scale", a13_scale)
+            set_weight_attrs(a13_scale, extra_weight_attrs)
+
+            a2_scale = torch.nn.Parameter(torch.ones(num_experts,
+                                                     dtype=torch.float32),
+                                          requires_grad=False)
+            layer.register_parameter("a2_scale", a2_scale)
+            set_weight_attrs(a2_scale, extra_weight_attrs)
+        else:
+            layer.a13_scale = None
+            layer.a2_scale = None
+
+    def process_weights_after_loading(self, layer: Module) -> None:
+        if (not hasattr(layer, "process_after_load")
+                or not layer.process_after_load):
+            return
+
+        # If checkpoint is fp16, quantize in place.
+        if not self.quant_config.is_checkpoint_fp8_serialized:
+            w13_weight = torch.empty_like(layer.w13_weight.data,
+                                          dtype=torch.float8_e4m3fn)
+            w2_weight = torch.empty_like(layer.w2_weight.data,
+                                         dtype=torch.float8_e4m3fn)
+
+            # Re-initialize w13_scale because we directly quantize
+            # merged w13 weights and generate a single scaling factor.
+            layer.w13_scale = torch.nn.Parameter(torch.ones(
+                layer.num_experts,
+                dtype=torch.float32,
+                device=w13_weight.device),
+                                                 requires_grad=False)
+            for expert in range(layer.num_experts):
+                w13_weight[expert, :, :], layer.w13_scale[
+                    expert] = ops.scaled_fp8_quant(
+                        layer.w13_weight.data[expert, :, :])
+                w2_weight[expert, :, :], layer.w2_scale[
+                    expert] = ops.scaled_fp8_quant(
+                        layer.w2_weight.data[expert, :, :])
+            layer.w13_weight = torch.nn.Parameter(w13_weight,
+                                                  requires_grad=False)
+            layer.w2_weight = torch.nn.Parameter(w2_weight,
+                                                 requires_grad=False)
+            return
+
+        # If checkpoint is fp8, we need to handle that the
+        # MoE kernels require single activation scale and single weight
+        # scale for w13 per expert.
+        else:
+            # Fp8 moe kernels require a single activation scale.
+            # We take the max of all the scales in case they differ.
+            if self.quant_config.activation_scheme == "static":
+                if layer.a13_scale is None or layer.a2_scale is None:
+                    raise ValueError(
+                        "QuantConfig has static quantization, but found "
+                        "activation scales are None.")
+                if (not all_close_1d(layer.a13_scale)
+                        or not all_close_1d(layer.a2_scale)):
+                    print_warning_once(
+                        "Found input_scales that are not equal for "
+                        "fp8 MoE layer. Using the maximum across experts "
+                        "for each layer. ")
+                layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
+                                                     requires_grad=False)
+                layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
+                                                    requires_grad=False)
+
+            # Fp8 moe kernel needs single weight scale for w13 per expert.
+            # We take the max then dequant and requant each expert.
+            assert layer.w13_scale is not None
+            shard_size = layer.intermediate_size_per_partition
+            max_w13_scales = layer.w13_scale.max(dim=1).values
+            for expert_id in range(layer.num_experts):
+                start = 0
+                for shard_id in range(2):
+                    dq_weight = per_tensor_dequantize(
+                        layer.w13_weight[expert_id][start:start +
+                                                    shard_size, :],
+                        layer.w13_scale[expert_id][shard_id])
+                    layer.w13_weight[expert_id][
+                        start:start + shard_size, :] = per_tensor_quantize(
+                            dq_weight, max_w13_scales[expert_id])
+                    start += shard_size
+
+            layer.w13_scale = torch.nn.Parameter(max_w13_scales,
+                                                 requires_grad=False)
+            return
+
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              router_logits: torch.Tensor,
+              top_k: int,
+              renormalize: bool = True) -> torch.Tensor:
+
+        return fused_moe(x,
+                         layer.w13_weight,
+                         layer.w2_weight,
+                         router_logits,
+                         top_k,
+                         renormalize=renormalize,
+                         inplace=True,
+                         use_fp8=True,
+                         w1_scale=layer.w13_scale,
+                         w2_scale=layer.w2_scale,
+                         a1_scale=layer.a13_scale,
+                         a2_scale=layer.a2_scale)
+
+
 class Fp8KVCacheMethod(QuantizeMethodBase):
 class Fp8KVCacheMethod(QuantizeMethodBase):
     """Supports loading kv-cache scaling factors from FP8 checkpoints.
     """Supports loading kv-cache scaling factors from FP8 checkpoints.
     """
     """
@@ -318,3 +501,8 @@ def per_tensor_dequantize(
     fake_qweight = tensor.to(torch.float16)
     fake_qweight = tensor.to(torch.float16)
     dq_weight = fake_qweight * inv_scale
     dq_weight = fake_qweight * inv_scale
     return dq_weight
     return dq_weight
+
+
+def all_close_1d(x: torch.Tensor) -> bool:
+    assert len(x.shape) == 1
+    return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))