Ver código fonte

models: add support for Phi3 MoE

AlpinDale 2 meses atrás
pai
commit
201db10f02

+ 14 - 5
aphrodite/modeling/layers/fused_moe/fused_moe.py

@@ -2,7 +2,7 @@
 import functools
 import json
 import os
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Callable, Dict, Optional, Tuple
 
 import torch
 import triton
@@ -444,7 +444,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
                      rand_perm1: torch.Tensor,
                      rand_perm2: torch.Tensor,
                      topk: int,
-                     renormalize: bool,
+                     custom_routing_function: Optional[Callable] = None,
+                     renormalize: bool = True,
                      override_config: Optional[Dict[str, Any]] = None,
                      use_fp8: bool = False,
                      w1_scale: Optional[torch.Tensor] = None,
@@ -495,8 +496,12 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
     E = w1.shape[0]
     N = w2.shape[1] * 16
 
-    topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
-                                        renormalize)
+    if custom_routing_function is None:
+        topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
+                                            renormalize)
+    else:
+        topk_weights, topk_ids = custom_routing_function(
+            hidden_states, gating_output, topk, renormalize)
 
     get_config_func = functools.partial(try_get_optimal_moe_config,
                                         w1.shape,
@@ -691,6 +696,7 @@ def fused_moe(
     use_grouped_topk: bool = False,
     num_expert_group: Optional[int] = None,
     topk_group: Optional[int] = None,
+    custom_routing_function: Optional[Callable] = None,
     use_fp8_w8a8: bool = False,
     use_int8_w8a16: bool = False,
     w1_scale: Optional[torch.Tensor] = None,
@@ -738,9 +744,12 @@ def fused_moe(
         topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
                                               topk, renormalize,
                                               num_expert_group, topk_group)
-    else:
+    elif custom_routing_function is None:
         topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
                                             renormalize)
+    else:
+        topk_weights, topk_ids = custom_routing_function(
+            hidden_states, gating_output, topk, renormalize)
 
     return fused_experts(hidden_states,
                          w1,

+ 56 - 34
aphrodite/modeling/layers/fused_moe/layer.py

@@ -1,6 +1,6 @@
 from abc import abstractmethod
 from enum import Enum
-from typing import List, Optional, Tuple
+from typing import Callable, List, Optional, Tuple
 
 import torch
 
@@ -59,15 +59,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
         layer.register_parameter("w2_weight", w2_weight)
         set_weight_attrs(w2_weight, extra_weight_attrs)
 
-    def apply(self,
-              layer: torch.nn.Module,
-              x: torch.Tensor,
-              router_logits: torch.Tensor,
-              top_k: int,
-              renormalize: bool,
-              use_grouped_topk: bool,
-              topk_group: Optional[int] = None,
-              num_expert_group: Optional[int] = None) -> torch.Tensor:
+    def apply(
+            self,
+            layer: torch.nn.Module,
+            x: torch.Tensor,
+            router_logits: torch.Tensor,
+            top_k: int,
+            renormalize: bool,
+            use_grouped_topk: bool,
+            topk_group: Optional[int] = None,
+            num_expert_group: Optional[int] = None,
+            custom_routing_function: Optional[Callable] = None
+    ) -> torch.Tensor:
 
         return self.forward(x=x,
                             layer=layer,
@@ -76,17 +79,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
                             renormalize=renormalize,
                             use_grouped_topk=use_grouped_topk,
                             topk_group=topk_group,
-                            num_expert_group=num_expert_group)
-
-    def forward_cuda(self,
-                     layer: torch.nn.Module,
-                     x: torch.Tensor,
-                     use_grouped_topk: bool,
-                     top_k: int,
-                     router_logits: torch.Tensor,
-                     renormalize: bool,
-                     topk_group: Optional[int] = None,
-                     num_expert_group: Optional[int] = None) -> torch.Tensor:
+                            num_expert_group=num_expert_group,
+                            custom_routing_function=custom_routing_function)
+
+    def forward_cuda(
+            self,
+            layer: torch.nn.Module,
+            x: torch.Tensor,
+            use_grouped_topk: bool,
+            top_k: int,
+            router_logits: torch.Tensor,
+            renormalize: bool,
+            topk_group: Optional[int] = None,
+            num_expert_group: Optional[int] = None,
+            custom_routing_function: Optional[Callable] = None
+    ) -> torch.Tensor:
 
         from aphrodite.modeling.layers.fused_moe.fused_moe import fused_experts
 
@@ -97,7 +104,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
             top_k=top_k,
             renormalize=renormalize,
             topk_group=topk_group,
-            num_expert_group=num_expert_group)
+            num_expert_group=num_expert_group,
+            custom_routing_function=custom_routing_function)
 
         return fused_experts(hidden_states=x,
                              w1=layer.w13_weight,
@@ -110,20 +118,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
         raise NotImplementedError(
             "The CPU backend currently does not support MoE.")
 
-    def forward_tpu(self,
-                    layer: torch.nn.Module,
-                    x: torch.Tensor,
-                    use_grouped_topk: bool,
-                    top_k: int,
-                    router_logits: torch.Tensor,
-                    renormalize: bool,
-                    topk_group: Optional[int] = None,
-                    num_expert_group: Optional[int] = None) -> torch.Tensor:
+    def forward_tpu(
+            self,
+            layer: torch.nn.Module,
+            x: torch.Tensor,
+            use_grouped_topk: bool,
+            top_k: int,
+            router_logits: torch.Tensor,
+            renormalize: bool,
+            topk_group: Optional[int] = None,
+            num_expert_group: Optional[int] = None,
+            custom_routing_function: Optional[Callable] = None
+    ) -> torch.Tensor:
 
         from aphrodite.modeling.layers.fused_moe.moe_pallas import fused_moe
         assert not use_grouped_topk
         assert num_expert_group is None
         assert topk_group is None
+        assert custom_routing_function is None
         return fused_moe(hidden_states=x,
                          w1=layer.w13_weight,
                          w2=layer.w2_weight,
@@ -168,6 +180,7 @@ class FusedMoE(torch.nn.Module):
         quant_config: Optional[QuantizationConfig] = None,
         tp_size: Optional[int] = None,
         prefix: str = "",
+        custom_routing_function: Optional[Callable] = None,
     ):
         super().__init__()
 
@@ -186,6 +199,7 @@ class FusedMoE(torch.nn.Module):
             assert num_expert_group is not None and topk_group is not None
         self.num_expert_group = num_expert_group
         self.topk_group = topk_group
+        self.custom_routing_function = custom_routing_function
 
         if quant_config is None:
             self.quant_method: Optional[QuantizeMethodBase] = (
@@ -386,7 +400,8 @@ class FusedMoE(torch.nn.Module):
                        use_grouped_topk: bool,
                        renormalize: bool,
                        topk_group: Optional[int] = None,
-                       num_expert_group: Optional[int] = None):
+                       num_expert_group: Optional[int] = None,
+                       custom_routing_function: Optional[Callable] = None):
         from aphrodite.modeling.layers.fused_moe.fused_moe import (
             fused_topk, grouped_topk)
 
@@ -401,11 +416,17 @@ class FusedMoE(torch.nn.Module):
                 renormalize=renormalize,
                 num_expert_group=num_expert_group,
                 topk_group=topk_group)
-        else:
+        elif custom_routing_function is None:
             topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
                                                 gating_output=router_logits,
                                                 topk=top_k,
                                                 renormalize=renormalize)
+        else:
+            topk_weights, topk_ids = custom_routing_function(
+                hidden_states=hidden_states,
+                gating_output=router_logits,
+                topk=top_k,
+                renormalize=renormalize)
 
         return topk_weights, topk_ids
 
@@ -422,7 +443,8 @@ class FusedMoE(torch.nn.Module):
             renormalize=self.renormalize,
             use_grouped_topk=self.use_grouped_topk,
             topk_group=self.topk_group,
-            num_expert_group=self.num_expert_group)
+            num_expert_group=self.num_expert_group,
+            custom_routing_function=self.custom_routing_function)
 
         if self.reduce_results and self.tp_size > 1:
             final_hidden_states = tensor_model_parallel_all_reduce(

+ 14 - 11
aphrodite/modeling/layers/rotary_embedding.py

@@ -508,8 +508,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
         dtype: torch.dtype,
         short_factor: List[float],
         long_factor: List[float],
-        short_mscale: float = 1.0,
-        long_mscale: float = 1.0,
+        short_mscale: Optional[float] = None,
+        long_mscale: Optional[float] = None,
     ):
         super().__init__()
 
@@ -528,18 +528,21 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
         self.base = base
         self.short_factor = short_factor
         self.long_factor = long_factor
-        self.short_mscale = short_mscale
-        self.long_mscale = long_mscale
-
-        scale = (self.max_position_embeddings /
-                 self.original_max_position_embeddings)
 
+        scale = self.max_position_embeddings / \
+            self.original_max_position_embeddings
         if scale <= 1.0:
-            self.scaling_factor = 1.0
+            scaling_factor = 1.0
         else:
-            self.scaling_factor = math.sqrt(
+            scaling_factor = math.sqrt(
                 1 + math.log(scale) /
                 math.log(self.original_max_position_embeddings))
+        if short_mscale is None:
+            short_mscale = scaling_factor
+        if long_mscale is None:
+            long_mscale = scaling_factor
+        self.short_mscale = short_mscale
+        self.long_mscale = long_mscale
 
         short_cache = self._compute_cos_sin_cache(
             original_max_position_embeddings, short_factor, short_mscale)
@@ -576,8 +579,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
         inv_freq = self._compute_inv_freq(rescale_factors)
         t = torch.arange(max_position_embeddings, dtype=torch.float)
         freqs = torch.einsum("i,j -> ij", t, inv_freq)
-        cos = freqs.cos() * mscale * self.scaling_factor
-        sin = freqs.sin() * mscale * self.scaling_factor
+        cos = freqs.cos() * mscale
+        sin = freqs.sin() * mscale
         cache = torch.cat((cos, sin), dim=-1)
         return cache
 

+ 1 - 0
aphrodite/modeling/models/__init__.py

@@ -47,6 +47,7 @@ _GENERATION_MODELS = {
     "OrionForCausalLM": ("orion", "OrionForCausalLM"),
     "PhiForCausalLM": ("phi", "PhiForCausalLM"),
     "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
+    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
     "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
     "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
     "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),

+ 619 - 0
aphrodite/modeling/models/phimoe.py

@@ -0,0 +1,619 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only PhiMoE model."""
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig, LoRAConfig
+from aphrodite.common.sequence import IntermediateTensors
+from aphrodite.distributed import get_tensor_model_parallel_world_size
+from aphrodite.modeling.layers.fused_moe import FusedMoE
+from aphrodite.modeling.layers.linear import (QKVParallelLinear,
+                                              ReplicatedLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import (
+    default_weight_loader, maybe_remap_kv_scale_name)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
+
+from .interfaces import SupportsLoRA
+
+
+class PhiMoEConfig(PretrainedConfig):
+
+    model_type = "phimoe"
+    keys_to_ignore_at_inference = ["past_key_values"]
+
+    def __init__(
+        self,
+        vocab_size=32000,
+        hidden_size=4096,
+        intermediate_size=14336,
+        num_hidden_layers=32,
+        num_attention_heads=32,
+        num_key_value_heads=8,
+        hidden_act="silu",
+        max_position_embeddings=4096 * 32,
+        initializer_range=0.02,
+        rms_norm_eps=1e-5,
+        use_cache=True,
+        pad_token_id=None,
+        bos_token_id=1,
+        eos_token_id=2,
+        tie_word_embeddings=False,
+        rope_theta=1e6,
+        sliding_window=None,
+        attention_dropout=0.0,
+        num_experts_per_tok=2,
+        num_local_experts=16,
+        output_router_logits=False,
+        router_aux_loss_coef=0.001,
+        router_jitter_noise=0.0,
+        attention_bias=False,
+        lm_head_bias=False,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.sliding_window = sliding_window
+        self.attention_bias = attention_bias
+        self.lm_head_bias = lm_head_bias
+        # for backward compatibility
+        if num_key_value_heads is None:
+            num_key_value_heads = num_attention_heads
+
+        self.num_key_value_heads = num_key_value_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.rms_norm_eps = rms_norm_eps
+        self.use_cache = use_cache
+        self.rope_theta = rope_theta
+        self.attention_dropout = attention_dropout
+
+        self.num_experts_per_tok = num_experts_per_tok
+        self.num_local_experts = num_local_experts
+        self.output_router_logits = output_router_logits
+        self.router_aux_loss_coef = router_aux_loss_coef
+        self.router_jitter_noise = router_jitter_noise
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            tie_word_embeddings=tie_word_embeddings,
+            **kwargs,
+        )
+
+
+class mp(torch.autograd.Function):
+
+    @staticmethod
+    def forward(
+        ctx,
+        scores: torch.Tensor,
+        multiplier: torch.Tensor,
+        selected_experts: torch.Tensor,
+        masked_gates: torch.Tensor,
+        mask_for_one: torch.Tensor,
+    ):
+        ctx.save_for_backward(multiplier, selected_experts, masked_gates)
+        return multiplier * mask_for_one
+
+    @staticmethod
+    def backward(
+        ctx,
+        grad_at_output: torch.Tensor,
+    ):
+        multiplier, selected_experts, masked_gates = ctx.saved_tensors
+
+        grad_at_output = grad_at_output * multiplier
+
+        grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
+        grad_at_scores_expaned.scatter_add_(
+            dim=-1,
+            index=selected_experts,
+            src=grad_at_output,
+        )
+
+        return (
+            grad_at_scores_expaned,
+            None,
+            None,
+            None,
+            None,
+        )
+
+
+def sparsemixer(scores, jitter_eps=0.01):
+    ################ first expert ################
+
+    with torch.no_grad():
+        # compute mask for sparsity
+        mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
+        factor = scores.abs().clamp(min=mask_logits_threshold)
+        mask_logits_threshold = (
+            (mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
+
+    # apply mask
+    masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf"))
+    selected_experts = max_ind
+
+    # compute scores for gradients
+    masked_gates = torch.softmax(masked_gates, dim=-1)
+    multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
+
+    multiplier = multiplier_o
+
+    # masked out first expert
+    masked_scores = torch.scatter(
+        scores,
+        -1,
+        selected_experts,
+        float("-inf"),
+    )
+    with torch.no_grad():
+        # compute mask for sparsity
+        mask_logits_threshold, max_ind = masked_scores.max(dim=-1,
+                                                           keepdim=True)
+        factor = scores.abs().clamp(min=mask_logits_threshold)
+        mask_logits_threshold = (
+            (mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
+
+    # apply mask
+    masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold,
+                                                  float("-inf"))
+    selected_experts_top2 = max_ind
+    # compute scores for gradients
+    masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
+    multiplier_top2 = masked_gates_top2.gather(dim=-1,
+                                               index=selected_experts_top2)
+
+    multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
+    selected_experts = torch.concat((selected_experts, selected_experts_top2),
+                                    dim=-1)
+
+    return (
+        multiplier,
+        selected_experts,
+    )
+
+
+def phimoe_routing_function(
+    hidden_states: torch.Tensor,
+    gating_output: torch.Tensor,
+    topk: int,
+    renormalize: bool,
+):
+    assert hidden_states.shape[0] == gating_output.shape[0], (
+        "Number of tokens mismatch")
+    assert topk == 2, "Only top-2 routing is supported"
+    assert renormalize is False, "Renormalization is not supported"
+
+    topk_weights, topk_ids = sparsemixer(gating_output)
+    return topk_weights, topk_ids
+
+
+class PhiMoE(nn.Module):
+    """A tensor-parallel MoE implementation for PhiMoE that shards each expert
+    across all ranks.
+
+    Each expert's weights are sharded across all ranks and a fused MoE
+    kernel is used for the forward pass, and finally we reduce the outputs
+    across ranks.
+    """
+
+    def __init__(
+        self,
+        num_experts: int,
+        top_k: int,
+        hidden_size: int,
+        intermediate_size: int,
+        params_dtype: Optional[torch.dtype] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        tp_size: Optional[int] = None,
+    ):
+        super().__init__()
+        self.hidden_size = hidden_size
+
+        # Gate always runs at half / full precision for now.
+        self.gate = ReplicatedLinear(
+            hidden_size,
+            num_experts,
+            bias=False,
+            params_dtype=params_dtype,
+            quant_config=None,
+        )
+
+        self.experts = FusedMoE(
+            num_experts=num_experts,
+            top_k=top_k,
+            hidden_size=hidden_size,
+            intermediate_size=intermediate_size,
+            params_dtype=params_dtype,
+            reduce_results=True,
+            renormalize=False,
+            quant_config=quant_config,
+            tp_size=tp_size,
+            custom_routing_function=phimoe_routing_function)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        # NOTE: hidden_states can have either 1D or 2D shape.
+        orig_shape = hidden_states.shape
+        hidden_states = hidden_states.view(-1, self.hidden_size)
+        # router_logits: (num_tokens, n_experts)
+        router_logits, _ = self.gate(hidden_states)
+        final_hidden_states = self.experts(hidden_states, router_logits)
+        return final_hidden_states.view(orig_shape)
+
+
+class PhiMoEAttention(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        max_position: int = 4096 * 32,
+        rope_theta: float = 10000,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        rope_scaling: Optional[dict] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = num_kv_heads
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        self.head_dim = hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+        self.rope_theta = rope_theta
+        self.rope_scaling = rope_scaling
+
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=True,
+            quant_config=None,
+        )
+        self.o_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            hidden_size,
+            bias=True,
+            quant_config=None,
+        )
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=max_position,
+            base=int(self.rope_theta),
+            is_neox_style=True,
+            rope_scaling=self.rope_scaling,
+        )
+        self.attn = Attention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_kv_heads,
+            cache_config=cache_config,
+            quant_config=quant_config,
+        )
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.rotary_emb(positions, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class PhiMoEDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: PhiMoEConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        # Requires transformers > 4.32.0
+        rope_theta = getattr(config, "rope_theta", 10000)
+        self.self_attn = PhiMoEAttention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            max_position=config.max_position_embeddings,
+            num_kv_heads=config.num_key_value_heads,
+            rope_theta=rope_theta,
+            cache_config=cache_config,
+            quant_config=quant_config,
+            rope_scaling=config.rope_scaling,
+        )
+        self.block_sparse_moe = PhiMoE(
+            num_experts=config.num_local_experts,
+            top_k=config.num_experts_per_tok,
+            hidden_size=config.hidden_size,
+            intermediate_size=config.intermediate_size,
+            quant_config=quant_config,
+        )
+        self.input_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.rms_norm_eps,
+                                            elementwise_affine=True)
+        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
+                                                     eps=config.rms_norm_eps,
+                                                     elementwise_affine=True)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        residual = hidden_states
+
+        # Self Attention
+        hidden_states = self.input_layernorm(hidden_states)
+
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+        hidden_states = hidden_states + residual
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.block_sparse_moe(hidden_states)
+
+        hidden_states = hidden_states + residual
+        return hidden_states, residual
+
+
+class PhiMoEModel(nn.Module):
+
+    def __init__(
+        self,
+        config: PhiMoEConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.padding_idx = config.pad_token_id
+        lora_vocab = ((lora_config.lora_extra_vocab_size *
+                       (lora_config.max_loras or 1)) if lora_config else 0)
+        self.vocab_size = config.vocab_size + lora_vocab
+        self.org_vocab_size = config.vocab_size
+
+        self.embed_tokens = VocabParallelEmbedding(
+            self.vocab_size,
+            config.hidden_size,
+            org_num_embeddings=config.vocab_size,
+        )
+        self.layers = nn.ModuleList([
+            PhiMoEDecoderLayer(config, cache_config, quant_config=quant_config)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.norm = nn.LayerNorm(config.hidden_size,
+                                 eps=config.rms_norm_eps,
+                                 elementwise_affine=True)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+        residual = None
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states, residual = layer(positions, hidden_states,
+                                            kv_caches[i], attn_metadata,
+                                            residual)
+        hidden_states = self.norm(hidden_states)
+        return hidden_states
+
+
+class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
+    fall_back_to_pt_during_load = False
+
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+    }
+
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj",
+        "o_proj",
+        "embed_tokens",
+        "lm_head",
+    ]
+    embedding_modules = {
+        "embed_tokens": "input_embeddings",
+        "lm_head": "output_embeddings",
+    }
+    embedding_padding_modules = ["lm_head"]
+
+    def __init__(
+        self,
+        config: PhiMoEConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
+    ) -> None:
+        super().__init__()
+
+        self.config = config
+        self.lora_config = lora_config
+
+        self.model = PhiMoEModel(config,
+                                 cache_config,
+                                 quant_config,
+                                 lora_config=lora_config)
+        self.unpadded_vocab_size = config.vocab_size
+        if lora_config:
+            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+        self.lm_head = ParallelLMHead(
+            self.unpadded_vocab_size,
+            config.hidden_size,
+            org_num_embeddings=config.vocab_size,
+            padding_size=(
+                DEFAULT_VOCAB_PADDING_SIZE
+                # We need bigger padding if using lora for kernel
+                # compatibility
+                if not lora_config else lora_config.lora_vocab_padding_size),
+            quant_config=None,
+            bias=True,
+        )
+        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+                                                config.vocab_size)
+        self.sampler = Sampler()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   attn_metadata)
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: Optional[torch.Tensor],
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+        ]
+
+        expert_params_mapping = FusedMoE.make_expert_params_mapping(
+            ckpt_gate_proj_name="w1",
+            ckpt_down_proj_name="w2",
+            ckpt_up_proj_name="w3",
+            num_experts=self.config.num_local_experts)
+
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+
+            for param_name, weight_name, shard_id in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                for mapping in expert_params_mapping:
+                    param_name, weight_name, expert_id, shard_id = mapping
+                    if weight_name not in name:
+                        continue
+                    name = name.replace(weight_name, param_name)
+                    param = params_dict[name]
+                    weight_loader = param.weight_loader
+                    weight_loader(
+                        param,
+                        loaded_weight,
+                        name,
+                        shard_id=shard_id,
+                        expert_id=expert_id,
+                    )
+                    break
+                else:
+                    # Skip loading extra bias for GPTQ models.
+                    if name.endswith(".bias") and name not in params_dict:
+                        continue
+                    # Remapping the name of FP8 kv-scale.
+                    name = maybe_remap_kv_scale_name(name, params_dict)
+                    if name is None:
+                        continue
+
+                    param = params_dict[name]
+                    weight_loader = getattr(param, "weight_loader",
+                                            default_weight_loader)
+                    weight_loader(param, loaded_weight)

+ 14 - 10
aphrodite/quantization/compressed_tensors/compressed_tensors_moe.py

@@ -1,6 +1,6 @@
 import enum
 from enum import Enum
-from typing import List, Optional
+from typing import Callable, List, Optional
 
 import torch
 
@@ -255,15 +255,18 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
         )
         replace_tensor("w2_weight_scale", marlin_w2_scales)
 
-    def apply(self,
-              layer: torch.nn.Module,
-              x: torch.Tensor,
-              router_logits: torch.Tensor,
-              top_k: int,
-              renormalize: bool = True,
-              use_grouped_topk: bool = False,
-              num_expert_group: Optional[int] = None,
-              topk_group: Optional[int] = None) -> torch.Tensor:
+    def apply(
+        self,
+        layer: torch.nn.Module,
+        x: torch.Tensor,
+        router_logits: torch.Tensor,
+        top_k: int,
+        renormalize: bool = True,
+        use_grouped_topk: bool = False,
+        num_expert_group: Optional[int] = None,
+        topk_group: Optional[int] = None,
+        custom_routing_function: Optional[Callable] = None,
+    ) -> torch.Tensor:
 
         from aphrodite.modeling.layers.fused_moe.fused_moe import (
             fused_marlin_moe)
@@ -277,6 +280,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
                                 layer.w13_g_idx_sort_indices,
                                 layer.w2_g_idx_sort_indices,
                                 top_k,
+                                custom_routing_function,
                                 renormalize=renormalize,
                                 w1_scale=layer.w13_weight_scale,
                                 w2_scale=layer.w2_weight_scale)

+ 15 - 11
aphrodite/quantization/experts_int8.py

@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional
 
 import torch
 
@@ -96,15 +96,18 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
                                       requires_grad=False)
         layer.register_parameter("w2_scale", w2_scale)
 
-    def apply(self,
-              layer: torch.nn.Module,
-              x: torch.Tensor,
-              router_logits: torch.Tensor,
-              top_k: int,
-              renormalize: bool = True,
-              use_grouped_topk: bool = False,
-              num_expert_group: Optional[int] = None,
-              topk_group: Optional[int] = None) -> torch.Tensor:
+    def apply(
+        self,
+        layer: torch.nn.Module,
+        x: torch.Tensor,
+        router_logits: torch.Tensor,
+        top_k: int,
+        renormalize: bool = True,
+        use_grouped_topk: bool = False,
+        num_expert_group: Optional[int] = None,
+        topk_group: Optional[int] = None,
+        custom_routing_function: Optional[Callable] = None,
+    ) -> torch.Tensor:
         from aphrodite.modeling.layers.fused_moe import fused_experts
 
         topk_weights, topk_ids = FusedMoE.select_experts(
@@ -114,7 +117,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
             top_k=top_k,
             renormalize=renormalize,
             topk_group=topk_group,
-            num_expert_group=num_expert_group)
+            num_expert_group=num_expert_group,
+            custom_routing_function=custom_routing_function)
 
         return fused_experts(x,
                              layer.w13_weight,

+ 15 - 11
aphrodite/quantization/fp8.py

@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional
 
 import torch
 from loguru import logger
@@ -465,15 +465,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
                                                         requires_grad=False)
             return
 
-    def apply(self,
-              layer: torch.nn.Module,
-              x: torch.Tensor,
-              router_logits: torch.Tensor,
-              top_k: int,
-              renormalize: bool,
-              use_grouped_topk: bool,
-              topk_group: Optional[int] = None,
-              num_expert_group: Optional[int] = None) -> torch.Tensor:
+    def apply(
+        self,
+        layer: torch.nn.Module,
+        x: torch.Tensor,
+        router_logits: torch.Tensor,
+        top_k: int,
+        renormalize: bool,
+        use_grouped_topk: bool,
+        topk_group: Optional[int] = None,
+        num_expert_group: Optional[int] = None,
+        custom_routing_function: Optional[Callable] = None,
+    ) -> torch.Tensor:
 
         from aphrodite.modeling.layers.fused_moe import fused_experts
 
@@ -484,7 +487,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
             top_k=top_k,
             renormalize=renormalize,
             topk_group=topk_group,
-            num_expert_group=num_expert_group)
+            num_expert_group=num_expert_group,
+            custom_routing_function=custom_routing_function)
 
         return fused_experts(x,
                              layer.w13_weight,