Переглянути джерело

moe: refactor DBRX experts to support FusedMoE (#1095)

AlpinDale 1 місяць тому
батько
коміт
b65449b5ad
1 змінених файлів з 54 додано та 69 видалено
  1. 54 69
      aphrodite/modeling/models/dbrx.py

+ 54 - 69
aphrodite/modeling/models/dbrx.py

@@ -8,9 +8,8 @@ from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size,
-                                   tensor_model_parallel_all_reduce)
-from aphrodite.modeling.layers.fused_moe import fused_moe
+                                   get_tensor_model_parallel_world_size)
+from aphrodite.modeling.layers.fused_moe import FusedMoE
 from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
@@ -21,7 +20,6 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.transformers_utils.configs.dbrx import DbrxConfig
 
@@ -53,13 +51,7 @@ class DbrxRouter(nn.Module):
         return router_logits
 
 
-class DbrxExperts(nn.Module):
-    """A tensor-parallel MoE implementation for DBRX.
-
-    Each expert's weights are sharded across all ranks and a fused MoE
-    kernel is used for the forward pass, and finally we reduce the outputs
-    across ranks.
-    """
+class DbrxExperts(FusedMoE):
 
     def __init__(
         self,
@@ -67,49 +59,24 @@ class DbrxExperts(nn.Module):
         quant_config: Optional[QuantizationConfig] = None,
         params_dtype: Optional[torch.dtype] = None,
     ):
-        super().__init__()
+        super().__init__(
+            num_experts=config.ffn_config.moe_num_experts,
+            top_k=config.ffn_config.moe_top_k,
+            hidden_size=config.d_model,
+            intermediate_size=config.ffn_config.ffn_hidden_size,
+            params_dtype=params_dtype,
+            reduce_results=True,
+            renormalize=True,
+            quant_config=quant_config,
+            tp_size=get_tensor_model_parallel_world_size(),
+        )
+        self.config = config
         self.tp_size = get_tensor_model_parallel_world_size()
-        self.num_total_experts = config.ffn_config.moe_num_experts
-        self.top_k = config.ffn_config.moe_top_k
         self.d_model = config.d_model
-        self.intermediate_size = (config.ffn_config.ffn_hidden_size //
+        self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
                                   self.tp_size)
 
-        if params_dtype is None:
-            params_dtype = torch.get_default_dtype()
-        self.params_dtype = params_dtype
-
-        self.router = DbrxRouter(config, self.params_dtype)
-        self.ws = nn.Parameter(
-            torch.empty(
-                self.num_total_experts,
-                2 * self.intermediate_size,
-                self.d_model,
-                device="cuda",
-                dtype=self.params_dtype,
-            ))
-        self.w2s = nn.Parameter(
-            torch.empty(
-                self.num_total_experts,
-                self.d_model,
-                self.intermediate_size,
-                device="cuda",
-                dtype=self.params_dtype,
-            ))
-
-        set_weight_attrs(
-            self.ws,
-            {
-                "weight_loader": self.weight_loader,
-            },
-        )
-        set_weight_attrs(
-            self.w2s,
-            {
-                "weight_loader": self.weight_loader,
-            },
-        )
-
+    # Define custom weight loader for dbrx model
     def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
                       weight_name: str):
         tp_rank = get_tensor_model_parallel_rank()
@@ -139,26 +106,40 @@ class DbrxExperts(nn.Module):
             ).transpose(1, 2)
             param_data[:] = loaded_weight[:, :, shard]
 
+
+class DbrxMoE(nn.Module):
+    """A tensor-parallel MoE implementation for DBRX.
+
+    Each expert's weights are sharded across all ranks and a fused MoE
+    kernel is used for the forward pass, and finally we reduce the outputs
+    across ranks.
+    """
+
+    def __init__(
+        self,
+        config: DbrxConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        params_dtype: Optional[torch.dtype] = None,
+    ):
+        super().__init__()
+        self.d_model = config.d_model
+        if params_dtype is None:
+            params_dtype = torch.get_default_dtype()
+        self.params_dtype = params_dtype
+
+        self.router = DbrxRouter(config, self.params_dtype)
+
+        self.experts = DbrxExperts(config=config,
+                                   quant_config=quant_config,
+                                   params_dtype=self.params_dtype)
+
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
-        num_tokens, hidden_size = hidden_states.shape
+        orig_shape = hidden_states.shape
         hidden_states = hidden_states.view(-1, self.d_model)
         # router_logits: (num_tokens, n_experts)
         router_logits = self.router(hidden_states)
-        final_hidden_states = fused_moe(
-            hidden_states,
-            self.ws,
-            self.w2s,
-            router_logits,
-            self.top_k,
-            renormalize=True,
-            inplace=True,
-        )
-
-        if self.tp_size > 1:
-            final_hidden_states = tensor_model_parallel_all_reduce(
-                final_hidden_states)
-
-        return final_hidden_states.view(num_tokens, hidden_size)
+        final_hidden_states = self.experts(hidden_states, router_logits)
+        return final_hidden_states.view(orig_shape)
 
 
 class DbrxAttention(nn.Module):
@@ -287,7 +268,7 @@ class DbrxBlock(nn.Module):
         super().__init__()
         self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
                                                      quant_config)
-        self.ffn = DbrxExperts(config, quant_config)
+        self.ffn = DbrxMoE(config, quant_config)
 
     def forward(
         self,
@@ -361,6 +342,9 @@ class DbrxForCausalLM(nn.Module):
     ):
         super().__init__()
         self.config = config
+        if config.tie_word_embeddings:
+            raise ValueError(
+                "tie_word_embeddings is not supported for Dbrx models.")
         self.quant_config = quant_config
         self.unpadded_vocab_size = config.vocab_size
         self.transformer = DbrxModel(config, cache_config, quant_config)
@@ -405,9 +389,10 @@ class DbrxForCausalLM(nn.Module):
         return next_tokens
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+
         expert_params_mapping = [(
-            "ws" if weight_name in ["w1", "v1"] else "w2s",
-            f"experts.mlp.{weight_name}",
+            "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
+            f"mlp.{weight_name}",
         ) for weight_name in ["w1", "v1", "w2"]]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
         for name, loaded_weight in weights: