ソースを参照

fix: mixtral fp8 ckpt loading

AlpinDale 7 ヶ月 前
コミット
39b36efabf
2 ファイル変更79 行追加35 行削除
  1. 75 32
      aphrodite/modeling/models/mixtral.py
  2. 4 3
      aphrodite/quantization/fp8.py

+ 75 - 32
aphrodite/modeling/models/mixtral.py

@@ -48,7 +48,8 @@ from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.base_config import QuantizationConfig
-from aphrodite.quantization.fp8 import Fp8Config, scaled_fp8_quant
+from aphrodite.quantization.fp8 import (Fp8Config, per_tensor_dequantize,
+                                        per_tensor_quantize, scaled_fp8_quant)
 
 
 
 
 class MixtralMoE(nn.Module):
 class MixtralMoE(nn.Module):
@@ -96,16 +97,16 @@ class MixtralMoE(nn.Module):
         if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
         if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
             params_dtype = torch.float8_e4m3fn
             params_dtype = torch.float8_e4m3fn
 
 
-        self.w13_weight = nn.Parameter(
-            torch.empty(self.num_total_experts,
-                        2 * self.intermediate_size,
-                        self.hidden_size,
-                        dtype=params_dtype))
-        self.w2_weight = nn.Parameter(
-            torch.empty(self.num_total_experts,
-                        self.hidden_size,
-                        self.intermediate_size,
-                        dtype=params_dtype))
+        self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts,
+                                                   2 * self.intermediate_size,
+                                                   self.hidden_size,
+                                                   dtype=params_dtype),
+                                       requires_grad=False)
+        self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts,
+                                                  self.hidden_size,
+                                                  self.intermediate_size,
+                                                  dtype=params_dtype),
+                                      requires_grad=False)
 
 
         set_weight_attrs(self.w13_weight, {
         set_weight_attrs(self.w13_weight, {
             "weight_loader": self.weight_loader,
             "weight_loader": self.weight_loader,
@@ -122,7 +123,10 @@ class MixtralMoE(nn.Module):
 
 
         if self.use_fp8:
         if self.use_fp8:
             # WEIGHT_SCALE (for fp8)
             # WEIGHT_SCALE (for fp8)
+            # Allocate 2 scales for w1 and w3 respectively.
+            # They will be combined to a single scale after weight loading.
             self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
             self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
+                                                     2,
                                                      dtype=torch.float32),
                                                      dtype=torch.float32),
                                           requires_grad=False)
                                           requires_grad=False)
             self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
             self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
@@ -146,11 +150,11 @@ class MixtralMoE(nn.Module):
                     raise ValueError(
                     raise ValueError(
                         "Found static activation scheme for checkpoint that "
                         "Found static activation scheme for checkpoint that "
                         "was not serialized fp8.")
                         "was not serialized fp8.")
-                self.a13_scale = nn.Parameter(torch.zeros(
+                self.a13_scale = nn.Parameter(torch.ones(
                     self.num_total_experts, dtype=torch.float32),
                     self.num_total_experts, dtype=torch.float32),
                                               requires_grad=False)
                                               requires_grad=False)
-                self.a2_scale = nn.Parameter(torch.zeros(
-                    self.num_total_experts, dtype=torch.float32),
+                self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts,
+                                                        dtype=torch.float32),
                                              requires_grad=False)
                                              requires_grad=False)
 
 
                 set_weight_attrs(self.a13_scale, {
                 set_weight_attrs(self.a13_scale, {
@@ -173,8 +177,22 @@ class MixtralMoE(nn.Module):
                        shard_size:2 * shard_size, :] = loaded_weight[shard, :]
                        shard_size:2 * shard_size, :] = loaded_weight[shard, :]
         if weight_name.endswith("w2.weight"):
         if weight_name.endswith("w2.weight"):
             param_data[expert_id, :, :] = loaded_weight[:, shard]
             param_data[expert_id, :, :] = loaded_weight[:, shard]
-        if "act_scale" in weight_name or "weight_scale" in weight_name:
+
+        # Loading scales
+        if "act_scale" in weight_name or "w2.weight_scale" in weight_name:
+            if param_data[expert_id] != 1 and (param_data[expert_id] -
+                                               loaded_weight).abs() > 1e-5:
+                raise ValueError(
+                    "act_scales of w1 and w3 of a layer "
+                    f"must be equal. But got {param_data[expert_id]} "
+                    f"vs. {loaded_weight}")
             param_data[expert_id] = loaded_weight
             param_data[expert_id] = loaded_weight
+        elif "weight_scale" in weight_name:
+            # We have to keep the weight scales of w1 and w3 because
+            # we need to re-quantize w1/w3 weights after weight loading.
+            assert "w1" in weight_name or "w3" in weight_name
+            shard_id = 0 if "w1" in weight_name else 1
+            param_data[expert_id][shard_id] = loaded_weight
 
 
     def process_weights_after_loading(self):
     def process_weights_after_loading(self):
         # Fp8 is the only case where we need to process after loading.
         # Fp8 is the only case where we need to process after loading.
@@ -187,6 +205,12 @@ class MixtralMoE(nn.Module):
                                           dtype=torch.float8_e4m3fn)
                                           dtype=torch.float8_e4m3fn)
             w2_weight = torch.empty_like(self.w2_weight.data,
             w2_weight = torch.empty_like(self.w2_weight.data,
                                          dtype=torch.float8_e4m3fn)
                                          dtype=torch.float8_e4m3fn)
+
+            # Re-initialize w13_scale because we directly quantize
+            # merged w13 weights and generate a single scaling factor.
+            self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
+                                                     dtype=torch.float32),
+                                          requires_grad=False)
             for expert in range(self.num_total_experts):
             for expert in range(self.num_total_experts):
                 w13_weight[
                 w13_weight[
                     expert, :, :], self.w13_scale[expert] = scaled_fp8_quant(
                     expert, :, :], self.w13_scale[expert] = scaled_fp8_quant(
@@ -197,25 +221,44 @@ class MixtralMoE(nn.Module):
             self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
             self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
             self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
             self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
 
 
-        # If checkpoint is fp8 + static, cleanup act_scales.
-        #   Since state_dict has an act_scale per expert but our kernels
-        #   are passed one act_scale shared across all experts.
-        elif self.quant_config.activation_scheme == "static":
-            if self.a13_scale is None or self.a2_scale is None:
-                raise ValueError(
-                    "QuantConfig has static quantization, but found "
-                    "activation scales are None.")
+        else:
+            # If checkpoint is fp8 + static, cleanup act_scales.
+            #   Since state_dict has an act_scale per expert but our kernels
+            #   are passed one act_scale shared across all experts.
+            if self.quant_config.activation_scheme == "static":
+                if self.a13_scale is None or self.a2_scale is None:
+                    raise ValueError(
+                        "QuantConfig has static quantization, but found "
+                        "activation scales are None.")
 
 
-            if (not all_close_1d(self.a13_scale)
-                    or not all_close_1d(self.a2_scale)):
-                print_warning_once(
-                    "Found act_scales that are not equal for fp8 MoE layer. "
-                    "Using the maximum across experts for each layer. ")
+                if (not all_close_1d(self.a13_scale)
+                        or not all_close_1d(self.a2_scale)):
+                    print_warning_once(
+                        "Found act_scales that are not equal for "
+                        "fp8 MoE layer. Using the maximum across experts "
+                        "for each layer. ")
 
 
-            self.a13_scale = nn.Parameter(self.a13_scale.max(),
-                                          requires_grad=False)
-            self.a2_scale = nn.Parameter(self.a2_scale.max(),
-                                         requires_grad=False)
+                self.a13_scale = nn.Parameter(self.a13_scale.max(),
+                                              requires_grad=False)
+                self.a2_scale = nn.Parameter(self.a2_scale.max(),
+                                             requires_grad=False)
+
+            assert self.w13_scale is not None
+            shard_size = self.intermediate_size
+            max_w13_scales = self.w13_scale.max(dim=1).values
+            for expert_id in range(self.num_total_experts):
+                start = 0
+                for shard_id in range(2):
+                    dq_weight = per_tensor_dequantize(
+                        self.w13_weight[expert_id][start:start +
+                                                   shard_size, :],
+                        self.w13_scale[expert_id][shard_id])
+                    self.w13_weight[expert_id][
+                        start:start + shard_size, :] = per_tensor_quantize(
+                            dq_weight, max_w13_scales[expert_id])
+                    start += shard_size
+
+            self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False)
 
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         num_tokens, hidden_size = hidden_states.shape
         num_tokens, hidden_size = hidden_states.shape

+ 4 - 3
aphrodite/quantization/fp8.py

@@ -338,14 +338,15 @@ def all_close_1d(x: torch.Tensor) -> bool:
 
 
 
 
 def per_tensor_quantize(tensor: torch.Tensor,
 def per_tensor_quantize(tensor: torch.Tensor,
-                        inv_scale: float) -> torch.Tensor:
+                        inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
     finfo = torch.finfo(torch.float8_e4m3fn)
     finfo = torch.finfo(torch.float8_e4m3fn)
     qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
     qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
     return qweight.to(torch.float8_e4m3fn)
     return qweight.to(torch.float8_e4m3fn)
 
 
 
 
-def per_tensor_dequantize(tensor: torch.Tensor,
-                          inv_scale: float) -> torch.Tensor:
+def per_tensor_dequantize(
+        tensor: torch.Tensor, inv_scale: Union[float,
+                                               torch.Tensor]) -> torch.Tensor:
     fake_qweight = tensor.to(torch.float16)
     fake_qweight = tensor.to(torch.float16)
     dq_weight = fake_qweight * inv_scale
     dq_weight = fake_qweight * inv_scale
     return dq_weight
     return dq_weight