|
@@ -48,7 +48,8 @@ from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
|
|
from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
from aphrodite.modeling.utils import set_weight_attrs
|
|
from aphrodite.modeling.utils import set_weight_attrs
|
|
from aphrodite.quantization.base_config import QuantizationConfig
|
|
from aphrodite.quantization.base_config import QuantizationConfig
|
|
-from aphrodite.quantization.fp8 import Fp8Config, scaled_fp8_quant
|
|
|
|
|
|
+from aphrodite.quantization.fp8 import (Fp8Config, per_tensor_dequantize,
|
|
|
|
+ per_tensor_quantize, scaled_fp8_quant)
|
|
|
|
|
|
|
|
|
|
class MixtralMoE(nn.Module):
|
|
class MixtralMoE(nn.Module):
|
|
@@ -96,16 +97,16 @@ class MixtralMoE(nn.Module):
|
|
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
|
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
|
params_dtype = torch.float8_e4m3fn
|
|
params_dtype = torch.float8_e4m3fn
|
|
|
|
|
|
- self.w13_weight = nn.Parameter(
|
|
|
|
- torch.empty(self.num_total_experts,
|
|
|
|
- 2 * self.intermediate_size,
|
|
|
|
- self.hidden_size,
|
|
|
|
- dtype=params_dtype))
|
|
|
|
- self.w2_weight = nn.Parameter(
|
|
|
|
- torch.empty(self.num_total_experts,
|
|
|
|
- self.hidden_size,
|
|
|
|
- self.intermediate_size,
|
|
|
|
- dtype=params_dtype))
|
|
|
|
|
|
+ self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts,
|
|
|
|
+ 2 * self.intermediate_size,
|
|
|
|
+ self.hidden_size,
|
|
|
|
+ dtype=params_dtype),
|
|
|
|
+ requires_grad=False)
|
|
|
|
+ self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts,
|
|
|
|
+ self.hidden_size,
|
|
|
|
+ self.intermediate_size,
|
|
|
|
+ dtype=params_dtype),
|
|
|
|
+ requires_grad=False)
|
|
|
|
|
|
set_weight_attrs(self.w13_weight, {
|
|
set_weight_attrs(self.w13_weight, {
|
|
"weight_loader": self.weight_loader,
|
|
"weight_loader": self.weight_loader,
|
|
@@ -122,7 +123,10 @@ class MixtralMoE(nn.Module):
|
|
|
|
|
|
if self.use_fp8:
|
|
if self.use_fp8:
|
|
# WEIGHT_SCALE (for fp8)
|
|
# WEIGHT_SCALE (for fp8)
|
|
|
|
+ # Allocate 2 scales for w1 and w3 respectively.
|
|
|
|
+ # They will be combined to a single scale after weight loading.
|
|
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
|
|
+ 2,
|
|
dtype=torch.float32),
|
|
dtype=torch.float32),
|
|
requires_grad=False)
|
|
requires_grad=False)
|
|
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
@@ -146,11 +150,11 @@ class MixtralMoE(nn.Module):
|
|
raise ValueError(
|
|
raise ValueError(
|
|
"Found static activation scheme for checkpoint that "
|
|
"Found static activation scheme for checkpoint that "
|
|
"was not serialized fp8.")
|
|
"was not serialized fp8.")
|
|
- self.a13_scale = nn.Parameter(torch.zeros(
|
|
|
|
|
|
+ self.a13_scale = nn.Parameter(torch.ones(
|
|
self.num_total_experts, dtype=torch.float32),
|
|
self.num_total_experts, dtype=torch.float32),
|
|
requires_grad=False)
|
|
requires_grad=False)
|
|
- self.a2_scale = nn.Parameter(torch.zeros(
|
|
|
|
- self.num_total_experts, dtype=torch.float32),
|
|
|
|
|
|
+ self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
|
|
+ dtype=torch.float32),
|
|
requires_grad=False)
|
|
requires_grad=False)
|
|
|
|
|
|
set_weight_attrs(self.a13_scale, {
|
|
set_weight_attrs(self.a13_scale, {
|
|
@@ -173,8 +177,22 @@ class MixtralMoE(nn.Module):
|
|
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
|
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
|
if weight_name.endswith("w2.weight"):
|
|
if weight_name.endswith("w2.weight"):
|
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
|
- if "act_scale" in weight_name or "weight_scale" in weight_name:
|
|
|
|
|
|
+
|
|
|
|
+ # Loading scales
|
|
|
|
+ if "act_scale" in weight_name or "w2.weight_scale" in weight_name:
|
|
|
|
+ if param_data[expert_id] != 1 and (param_data[expert_id] -
|
|
|
|
+ loaded_weight).abs() > 1e-5:
|
|
|
|
+ raise ValueError(
|
|
|
|
+ "act_scales of w1 and w3 of a layer "
|
|
|
|
+ f"must be equal. But got {param_data[expert_id]} "
|
|
|
|
+ f"vs. {loaded_weight}")
|
|
param_data[expert_id] = loaded_weight
|
|
param_data[expert_id] = loaded_weight
|
|
|
|
+ elif "weight_scale" in weight_name:
|
|
|
|
+ # We have to keep the weight scales of w1 and w3 because
|
|
|
|
+ # we need to re-quantize w1/w3 weights after weight loading.
|
|
|
|
+ assert "w1" in weight_name or "w3" in weight_name
|
|
|
|
+ shard_id = 0 if "w1" in weight_name else 1
|
|
|
|
+ param_data[expert_id][shard_id] = loaded_weight
|
|
|
|
|
|
def process_weights_after_loading(self):
|
|
def process_weights_after_loading(self):
|
|
# Fp8 is the only case where we need to process after loading.
|
|
# Fp8 is the only case where we need to process after loading.
|
|
@@ -187,6 +205,12 @@ class MixtralMoE(nn.Module):
|
|
dtype=torch.float8_e4m3fn)
|
|
dtype=torch.float8_e4m3fn)
|
|
w2_weight = torch.empty_like(self.w2_weight.data,
|
|
w2_weight = torch.empty_like(self.w2_weight.data,
|
|
dtype=torch.float8_e4m3fn)
|
|
dtype=torch.float8_e4m3fn)
|
|
|
|
+
|
|
|
|
+ # Re-initialize w13_scale because we directly quantize
|
|
|
|
+ # merged w13 weights and generate a single scaling factor.
|
|
|
|
+ self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
|
|
+ dtype=torch.float32),
|
|
|
|
+ requires_grad=False)
|
|
for expert in range(self.num_total_experts):
|
|
for expert in range(self.num_total_experts):
|
|
w13_weight[
|
|
w13_weight[
|
|
expert, :, :], self.w13_scale[expert] = scaled_fp8_quant(
|
|
expert, :, :], self.w13_scale[expert] = scaled_fp8_quant(
|
|
@@ -197,25 +221,44 @@ class MixtralMoE(nn.Module):
|
|
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
|
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
|
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
|
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
|
|
|
|
|
- # If checkpoint is fp8 + static, cleanup act_scales.
|
|
|
|
- # Since state_dict has an act_scale per expert but our kernels
|
|
|
|
- # are passed one act_scale shared across all experts.
|
|
|
|
- elif self.quant_config.activation_scheme == "static":
|
|
|
|
- if self.a13_scale is None or self.a2_scale is None:
|
|
|
|
- raise ValueError(
|
|
|
|
- "QuantConfig has static quantization, but found "
|
|
|
|
- "activation scales are None.")
|
|
|
|
|
|
+ else:
|
|
|
|
+ # If checkpoint is fp8 + static, cleanup act_scales.
|
|
|
|
+ # Since state_dict has an act_scale per expert but our kernels
|
|
|
|
+ # are passed one act_scale shared across all experts.
|
|
|
|
+ if self.quant_config.activation_scheme == "static":
|
|
|
|
+ if self.a13_scale is None or self.a2_scale is None:
|
|
|
|
+ raise ValueError(
|
|
|
|
+ "QuantConfig has static quantization, but found "
|
|
|
|
+ "activation scales are None.")
|
|
|
|
|
|
- if (not all_close_1d(self.a13_scale)
|
|
|
|
- or not all_close_1d(self.a2_scale)):
|
|
|
|
- print_warning_once(
|
|
|
|
- "Found act_scales that are not equal for fp8 MoE layer. "
|
|
|
|
- "Using the maximum across experts for each layer. ")
|
|
|
|
|
|
+ if (not all_close_1d(self.a13_scale)
|
|
|
|
+ or not all_close_1d(self.a2_scale)):
|
|
|
|
+ print_warning_once(
|
|
|
|
+ "Found act_scales that are not equal for "
|
|
|
|
+ "fp8 MoE layer. Using the maximum across experts "
|
|
|
|
+ "for each layer. ")
|
|
|
|
|
|
- self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
|
|
|
- requires_grad=False)
|
|
|
|
- self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
|
|
|
- requires_grad=False)
|
|
|
|
|
|
+ self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
|
|
|
+ requires_grad=False)
|
|
|
|
+ self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
|
|
|
+ requires_grad=False)
|
|
|
|
+
|
|
|
|
+ assert self.w13_scale is not None
|
|
|
|
+ shard_size = self.intermediate_size
|
|
|
|
+ max_w13_scales = self.w13_scale.max(dim=1).values
|
|
|
|
+ for expert_id in range(self.num_total_experts):
|
|
|
|
+ start = 0
|
|
|
|
+ for shard_id in range(2):
|
|
|
|
+ dq_weight = per_tensor_dequantize(
|
|
|
|
+ self.w13_weight[expert_id][start:start +
|
|
|
|
+ shard_size, :],
|
|
|
|
+ self.w13_scale[expert_id][shard_id])
|
|
|
|
+ self.w13_weight[expert_id][
|
|
|
|
+ start:start + shard_size, :] = per_tensor_quantize(
|
|
|
|
+ dq_weight, max_w13_scales[expert_id])
|
|
|
|
+ start += shard_size
|
|
|
|
+
|
|
|
|
+ self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False)
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
num_tokens, hidden_size = hidden_states.shape
|
|
num_tokens, hidden_size = hidden_states.shape
|