|
@@ -107,31 +107,43 @@ def apply_fp8_linear(
|
|
|
input_scale: torch.Tensor,
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
cutlass_fp8_supported: bool = True,
|
|
|
+ use_per_token_if_dynamic: bool = False,
|
|
|
) -> torch.Tensor:
|
|
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
|
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
|
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
|
|
|
|
|
+ # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
|
|
if cutlass_fp8_supported:
|
|
|
- qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
|
|
|
+ qinput, x_scale = ops.scaled_fp8_quant(
|
|
|
+ input,
|
|
|
+ input_scale,
|
|
|
+ use_per_token_if_dynamic=use_per_token_if_dynamic)
|
|
|
|
|
|
# Fused GEMM_DQ
|
|
|
- output = ops.cutlass_scaled_mm(qinput,
|
|
|
- weight,
|
|
|
- out_dtype=input.dtype,
|
|
|
- scale_a=x_scale,
|
|
|
- scale_b=weight_scale,
|
|
|
- bias=bias)
|
|
|
-
|
|
|
+ return ops.cutlass_scaled_mm(qinput,
|
|
|
+ weight,
|
|
|
+ out_dtype=input.dtype,
|
|
|
+ scale_a=x_scale,
|
|
|
+ scale_b=weight_scale,
|
|
|
+ bias=bias)
|
|
|
+
|
|
|
+ # torch.scaled_mm supports per tensor weights + activations only
|
|
|
+ # so fallback to naive if per channel or per token
|
|
|
else:
|
|
|
- # NOTE: we pad the input because torch._scaled_mm is more performant
|
|
|
+ # Note: we pad the input because torch._scaled_mm is more performant
|
|
|
# for matrices with batch dimension > 16.
|
|
|
# This could change in the future.
|
|
|
- qinput, x_scale = ops.scaled_fp8_quant(input,
|
|
|
- input_scale,
|
|
|
- batch_dim_padding=17)
|
|
|
+ qinput, x_scale = ops.scaled_fp8_quant(
|
|
|
+ input,
|
|
|
+ input_scale,
|
|
|
+ batch_dim_padding=17,
|
|
|
+ use_per_token_if_dynamic=use_per_token_if_dynamic)
|
|
|
+
|
|
|
+ per_tensor_weights = (weight_scale.numel() == 1)
|
|
|
+ per_tensor_activations = (x_scale.numel() == 1)
|
|
|
|
|
|
- if weight_scale.numel() == 1:
|
|
|
+ if per_tensor_weights and per_tensor_activations:
|
|
|
# Fused GEMM_DQ
|
|
|
output, _ = torch._scaled_mm(qinput,
|
|
|
weight,
|
|
@@ -139,9 +151,11 @@ def apply_fp8_linear(
|
|
|
scale_a=x_scale,
|
|
|
scale_b=weight_scale,
|
|
|
bias=bias)
|
|
|
+ return torch.narrow(output, 0, 0, input.shape[0])
|
|
|
+
|
|
|
else:
|
|
|
- # Fallback for channelwise case, where the weight scales are
|
|
|
- # applied separately.
|
|
|
+ # Fallback for channelwise case, where we use unfused DQ
|
|
|
+ # due to limitations with scaled_mm
|
|
|
|
|
|
# Symmetric quantized GEMM by definition computes the following:
|
|
|
# C = (s_x * X) (s_w * W) + bias
|
|
@@ -155,21 +169,21 @@ def apply_fp8_linear(
|
|
|
# For the scaled_mm fallback case, we break this down, since it
|
|
|
# does not support s_w being a vector.
|
|
|
|
|
|
- # This computes C = sx * (X * W).
|
|
|
+ # GEMM
|
|
|
+ # This computes C = (X * W).
|
|
|
# Output in fp32 to allow subsequent ops to happen in-place
|
|
|
output, _ = torch._scaled_mm(qinput,
|
|
|
weight,
|
|
|
- out_dtype=torch.float32,
|
|
|
- scale_a=x_scale)
|
|
|
+ out_dtype=torch.float32)
|
|
|
+ # Unpad (undo batch_dim_padding)
|
|
|
+ output = torch.narrow(output, 0, 0, input.shape[0])
|
|
|
|
|
|
- # C = sw * sx * (X * W)
|
|
|
- output = output * weight_scale.t()
|
|
|
+ # DQ
|
|
|
+ # C = sw * sx * (X * W) + bias
|
|
|
+ output = output * x_scale * weight_scale.t()
|
|
|
if bias is not None:
|
|
|
- # C = sw * sx * (X * W) + bias
|
|
|
output = output + bias
|
|
|
- output = output.to(dtype=input.dtype)
|
|
|
-
|
|
|
- return torch.narrow(output, 0, 0, input.shape[0])
|
|
|
+ return output.to(dtype=input.dtype)
|
|
|
|
|
|
|
|
|
def apply_int8_linear(
|