123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497 |
- from abc import abstractmethod
- from enum import Enum
- from typing import Callable, List, Optional, Tuple
- import torch
- from aphrodite.distributed import (get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
- tensor_model_parallel_all_reduce)
- from aphrodite.modeling._custom_op import CustomOp
- from aphrodite.modeling.utils import set_weight_attrs
- from aphrodite.quantization.base_config import (QuantizationConfig,
- QuantizeMethodBase)
- class FusedMoeWeightScaleSupported(Enum):
- TENSOR = "tensor"
- CHANNEL = "channel"
- GROUP = "group"
- class FusedMoEMethodBase(QuantizeMethodBase):
- @abstractmethod
- def create_weights(self, layer: torch.nn.Module, num_experts: int,
- hidden_size: int, intermediate_size: int,
- params_dtype: torch.dtype, **extra_weight_attrs):
- raise NotImplementedError
- @abstractmethod
- def apply(self, layer: torch.nn.Module, x: torch.Tensor,
- router_logits: torch.Tensor, top_k: int, renormalize: bool,
- use_grouped_topk: bool) -> torch.Tensor:
- raise NotImplementedError
- class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
- """MoE method without quantization."""
- def create_weights(self, layer: torch.nn.Module, num_experts: int,
- hidden_size: int, intermediate_size: int,
- params_dtype: torch.dtype, **extra_weight_attrs):
- # Fused gate_up_proj (column parallel)
- w13_weight = torch.nn.Parameter(torch.empty(num_experts,
- 2 * intermediate_size,
- hidden_size,
- dtype=params_dtype),
- requires_grad=False)
- layer.register_parameter("w13_weight", w13_weight)
- set_weight_attrs(w13_weight, extra_weight_attrs)
- # down_proj (row parallel)
- w2_weight = torch.nn.Parameter(torch.empty(num_experts,
- hidden_size,
- intermediate_size,
- dtype=params_dtype),
- requires_grad=False)
- layer.register_parameter("w2_weight", w2_weight)
- set_weight_attrs(w2_weight, extra_weight_attrs)
- def apply(
- self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- router_logits: torch.Tensor,
- top_k: int,
- renormalize: bool,
- use_grouped_topk: bool,
- topk_group: Optional[int] = None,
- num_expert_group: Optional[int] = None,
- custom_routing_function: Optional[Callable] = None
- ) -> torch.Tensor:
- return self.forward(x=x,
- layer=layer,
- router_logits=router_logits,
- top_k=top_k,
- renormalize=renormalize,
- use_grouped_topk=use_grouped_topk,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function)
- def forward_cuda(
- self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- use_grouped_topk: bool,
- top_k: int,
- router_logits: torch.Tensor,
- renormalize: bool,
- topk_group: Optional[int] = None,
- num_expert_group: Optional[int] = None,
- custom_routing_function: Optional[Callable] = None
- ) -> torch.Tensor:
- from aphrodite.modeling.layers.fused_moe.fused_moe import fused_experts
- topk_weights, topk_ids = FusedMoE.select_experts(
- hidden_states=x,
- router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function)
- return fused_experts(hidden_states=x,
- w1=layer.w13_weight,
- w2=layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True)
- def forward_cpu(self, *args, **kwargs):
- raise NotImplementedError(
- "The CPU backend currently does not support MoE.")
- def forward_tpu(
- self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- use_grouped_topk: bool,
- top_k: int,
- router_logits: torch.Tensor,
- renormalize: bool,
- topk_group: Optional[int] = None,
- num_expert_group: Optional[int] = None,
- custom_routing_function: Optional[Callable] = None
- ) -> torch.Tensor:
- from aphrodite.modeling.layers.fused_moe.moe_pallas import fused_moe
- assert not use_grouped_topk
- assert num_expert_group is None
- assert topk_group is None
- assert custom_routing_function is None
- return fused_moe(hidden_states=x,
- w1=layer.w13_weight,
- w2=layer.w2_weight,
- topk=top_k,
- gating_output=router_logits,
- renormalize=renormalize)
- class FusedMoE(torch.nn.Module):
- """FusedMoE layer for MoE models.
- This layer contains both MergedColumnParallel weights (gate_up_proj /
- w13) and RowParallelLinear weights (down_proj/ w2).
- Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
- copy that naming convention here and handle any remapping in the
- load_weights function in each model implementation.
- Args:
- num_experts: Number of experts in the model
- top_k: Number of experts selected for each token
- hidden_size: Input hidden state size of the transformer
- intermediate_size: Intermediate size of the experts
- params_dtype: Data type for the parameters.
- reduce_results: Whether to all all_reduce on the output of the layer
- renomalize: Whether to renormalize the logits in the fused_moe kernel
- quant_config: Quantization configure.
- """
- def __init__(
- self,
- num_experts: int,
- top_k: int,
- hidden_size: int,
- intermediate_size: int,
- params_dtype: Optional[torch.dtype] = None,
- reduce_results: bool = False,
- renormalize: bool = True,
- use_grouped_topk: bool = False,
- num_expert_group: Optional[int] = None,
- topk_group: Optional[int] = None,
- quant_config: Optional[QuantizationConfig] = None,
- tp_size: Optional[int] = None,
- prefix: str = "",
- custom_routing_function: Optional[Callable] = None,
- ):
- super().__init__()
- if params_dtype is None:
- params_dtype = torch.get_default_dtype()
- self.tp_size = (tp_size if tp_size is not None else
- get_tensor_model_parallel_world_size())
- self.top_k = top_k
- self.num_experts = num_experts
- self.intermediate_size_per_partition = intermediate_size // self.tp_size
- self.reduce_results = reduce_results
- self.renormalize = renormalize
- self.use_grouped_topk = use_grouped_topk
- if self.use_grouped_topk:
- assert num_expert_group is not None and topk_group is not None
- self.num_expert_group = num_expert_group
- self.topk_group = topk_group
- self.custom_routing_function = custom_routing_function
- if quant_config is None:
- self.quant_method: Optional[QuantizeMethodBase] = (
- UnquantizedFusedMoEMethod())
- else:
- self.quant_method = quant_config.get_quant_method(self, prefix)
- assert self.quant_method is not None
- self.quant_method.create_weights(
- layer=self,
- num_experts=num_experts,
- hidden_size=hidden_size,
- intermediate_size=self.intermediate_size_per_partition,
- params_dtype=params_dtype,
- weight_loader=self.weight_loader)
- def _load_per_tensor_weight_scale(self, shard_id: str,
- param: torch.nn.Parameter,
- loaded_weight: torch.Tensor,
- expert_id: int):
- param_data = param.data
- # for per tensor weight quantization
- if shard_id in ("w1", "w3"):
- # We have to keep the weight scales of w1 and w3 because
- # we need to re-quantize w1/w3 weights after weight loading.
- idx = 0 if shard_id == "w1" else 1
- param_data[expert_id][idx] = loaded_weight
- # If we are in the row parallel case (down_proj)
- elif shard_id == "w2":
- param_data[expert_id] = loaded_weight
- def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
- expert_data: torch.Tensor,
- shard_id: str,
- loaded_weight: torch.tensor,
- tp_rank: int):
- # Load grouped weight scales for group quantization
- # or model weights
- if shard_id == "w2":
- self._load_w2(shard_id=shard_id,
- shard_dim=shard_dim,
- loaded_weight=loaded_weight,
- expert_data=expert_data,
- tp_rank=tp_rank)
- elif shard_id in ("w1", "w3"):
- self._load_w13(shard_id=shard_id,
- shard_dim=shard_dim,
- loaded_weight=loaded_weight,
- expert_data=expert_data,
- tp_rank=tp_rank)
- def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
- shard_dim: int, shard_id: str,
- loaded_weight: torch.tensor,
- tp_rank: int):
- # for per channel weight quantization
- if shard_id == "w2":
- expert_data.copy_(loaded_weight)
- elif shard_id in ("w1", "w3"):
- self._load_w13(shard_id=shard_id,
- shard_dim=shard_dim,
- loaded_weight=loaded_weight,
- expert_data=expert_data,
- tp_rank=tp_rank)
- def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
- shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
- # Index the loaded weight for tp sharding.
- # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
- shard_size = expert_data.shape[shard_dim] // 2
- loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
- shard_size)
- # Narrow parameter and load.
- # w1, gate_proj: Load into first logical weight of w13.
- if shard_id == "w1":
- expert_data = expert_data.narrow(shard_dim, 0, shard_size)
- # w3, up_proj: Load into second logical weight of w13.
- else:
- assert shard_id == "w3"
- expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
- expert_data.copy_(loaded_weight)
- def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
- shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
- # Index the loaded weight for tp sharding.
- # down_proj: "RowParallel" so tp sharding on input_dim
- # Narrow parameter and load.
- shard_size = expert_data.shape[shard_dim]
- loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
- shard_size)
- # w2, down_proj: Load into only logical weight of w2.
- expert_data.copy_(loaded_weight)
- def _load_single_value(self, param: torch.nn.Parameter,
- loaded_weight: torch.Tensor, expert_id: int):
- param_data = param.data
- # Input scales can be loaded directly and should be equal.
- param_data[expert_id] = loaded_weight
- def weight_loader(self, param: torch.nn.Parameter,
- loaded_weight: torch.Tensor, weight_name: str,
- shard_id: str, expert_id: int) -> None:
- if shard_id not in ("w1", "w2", "w3"):
- raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
- f"got {shard_id}.")
- WEIGHT_SCALE_SUPPORTED = [
- e.value for e in FusedMoeWeightScaleSupported
- ]
- # Fetch the dim to shard the parameter/loaded weight
- # based on the shard id. This will be whatever
- # dimension intermediate_size is used.
- SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
- expert_data = param.data[expert_id]
- tp_rank = get_tensor_model_parallel_rank()
- # is_transposed: whether or not the parameter is transposed on disk
- # If transposed, the loaded weight will be transposed and the dim
- # to shard the loaded weight will be flipped.
- is_transposed = getattr(param, "is_transposed", False)
- shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
- if is_transposed:
- loaded_weight = loaded_weight.t().contiguous()
- shard_dim = ~shard_dim
- # Case weight_scales
- if "weight_scale" in weight_name:
- # load the weight scaling based on the quantization scheme
- # supported weight scales can be found in
- # FusedMoeWeightScaleSupported
- # TODO @dsikka: once hardened, refactor to use vLLM Parameters
- # specific to each case
- quant_method = getattr(param, "quant_method", None)
- if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
- self._load_per_channel_weight_scale(
- shard_id=shard_id,
- shard_dim=shard_dim,
- loaded_weight=loaded_weight,
- expert_data=expert_data,
- tp_rank=tp_rank)
- elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
- self._load_model_weight_or_group_weight_scale(
- shard_id=shard_id,
- shard_dim=shard_dim,
- loaded_weight=loaded_weight,
- expert_data=expert_data,
- tp_rank=tp_rank)
- elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
- self._load_per_tensor_weight_scale(shard_id=shard_id,
- param=param,
- loaded_weight=loaded_weight,
- expert_id=expert_id)
- else:
- raise ValueError(
- f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
- return
- if "weight_shape" in weight_name:
- self._load_single_value(param=param,
- loaded_weight=loaded_weight,
- expert_id=expert_id)
- return
- # Case input scale
- if "input_scale" in weight_name:
- # Note: input_scale loading is only supported for fp8
- if param.data[expert_id] != 1 and (param.data[expert_id] -
- loaded_weight).abs() > 1e-5:
- raise ValueError(
- "input_scales of w1 and w3 of a layer "
- f"must be equal. But got {param.data[expert_id]} "
- f"vs. {loaded_weight}")
- self._load_single_value(param=param,
- loaded_weight=loaded_weight,
- expert_id=expert_id)
- return
- # Case model weights
- if "weight" in weight_name:
- self._load_model_weight_or_group_weight_scale(
- shard_id=shard_id,
- shard_dim=shard_dim,
- loaded_weight=loaded_weight,
- expert_data=expert_data,
- tp_rank=tp_rank)
- return
- @staticmethod
- def select_experts(hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
- top_k: int,
- use_grouped_topk: bool,
- renormalize: bool,
- topk_group: Optional[int] = None,
- num_expert_group: Optional[int] = None,
- custom_routing_function: Optional[Callable] = None):
- from aphrodite.modeling.layers.fused_moe.fused_moe import (
- fused_topk, grouped_topk)
- # DeekSeekv2 uses grouped_top_k
- if use_grouped_topk:
- assert topk_group is not None
- assert num_expert_group is not None
- topk_weights, topk_ids = grouped_topk(
- hidden_states=hidden_states,
- gating_output=router_logits,
- topk=top_k,
- renormalize=renormalize,
- num_expert_group=num_expert_group,
- topk_group=topk_group)
- elif custom_routing_function is None:
- topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
- gating_output=router_logits,
- topk=top_k,
- renormalize=renormalize)
- else:
- topk_weights, topk_ids = custom_routing_function(
- hidden_states=hidden_states,
- gating_output=router_logits,
- topk=top_k,
- renormalize=renormalize)
- return topk_weights, topk_ids
- def forward(self, hidden_states: torch.Tensor,
- router_logits: torch.Tensor):
- assert self.quant_method is not None
- # Matrix multiply.
- final_hidden_states = self.quant_method.apply(
- layer=self,
- x=hidden_states,
- router_logits=router_logits,
- top_k=self.top_k,
- renormalize=self.renormalize,
- use_grouped_topk=self.use_grouped_topk,
- topk_group=self.topk_group,
- num_expert_group=self.num_expert_group,
- custom_routing_function=self.custom_routing_function)
- if self.reduce_results and self.tp_size > 1:
- final_hidden_states = tensor_model_parallel_all_reduce(
- final_hidden_states)
- return final_hidden_states
- @classmethod
- def make_expert_params_mapping(
- cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
- ckpt_up_proj_name: str,
- num_experts: int) -> List[Tuple[str, str, int, str]]:
- return [
- # (param_name, weight_name, expert_id, shard_id)
- ("experts.w13_" if weight_name
- in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
- f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
- for expert_id in range(num_experts) for shard_id, weight_name in [
- ("w1", ckpt_gate_proj_name),
- ("w2", ckpt_down_proj_name),
- ("w3", ckpt_up_proj_name),
- ]
- ]
- def _load_fp8_scale(self, param: torch.nn.Parameter,
- loaded_weight: torch.Tensor, weight_name: str,
- shard_id: str, expert_id: int) -> None:
- param_data = param.data
- # Input scales can be loaded directly and should be equal.
- if "input_scale" in weight_name:
- if param_data[expert_id] != 1 and (param_data[expert_id] -
- loaded_weight).abs() > 1e-5:
- raise ValueError(
- "input_scales of w1 and w3 of a layer "
- f"must be equal. But got {param_data[expert_id]} "
- f"vs. {loaded_weight}")
- param_data[expert_id] = loaded_weight
- # Weight scales
- elif "weight_scale" in weight_name:
- # If we are in merged column case (gate_up_proj)
- if shard_id in ("w1", "w3"):
- # We have to keep the weight scales of w1 and w3 because
- # we need to re-quantize w1/w3 weights after weight loading.
- idx = 0 if shard_id == "w1" else 1
- param_data[expert_id][idx] = loaded_weight
- # If we are in the row parallel case (down_proj)
- else:
- param_data[expert_id] = loaded_weight
|