|
@@ -19,10 +19,33 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|
|
|
|
|
|
|
|
def scaled_fp8_quant(
|
|
|
- input: torch.Tensor,
|
|
|
- scale: Optional[torch.Tensor] = None
|
|
|
+ input: torch.Tensor,
|
|
|
+ scale: Optional[torch.Tensor] = None,
|
|
|
+ batch_dim_padding: Optional[int] = None,
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
- output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
|
|
+ """
|
|
|
+ Quantize input tensor to FP8 and return quantized tensor and scale.
|
|
|
+ This function supports both static and dynamic quantization: If you
|
|
|
+ provide the scale, it will use static scaling and if you omit it,
|
|
|
+ the scale will be determined dynamically. The function also allows
|
|
|
+ optional padding of the output tensor for downstream kernels that
|
|
|
+ will benefit from padding.
|
|
|
+ Args:
|
|
|
+ input: The input tensor to be quantized to FP8
|
|
|
+ scale: Optional scaling factor for the FP8 quantization
|
|
|
+ batch_dim_padding: If specified, pad the first dimension
|
|
|
+ of the output to at least this value.
|
|
|
+ Returns:
|
|
|
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
|
|
+ scaling factor.
|
|
|
+ """
|
|
|
+ if batch_dim_padding:
|
|
|
+ shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
|
|
|
+ output = torch.empty(shape,
|
|
|
+ device=input.device,
|
|
|
+ dtype=torch.float8_e4m3fn)
|
|
|
+ else:
|
|
|
+ output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
|
|
if scale is None:
|
|
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
|
|
ops.dynamic_scaled_fp8_quant(output, input, scale)
|
|
@@ -245,9 +268,14 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
|
|
# If dynamic, layer.act_scale is None and x_scale computed from x.
|
|
|
# If static, layer.act_scale is scalar and x_scale set to act_scale.
|
|
|
- qinput, x_scale = scaled_fp8_quant(x, layer.act_scale)
|
|
|
-
|
|
|
- # Fused GEMM_DQ
|
|
|
+ qinput, x_scale = scaled_fp8_quant(x,
|
|
|
+ layer.act_scale,
|
|
|
+ batch_dim_padding=17)
|
|
|
+
|
|
|
+ # Fused GEMM_DQ -- note we padded the input above because
|
|
|
+ # torch._scaled_mm is more performant for matrices with
|
|
|
+ # batch dimension > 16. Note that this could change
|
|
|
+ # in the future.
|
|
|
output, _ = torch._scaled_mm(
|
|
|
qinput,
|
|
|
layer.weight,
|
|
@@ -257,7 +285,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
|
bias=bias,
|
|
|
)
|
|
|
|
|
|
- return output
|
|
|
+ return torch.narrow(output, 0, 0, x.shape[0])
|
|
|
|
|
|
|
|
|
def all_close_1d(x: torch.Tensor) -> bool:
|