from abc import abstractmethod from enum import Enum from typing import Callable, List, Optional, Tuple import torch from aphrodite.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from aphrodite.modeling._custom_op import CustomOp from aphrodite.modeling.utils import set_weight_attrs from aphrodite.quantization.base_config import (QuantizationConfig, QuantizeMethodBase) class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" GROUP = "group" class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError @abstractmethod def apply(self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool) -> torch.Tensor: raise NotImplementedError class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter(torch.empty(num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) w2_weight = torch.nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size, dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None ) -> torch.Tensor: return self.forward(x=x, layer=layer, router_logits=router_logits, top_k=top_k, renormalize=renormalize, use_grouped_topk=use_grouped_topk, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function) def forward_cuda( self, layer: torch.nn.Module, x: torch.Tensor, use_grouped_topk: bool, top_k: int, router_logits: torch.Tensor, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None ) -> torch.Tensor: from aphrodite.modeling.layers.fused_moe.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function) return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True) def forward_cpu(self, *args, **kwargs): raise NotImplementedError( "The CPU backend currently does not support MoE.") def forward_tpu( self, layer: torch.nn.Module, x: torch.Tensor, use_grouped_topk: bool, top_k: int, router_logits: torch.Tensor, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None ) -> torch.Tensor: from aphrodite.modeling.layers.fused_moe.moe_pallas import fused_moe assert not use_grouped_topk assert num_expert_group is None assert topk_group is None assert custom_routing_function is None return fused_moe(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk=top_k, gating_output=router_logits, renormalize=renormalize) class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. This layer contains both MergedColumnParallel weights (gate_up_proj / w13) and RowParallelLinear weights (down_proj/ w2). Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We copy that naming convention here and handle any remapping in the load_weights function in each model implementation. Args: num_experts: Number of experts in the model top_k: Number of experts selected for each token hidden_size: Input hidden state size of the transformer intermediate_size: Intermediate size of the experts params_dtype: Data type for the parameters. reduce_results: Whether to all all_reduce on the output of the layer renomalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. """ def __init__( self, num_experts: int, top_k: int, hidden_size: int, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = False, renormalize: bool = True, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", custom_routing_function: Optional[Callable] = None, ): super().__init__() if params_dtype is None: params_dtype = torch.get_default_dtype() self.tp_size = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()) self.top_k = top_k self.num_experts = num_experts self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize self.use_grouped_topk = use_grouped_topk if self.use_grouped_topk: assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group self.topk_group = topk_group self.custom_routing_function = custom_routing_function if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( UnquantizedFusedMoEMethod()) else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None self.quant_method.create_weights( layer=self, num_experts=num_experts, hidden_size=hidden_size, intermediate_size=self.intermediate_size_per_partition, params_dtype=params_dtype, weight_loader=self.weight_loader) def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int): param_data = param.data # for per tensor weight quantization if shard_id in ("w1", "w3"): # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. idx = 0 if shard_id == "w1" else 1 param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) elif shard_id == "w2": param_data[expert_id] = loaded_weight def _load_model_weight_or_group_weight_scale(self, shard_dim: int, expert_data: torch.Tensor, shard_id: str, loaded_weight: torch.tensor, tp_rank: int): # Load grouped weight scales for group quantization # or model weights if shard_id == "w2": self._load_w2(shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) elif shard_id in ("w1", "w3"): self._load_w13(shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, loaded_weight: torch.tensor, tp_rank: int): # for per channel weight quantization if shard_id == "w2": expert_data.copy_(loaded_weight) elif shard_id in ("w1", "w3"): self._load_w13(shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, loaded_weight: torch.tensor, tp_rank: int): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": expert_data = expert_data.narrow(shard_dim, 0, shard_size) # w3, up_proj: Load into second logical weight of w13. else: assert shard_id == "w3" expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data.copy_(loaded_weight) def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, loaded_weight: torch.tensor, tp_rank: int): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) def _load_single_value(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int): param_data = param.data # Input scales can be loaded directly and should be equal. param_data[expert_id] = loaded_weight def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}.") WEIGHT_SCALE_SUPPORTED = [ e.value for e in FusedMoeWeightScaleSupported ] # Fetch the dim to shard the parameter/loaded weight # based on the shard id. This will be whatever # dimension intermediate_size is used. SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} expert_data = param.data[expert_id] tp_rank = get_tensor_model_parallel_rank() # is_transposed: whether or not the parameter is transposed on disk # If transposed, the loaded weight will be transposed and the dim # to shard the loaded weight will be flipped. is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: loaded_weight = loaded_weight.t().contiguous() shard_dim = ~shard_dim # Case weight_scales if "weight_scale" in weight_name: # load the weight scaling based on the quantization scheme # supported weight scales can be found in # FusedMoeWeightScaleSupported # TODO @dsikka: once hardened, refactor to use vLLM Parameters # specific to each case quant_method = getattr(param, "quant_method", None) if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: self._load_per_channel_weight_scale( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: self._load_per_tensor_weight_scale(shard_id=shard_id, param=param, loaded_weight=loaded_weight, expert_id=expert_id) else: raise ValueError( f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") return if "weight_shape" in weight_name: self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) return # Case input scale if "input_scale" in weight_name: # Note: input_scale loading is only supported for fp8 if param.data[expert_id] != 1 and (param.data[expert_id] - loaded_weight).abs() > 1e-5: raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param.data[expert_id]} " f"vs. {loaded_weight}") self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) return # Case model weights if "weight" in weight_name: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) return @staticmethod def select_experts(hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, use_grouped_topk: bool, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None): from aphrodite.modeling.layers.fused_moe.fused_moe import ( fused_topk, grouped_topk) # DeekSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None topk_weights, topk_ids = grouped_topk( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize, num_expert_group=num_expert_group, topk_group=topk_group) elif custom_routing_function is None: topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize) else: topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize) return topk_weights, topk_ids def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, x=hidden_states, router_logits=router_logits, top_k=self.top_k, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) return final_hidden_states @classmethod def make_expert_params_mapping( cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, ckpt_up_proj_name: str, num_experts: int) -> List[Tuple[str, str, int, str]]: return [ # (param_name, weight_name, expert_id, shard_id) ("experts.w13_" if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) for expert_id in range(num_experts) for shard_id, weight_name in [ ("w1", ckpt_gate_proj_name), ("w2", ckpt_down_proj_name), ("w3", ckpt_up_proj_name), ] ] def _load_fp8_scale(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: param_data = param.data # Input scales can be loaded directly and should be equal. if "input_scale" in weight_name: if param_data[expert_id] != 1 and (param_data[expert_id] - loaded_weight).abs() > 1e-5: raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param_data[expert_id]} " f"vs. {loaded_weight}") param_data[expert_id] = loaded_weight # Weight scales elif "weight_scale" in weight_name: # If we are in merged column case (gate_up_proj) if shard_id in ("w1", "w3"): # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. idx = 0 if shard_id == "w1" else 1 param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) else: param_data[expert_id] = loaded_weight