Bladeren bron

improve fp8 linear layer performance

AlpinDale 8 maanden geleden
bovenliggende
commit
c4c153863e
1 gewijzigde bestanden met toevoegingen van 35 en 7 verwijderingen
  1. 35 7
      aphrodite/quantization/fp8.py

+ 35 - 7
aphrodite/quantization/fp8.py

@@ -19,10 +19,33 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
 
 
 def scaled_fp8_quant(
-        input: torch.Tensor,
-        scale: Optional[torch.Tensor] = None
+    input: torch.Tensor,
+    scale: Optional[torch.Tensor] = None,
+    batch_dim_padding: Optional[int] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
-    output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
+    """
+    Quantize input tensor to FP8 and return quantized tensor and scale.
+    This function supports both static and dynamic quantization: If you
+    provide the scale, it will use static scaling and if you omit it,
+    the scale will be determined dynamically. The function also allows
+    optional padding of the output tensor for downstream kernels that
+    will benefit from padding.
+    Args:
+        input: The input tensor to be quantized to FP8
+        scale: Optional scaling factor for the FP8 quantization
+        batch_dim_padding: If specified, pad the first dimension
+            of the output to at least this value.
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
+            scaling factor.
+    """
+    if batch_dim_padding:
+        shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
+        output = torch.empty(shape,
+                             device=input.device,
+                             dtype=torch.float8_e4m3fn)
+    else:
+        output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
     if scale is None:
         scale = torch.zeros(1, device=input.device, dtype=torch.float32)
         ops.dynamic_scaled_fp8_quant(output, input, scale)
@@ -245,9 +268,14 @@ class Fp8LinearMethod(LinearMethodBase):
         # ops.scaled_fp8_quant supports both dynamic and static quant.
         #   If dynamic, layer.act_scale is None and x_scale computed from x.
         #   If static,  layer.act_scale is scalar and x_scale set to act_scale.
-        qinput, x_scale = scaled_fp8_quant(x, layer.act_scale)
-
-        # Fused GEMM_DQ
+        qinput, x_scale = scaled_fp8_quant(x,
+                                           layer.act_scale,
+                                           batch_dim_padding=17)
+
+        # Fused GEMM_DQ -- note we padded the input above because
+        # torch._scaled_mm is more performant for matrices with
+        # batch dimension > 16. Note that this could change
+        # in the future.
         output, _ = torch._scaled_mm(
             qinput,
             layer.weight,
@@ -257,7 +285,7 @@ class Fp8LinearMethod(LinearMethodBase):
             bias=bias,
         )
 
-        return output
+        return torch.narrow(output, 0, 0, x.shape[0])
 
 
 def all_close_1d(x: torch.Tensor) -> bool: