1
0
Эх сурвалжийг харах

chore: enable dynamic per-token `fp8`

AlpinDale 7 сар өмнө
parent
commit
d3c474d219

+ 13 - 13
aphrodite/_custom_ops.py

@@ -297,6 +297,7 @@ def scaled_fp8_quant(
     input: torch.Tensor,
     scale: Optional[torch.Tensor] = None,
     batch_dim_padding: Optional[int] = None,
+    use_per_token_if_dynamic: bool = False,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     """
     Quantize input tensor to FP8 and return quantized tensor and scale.
@@ -312,6 +313,8 @@ def scaled_fp8_quant(
         scale: Optional scaling factor for the FP8 quantization
         batch_dim_padding: If specified, pad the first dimension
             of the output to at least this value.
+         use_per_token_if_dynamic: Whether to do per_tensor or per_token 
+            in the dynamic quantization case.
 
     Returns:
         Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
@@ -325,22 +328,19 @@ def scaled_fp8_quant(
     else:
         output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
     if scale is None:
-        scale = torch.zeros(1, device=input.device, dtype=torch.float32)
-        torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
+        if use_per_token_if_dynamic:
+            scale = torch.empty((input.numel() // input.shape[-1], 1),
+                                device=input.device,
+                                dtype=torch.float32)
+            torch.ops._C.dynamic_per_token_scaled_fp8_quant(
+                output, input, scale)
+        else:
+            scale = torch.zeros(1, device=input.device, dtype=torch.float32)
+            torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
     else:
         torch.ops._C.static_scaled_fp8_quant(output, input, scale)
-    return output, scale
-
 
-def dynamic_per_token_scaled_fp8_quant(
-        input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-
-    output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
-    scales = torch.empty((input.numel() // input.shape[-1], 1),
-                         device=input.device,
-                         dtype=torch.float32)
-    torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales)
-    return output, scales
+    return output, scale
 
 
 # int8

+ 2 - 1
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py

@@ -103,4 +103,5 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
             weight_scale=layer.weight_scale,
             input_scale=layer.input_scale,
             bias=bias,
-            cutlass_fp8_supported=self.cutlass_fp8_supported)
+            cutlass_fp8_supported=self.cutlass_fp8_supported,
+            use_per_token_if_dynamic=True)

+ 2 - 1
aphrodite/quantization/fp8.py

@@ -213,7 +213,8 @@ class Fp8LinearMethod(LinearMethodBase):
             weight_scale=layer.weight_scale,
             input_scale=layer.input_scale,
             bias=bias,
-            cutlass_fp8_supported=self.cutlass_fp8_supported)
+            cutlass_fp8_supported=self.cutlass_fp8_supported,
+            use_per_token_if_dynamic=False)
 
 
 class Fp8MoEMethod(FusedMoEMethodBase):

+ 38 - 24
aphrodite/quantization/utils/w8a8_utils.py

@@ -107,31 +107,43 @@ def apply_fp8_linear(
     input_scale: torch.Tensor,
     bias: Optional[torch.Tensor] = None,
     cutlass_fp8_supported: bool = True,
+    use_per_token_if_dynamic: bool = False,
 ) -> torch.Tensor:
     # ops.scaled_fp8_quant supports both dynamic and static quant.
     #   If dynamic, layer.input_scale is None and x_scale computed from x.
     #   If static, layer.input_scale is scalar and x_scale is input_scale.
 
+    # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
     if cutlass_fp8_supported:
-        qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
+        qinput, x_scale = ops.scaled_fp8_quant(
+            input,
+            input_scale,
+            use_per_token_if_dynamic=use_per_token_if_dynamic)
 
         # Fused GEMM_DQ
-        output = ops.cutlass_scaled_mm(qinput,
-                                       weight,
-                                       out_dtype=input.dtype,
-                                       scale_a=x_scale,
-                                       scale_b=weight_scale,
-                                       bias=bias)
-
+        return ops.cutlass_scaled_mm(qinput,
+                                     weight,
+                                     out_dtype=input.dtype,
+                                     scale_a=x_scale,
+                                     scale_b=weight_scale,
+                                     bias=bias)
+
+    # torch.scaled_mm supports per tensor weights + activations only
+    # so fallback to naive if per channel or per token
     else:
-        # NOTE: we pad the input because torch._scaled_mm is more performant
+        # Note: we pad the input because torch._scaled_mm is more performant
         # for matrices with batch dimension > 16.
         # This could change in the future.
-        qinput, x_scale = ops.scaled_fp8_quant(input,
-                                               input_scale,
-                                               batch_dim_padding=17)
+        qinput, x_scale = ops.scaled_fp8_quant(
+            input,
+            input_scale,
+            batch_dim_padding=17,
+            use_per_token_if_dynamic=use_per_token_if_dynamic)
+
+        per_tensor_weights = (weight_scale.numel() == 1)
+        per_tensor_activations = (x_scale.numel() == 1)
 
-        if weight_scale.numel() == 1:
+        if per_tensor_weights and per_tensor_activations:
             # Fused GEMM_DQ
             output, _ = torch._scaled_mm(qinput,
                                          weight,
@@ -139,9 +151,11 @@ def apply_fp8_linear(
                                          scale_a=x_scale,
                                          scale_b=weight_scale,
                                          bias=bias)
+            return torch.narrow(output, 0, 0, input.shape[0])
+
         else:
-            # Fallback for channelwise case, where the weight scales are
-            # applied separately.
+            # Fallback for channelwise case, where we use unfused DQ
+            # due to limitations with scaled_mm
 
             # Symmetric quantized GEMM by definition computes the following:
             #   C = (s_x * X) (s_w * W) + bias
@@ -155,21 +169,21 @@ def apply_fp8_linear(
             # For the scaled_mm fallback case, we break this down, since it
             # does not support s_w being a vector.
 
-            # This computes C = sx * (X * W).
+            # GEMM
+            # This computes C = (X * W).
             # Output in fp32 to allow subsequent ops to happen in-place
             output, _ = torch._scaled_mm(qinput,
                                          weight,
-                                         out_dtype=torch.float32,
-                                         scale_a=x_scale)
+                                         out_dtype=torch.float32)
+            # Unpad (undo batch_dim_padding)
+            output = torch.narrow(output, 0, 0, input.shape[0])
 
-            # C = sw * sx * (X * W)
-            output = output * weight_scale.t()
+            # DQ
+            # C = sw * sx * (X * W) + bias
+            output = output * x_scale * weight_scale.t()
             if bias is not None:
-                # C = sw * sx * (X * W) + bias
                 output = output + bias
-            output = output.to(dtype=input.dtype)
-
-    return torch.narrow(output, 0, 0, input.shape[0])
+            return output.to(dtype=input.dtype)
 
 
 def apply_int8_linear(