Przeglądaj źródła

fp8: `act_scale` -> `input_scale`

AlpinDale 7 miesięcy temu
rodzic
commit
ab5ffb228c

+ 4 - 4
aphrodite/modeling/models/mixtral.py

@@ -179,7 +179,7 @@ class MixtralMoE(nn.Module):
             param_data[expert_id, :, :] = loaded_weight[:, shard]
             param_data[expert_id, :, :] = loaded_weight[:, shard]
 
 
         # Loading scales
         # Loading scales
-        if "act_scale" in weight_name or "w2.weight_scale" in weight_name:
+        if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
             if param_data[expert_id] != 1 and (param_data[expert_id] -
             if param_data[expert_id] != 1 and (param_data[expert_id] -
                                                loaded_weight).abs() > 1e-5:
                                                loaded_weight).abs() > 1e-5:
                 raise ValueError(
                 raise ValueError(
@@ -223,8 +223,8 @@ class MixtralMoE(nn.Module):
 
 
         else:
         else:
             # If checkpoint is fp8 + static, cleanup act_scales.
             # If checkpoint is fp8 + static, cleanup act_scales.
-            #   Since state_dict has an act_scale per expert but our kernels
-            #   are passed one act_scale shared across all experts.
+            #   Since state_dict has an input_scale per expert but our kernels
+            #   are passed one input_scale shared across all experts.
             if self.quant_config.activation_scheme == "static":
             if self.quant_config.activation_scheme == "static":
                 if self.a13_scale is None or self.a2_scale is None:
                 if self.a13_scale is None or self.a2_scale is None:
                     raise ValueError(
                     raise ValueError(
@@ -571,7 +571,7 @@ class MixtralForCausalLM(nn.Module):
             # These are the activation scales for the experts
             # These are the activation scales for the experts
             # (param_name, weight_name, expert_id)
             # (param_name, weight_name, expert_id)
             ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
             ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
-             f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
+             f"experts.{expert_id}.{weight_name}.input_scale", expert_id)
             for expert_id in range(self.config.num_local_experts)
             for expert_id in range(self.config.num_local_experts)
             for weight_name in ["w1", "w2", "w3"]
             for weight_name in ["w1", "w2", "w3"]
         ]
         ]

+ 11 - 11
aphrodite/quantization/fp8.py

@@ -216,7 +216,7 @@ class Fp8LinearMethod(LinearMethodBase):
             # ACTIVATION SCALE
             # ACTIVATION SCALE
             if self.quant_config.activation_scheme == "static":
             if self.quant_config.activation_scheme == "static":
                 self._create_scale_param(
                 self._create_scale_param(
-                    scale_name="act_scale",
+                    scale_name="input_scale",
                     layer=layer,
                     layer=layer,
                     output_partition_sizes=output_partition_sizes,
                     output_partition_sizes=output_partition_sizes,
                     **extra_weight_attrs)
                     **extra_weight_attrs)
@@ -248,7 +248,7 @@ class Fp8LinearMethod(LinearMethodBase):
             layer.weight = Parameter(qweight.t(), requires_grad=False)
             layer.weight = Parameter(qweight.t(), requires_grad=False)
             layer.weight_scale = Parameter(weight_scale, requires_grad=False)
             layer.weight_scale = Parameter(weight_scale, requires_grad=False)
             layer.logical_widths = None
             layer.logical_widths = None
-            layer.act_scale = None
+            layer.input_scale = None
             return
             return
 
 
         # If checkpoint is fp8, requantize the separately quantized logical
         # If checkpoint is fp8, requantize the separately quantized logical
@@ -277,14 +277,14 @@ class Fp8LinearMethod(LinearMethodBase):
             #   Dynamic: set to None (required input to ops.scaled_fp8_quant).
             #   Dynamic: set to None (required input to ops.scaled_fp8_quant).
             #   Static:  set to max of the act_scales (since they are equal).
             #   Static:  set to max of the act_scales (since they are equal).
             if self.quant_config.activation_scheme == "dynamic":
             if self.quant_config.activation_scheme == "dynamic":
-                layer.act_scale = None
+                layer.input_scale = None
             elif self.quant_config.activation_scheme == "static":
             elif self.quant_config.activation_scheme == "static":
-                if not all_close_1d(layer.act_scale):
+                if not all_close_1d(layer.input_scale):
                     raise ValueError(
                     raise ValueError(
                         "All the act_scales for the logical weights of a layer "
                         "All the act_scales for the logical weights of a layer "
-                        f"must be equal. But got {layer.act_scale}")
-                layer.act_scale = Parameter(layer.act_scale.max(),
-                                            requires_grad=False)
+                        f"must be equal. But got {layer.input_scale}")
+                layer.input_scale = Parameter(layer.input_scale.max(),
+                                              requires_grad=False)
             else:
             else:
                 raise ValueError(
                 raise ValueError(
                     f"Unknown scheme {self.quant_config.activation_scheme}")
                     f"Unknown scheme {self.quant_config.activation_scheme}")
@@ -294,10 +294,10 @@ class Fp8LinearMethod(LinearMethodBase):
               x: torch.Tensor,
               x: torch.Tensor,
               bias: Optional[torch.Tensor] = None) -> torch.Tensor:
               bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         # ops.scaled_fp8_quant supports both dynamic and static quant.
         # ops.scaled_fp8_quant supports both dynamic and static quant.
-        #   If dynamic, layer.act_scale is None and x_scale computed from x.
-        #   If static,  layer.act_scale is scalar and x_scale set to act_scale.
+        #   If dynamic, layer.input_scale is None and x_scale computed from x.
+        #   If static,  layer.input_scale is scalar and x_scale set to input_scale.
         if bias is None and self.cutlass_fp8_supported:
         if bias is None and self.cutlass_fp8_supported:
-            qinput, x_scale = scaled_fp8_quant(x, layer.act_scale)
+            qinput, x_scale = scaled_fp8_quant(x, layer.input_scale)
 
 
             # Fused GEMM_DQ
             # Fused GEMM_DQ
             output = cutlass_scaled_mm_dq(
             output = cutlass_scaled_mm_dq(
@@ -310,7 +310,7 @@ class Fp8LinearMethod(LinearMethodBase):
 
 
         else:
         else:
             qinput, x_scale = scaled_fp8_quant(x,
             qinput, x_scale = scaled_fp8_quant(x,
-                                               layer.act_scale,
+                                               layer.input_scale,
                                                batch_dim_padding=17)
                                                batch_dim_padding=17)
 
 
             # Fused GEMM_DQ -- note we padded the input above because
             # Fused GEMM_DQ -- note we padded the input above because