|
@@ -216,7 +216,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
# ACTIVATION SCALE
|
|
# ACTIVATION SCALE
|
|
if self.quant_config.activation_scheme == "static":
|
|
if self.quant_config.activation_scheme == "static":
|
|
self._create_scale_param(
|
|
self._create_scale_param(
|
|
- scale_name="act_scale",
|
|
|
|
|
|
+ scale_name="input_scale",
|
|
layer=layer,
|
|
layer=layer,
|
|
output_partition_sizes=output_partition_sizes,
|
|
output_partition_sizes=output_partition_sizes,
|
|
**extra_weight_attrs)
|
|
**extra_weight_attrs)
|
|
@@ -248,7 +248,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
|
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
|
layer.logical_widths = None
|
|
layer.logical_widths = None
|
|
- layer.act_scale = None
|
|
|
|
|
|
+ layer.input_scale = None
|
|
return
|
|
return
|
|
|
|
|
|
# If checkpoint is fp8, requantize the separately quantized logical
|
|
# If checkpoint is fp8, requantize the separately quantized logical
|
|
@@ -277,14 +277,14 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
|
|
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
|
|
# Static: set to max of the act_scales (since they are equal).
|
|
# Static: set to max of the act_scales (since they are equal).
|
|
if self.quant_config.activation_scheme == "dynamic":
|
|
if self.quant_config.activation_scheme == "dynamic":
|
|
- layer.act_scale = None
|
|
|
|
|
|
+ layer.input_scale = None
|
|
elif self.quant_config.activation_scheme == "static":
|
|
elif self.quant_config.activation_scheme == "static":
|
|
- if not all_close_1d(layer.act_scale):
|
|
|
|
|
|
+ if not all_close_1d(layer.input_scale):
|
|
raise ValueError(
|
|
raise ValueError(
|
|
"All the act_scales for the logical weights of a layer "
|
|
"All the act_scales for the logical weights of a layer "
|
|
- f"must be equal. But got {layer.act_scale}")
|
|
|
|
- layer.act_scale = Parameter(layer.act_scale.max(),
|
|
|
|
- requires_grad=False)
|
|
|
|
|
|
+ f"must be equal. But got {layer.input_scale}")
|
|
|
|
+ layer.input_scale = Parameter(layer.input_scale.max(),
|
|
|
|
+ requires_grad=False)
|
|
else:
|
|
else:
|
|
raise ValueError(
|
|
raise ValueError(
|
|
f"Unknown scheme {self.quant_config.activation_scheme}")
|
|
f"Unknown scheme {self.quant_config.activation_scheme}")
|
|
@@ -294,10 +294,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
x: torch.Tensor,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
|
- # If dynamic, layer.act_scale is None and x_scale computed from x.
|
|
|
|
- # If static, layer.act_scale is scalar and x_scale set to act_scale.
|
|
|
|
|
|
+ # If dynamic, layer.input_scale is None and x_scale computed from x.
|
|
|
|
+ # If static, layer.input_scale is scalar and x_scale set to input_scale.
|
|
if bias is None and self.cutlass_fp8_supported:
|
|
if bias is None and self.cutlass_fp8_supported:
|
|
- qinput, x_scale = scaled_fp8_quant(x, layer.act_scale)
|
|
|
|
|
|
+ qinput, x_scale = scaled_fp8_quant(x, layer.input_scale)
|
|
|
|
|
|
# Fused GEMM_DQ
|
|
# Fused GEMM_DQ
|
|
output = cutlass_scaled_mm_dq(
|
|
output = cutlass_scaled_mm_dq(
|
|
@@ -310,7 +310,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
|
|
|
|
else:
|
|
else:
|
|
qinput, x_scale = scaled_fp8_quant(x,
|
|
qinput, x_scale = scaled_fp8_quant(x,
|
|
- layer.act_scale,
|
|
|
|
|
|
+ layer.input_scale,
|
|
batch_dim_padding=17)
|
|
batch_dim_padding=17)
|
|
|
|
|
|
# Fused GEMM_DQ -- note we padded the input above because
|
|
# Fused GEMM_DQ -- note we padded the input above because
|