|
@@ -31,11 +31,10 @@ from transformers import PretrainedConfig
|
|
|
from aphrodite.attention import Attention, AttentionMetadata
|
|
|
from aphrodite.common.config import CacheConfig
|
|
|
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
|
|
|
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
|
|
|
- get_tensor_model_parallel_world_size,
|
|
|
+from aphrodite.distributed import (get_tensor_model_parallel_world_size,
|
|
|
tensor_model_parallel_all_reduce)
|
|
|
from aphrodite.modeling.layers.activation import SiluAndMul
|
|
|
-from aphrodite.modeling.layers.fused_moe import fused_experts, grouped_topk
|
|
|
+from aphrodite.modeling.layers.fused_moe import FusedMoE
|
|
|
from aphrodite.modeling.layers.layernorm import RMSNorm
|
|
|
from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
|
|
|
MergedColumnParallelLinear,
|
|
@@ -91,32 +90,34 @@ class DeepseekV2MoE(nn.Module):
|
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
- self.config = config
|
|
|
- self.rank = get_tensor_model_parallel_rank()
|
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
|
- self.n_routed_experts = config.n_routed_experts
|
|
|
- self.top_k = config.num_experts_per_tok
|
|
|
self.routed_scaling_factor = config.routed_scaling_factor
|
|
|
- if self.tp_size > self.n_routed_experts:
|
|
|
+ self.n_shared_experts = config.n_shared_experts
|
|
|
+ self.routed_scaling_factor = config.routed_scaling_factor
|
|
|
+ if self.tp_size > config.n_routed_experts:
|
|
|
raise ValueError(
|
|
|
f"Tensor parallel size {self.tp_size} is greater than "
|
|
|
- f"the number of experts {self.n_routed_experts}.")
|
|
|
-
|
|
|
- self.experts = nn.ModuleList([
|
|
|
- DeepseekV2MLP(hidden_size=config.hidden_size,
|
|
|
- intermediate_size=config.moe_intermediate_size,
|
|
|
- hidden_act=config.hidden_act,
|
|
|
- quant_config=quant_config,
|
|
|
- reduce_results=False)
|
|
|
- for idx in range(self.n_routed_experts)
|
|
|
- ])
|
|
|
- self.pack_params()
|
|
|
+ f"the number of experts {config.n_routed_experts}.")
|
|
|
+
|
|
|
+ if config.hidden_act != "silu":
|
|
|
+ raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
|
|
+ "Only silu is supported for now.")
|
|
|
+
|
|
|
+ self.experts = FusedMoE(num_experts=config.n_routed_experts,
|
|
|
+ top_k=config.num_experts_per_tok,
|
|
|
+ hidden_size=config.hidden_size,
|
|
|
+ intermediate_size=config.moe_intermediate_size,
|
|
|
+ reduce_results=False,
|
|
|
+ renormalize=config.norm_topk_prob,
|
|
|
+ quant_config=quant_config,
|
|
|
+ use_grouped_topk=True,
|
|
|
+ num_expert_group=config.n_group,
|
|
|
+ topk_group=config.topk_group)
|
|
|
|
|
|
self.gate = ReplicatedLinear(config.hidden_size,
|
|
|
- self.n_routed_experts,
|
|
|
+ config.n_routed_experts,
|
|
|
bias=False,
|
|
|
quant_config=None)
|
|
|
-
|
|
|
if config.n_shared_experts is not None:
|
|
|
intermediate_size = (config.moe_intermediate_size *
|
|
|
config.n_shared_experts)
|
|
@@ -128,50 +129,21 @@ class DeepseekV2MoE(nn.Module):
|
|
|
reduce_results=False,
|
|
|
)
|
|
|
|
|
|
- def pack_params(self):
|
|
|
- w1 = []
|
|
|
- w2 = []
|
|
|
- for expert in self.experts:
|
|
|
- w1.append(expert.gate_up_proj.weight)
|
|
|
- w2.append(expert.down_proj.weight)
|
|
|
- self.w1 = torch._utils._flatten_dense_tensors(w1)
|
|
|
- w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
|
|
|
- for data, param in zip(w1s, w1):
|
|
|
- param.data = data
|
|
|
- self.w1 = self.w1.view(len(w1), *w1s[0].shape)
|
|
|
-
|
|
|
- self.w2 = torch._utils._flatten_dense_tensors(w2)
|
|
|
- w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
|
|
|
- for data, param in zip(w2s, w2):
|
|
|
- param.data = data
|
|
|
-
|
|
|
- self.w2 = self.w2.view(len(w2), *w2s[0].shape)
|
|
|
-
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
num_tokens, hidden_dim = hidden_states.shape
|
|
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
|
- if self.config.n_shared_experts is not None:
|
|
|
+ if self.n_shared_experts is not None:
|
|
|
shared_output = self.shared_experts(hidden_states)
|
|
|
# router_logits: (num_tokens, n_experts)
|
|
|
router_logits, _ = self.gate(hidden_states)
|
|
|
- topk_weights, topk_ids = grouped_topk(
|
|
|
- hidden_states,
|
|
|
- router_logits,
|
|
|
- self.top_k,
|
|
|
- renormalize=self.config.norm_topk_prob,
|
|
|
- num_expert_group=self.config.n_group,
|
|
|
- topk_group=self.config.topk_group)
|
|
|
- final_hidden_states = fused_experts(
|
|
|
- hidden_states,
|
|
|
- self.w1,
|
|
|
- self.w2,
|
|
|
- topk_weights,
|
|
|
- topk_ids,
|
|
|
- inplace=True) * self.routed_scaling_factor
|
|
|
- if self.config.n_shared_experts is not None:
|
|
|
+ final_hidden_states = self.experts(
|
|
|
+ hidden_states=hidden_states,
|
|
|
+ router_logits=router_logits) * self.routed_scaling_factor
|
|
|
+ if shared_output is not None:
|
|
|
final_hidden_states = final_hidden_states + shared_output
|
|
|
- final_hidden_states = tensor_model_parallel_all_reduce(
|
|
|
- final_hidden_states)
|
|
|
+ if self.tp_size > 1:
|
|
|
+ final_hidden_states = tensor_model_parallel_all_reduce(
|
|
|
+ final_hidden_states)
|
|
|
|
|
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
|
|
|
|
@@ -504,34 +476,58 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
("gate_up_proj", "up_proj", 1),
|
|
|
]
|
|
|
|
|
|
+ # Params for weights, fp8 weight scales, fp8 activation scales
|
|
|
+ # (param_name, weight_name, expert_id, shard_id)
|
|
|
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
|
+ ckpt_gate_proj_name="gate_proj",
|
|
|
+ ckpt_down_proj_name="down_proj",
|
|
|
+ ckpt_up_proj_name="up_proj",
|
|
|
+ num_experts=self.config.n_routed_experts)
|
|
|
+
|
|
|
params_dict = dict(self.named_parameters())
|
|
|
for name, loaded_weight in weights:
|
|
|
if "rotary_emb.inv_freq" in name:
|
|
|
continue
|
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
|
+ # Skip non-stacked layers and experts (experts handled below).
|
|
|
if weight_name not in name:
|
|
|
continue
|
|
|
+ # We have mlp.experts[0].gate_proj in the checkpoint.
|
|
|
+ # Since we handle the experts below in expert_params_mapping,
|
|
|
+ # we need to skip here BEFORE we update the name, otherwise
|
|
|
+ # name will be updated to mlp.experts[0].gate_up_proj, which
|
|
|
+ # will then be updated below in expert_params_mapping
|
|
|
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
|
+ if (("mlp.experts." in name) and name not in params_dict):
|
|
|
+ continue
|
|
|
name = name.replace(weight_name, param_name)
|
|
|
# Skip loading extra bias for GPTQ models.
|
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
|
continue
|
|
|
- # Skip experts that are not assigned to this worker.
|
|
|
- if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
|
|
- and name not in params_dict):
|
|
|
- continue
|
|
|
param = params_dict[name]
|
|
|
weight_loader = param.weight_loader
|
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
|
break
|
|
|
else:
|
|
|
- # Skip loading extra bias for GPTQ models.
|
|
|
- if name.endswith(".bias") and name not in params_dict:
|
|
|
- continue
|
|
|
- # Skip experts that are not assigned to this worker.
|
|
|
- if (("mlp.experts." in name or "mlp.shared_experts." in name)
|
|
|
- and name not in params_dict):
|
|
|
- continue
|
|
|
- param = params_dict[name]
|
|
|
- weight_loader = getattr(param, "weight_loader",
|
|
|
- default_weight_loader)
|
|
|
- weight_loader(param, loaded_weight)
|
|
|
+ for mapping in expert_params_mapping:
|
|
|
+ param_name, weight_name, expert_id, shard_id = mapping
|
|
|
+ if weight_name not in name:
|
|
|
+ continue
|
|
|
+ name = name.replace(weight_name, param_name)
|
|
|
+ param = params_dict[name]
|
|
|
+ weight_loader = param.weight_loader
|
|
|
+ weight_loader(param,
|
|
|
+ loaded_weight,
|
|
|
+ weight_name,
|
|
|
+ shard_id=shard_id,
|
|
|
+ expert_id=expert_id)
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ # Skip loading extra bias for GPTQ models.
|
|
|
+ if name.endswith(".bias") and name not in params_dict:
|
|
|
+ continue
|
|
|
+
|
|
|
+ param = params_dict[name]
|
|
|
+ weight_loader = getattr(param, "weight_loader",
|
|
|
+ default_weight_loader)
|
|
|
+ weight_loader(param, loaded_weight)
|