Forráskód Böngészése

fix: support bias term in compressed-tensors quant

AlpinDale 6 hónapja
szülő
commit
500f3b654f

+ 1 - 4
aphrodite/quantization/compressed_tensors/compressed_tensors.py

@@ -266,10 +266,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
 
         """
 
-        if bias is not None:
-            raise ValueError("bias is not supported for this linear method")
-
         scheme = layer.scheme
         if scheme is None:
             raise ValueError("A scheme must be defined for each layer")
-        return scheme.apply_weights(layer, x)
+        return scheme.apply_weights(layer, x, bias=bias)

+ 5 - 2
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py

@@ -1,4 +1,5 @@
 from abc import ABC, abstractmethod
+from typing import Optional
 
 import torch
 
@@ -20,14 +21,16 @@ class CompressedTensorsScheme(ABC):
         raise NotImplementedError
 
     @abstractmethod
-    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
+    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
+                      bias: Optional[torch.Tensor]):
         """
         Run the forward pass for the particular scheme. This is where 
         scheme-specific dequant/quant steps/kernels should be applied.
 
-        :param layer: toch.nn.Module with the registered weights and 
+        :param layer: torch.nn.Module with the registered weights and 
             other parameters relevant to the particular scheme. 
         :param x: input to the layer
+        :param bias: bias parameter for the layer
 
         """
         raise NotImplementedError

+ 5 - 4
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py

@@ -1,4 +1,4 @@
-from typing import Callable, List
+from typing import Callable, List, Optional
 
 import torch
 import torch.nn.functional as F
@@ -37,6 +37,7 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
         layer.register_parameter("weight", weight)
         set_weight_attrs(weight, {"weight_loader": weight_loader})
 
-    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
-        weight = layer.weight
-        return F.linear(x, weight)
+    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
+                      bias: Optional[torch.Tensor]) -> torch.Tensor:
+
+        return F.linear(x, layer.weight, bias)

+ 6 - 1
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py

@@ -118,7 +118,8 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
                               requires_grad=False)
         layer.workspace = workspace
 
-    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
+    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
+                      bias: Optional[torch.Tensor]) -> torch.Tensor:
         qweight = layer.weight_packed
         meta = layer.meta
         scales = layer.scale_packed
@@ -135,4 +136,8 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
                                             size_n, size_k)
 
         output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
+
+        if bias is not None:
+            output.add_(bias)  # In-place add
+
         return output

+ 6 - 3
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py

@@ -1,4 +1,4 @@
-from typing import Callable, List
+from typing import Callable, List, Optional
 
 import torch
 from torch.nn import Parameter
@@ -78,8 +78,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
                                                   **layer_kwargs)
             layer.register_parameter("input_scale", scale)
 
-    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
+    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
+                      bias: Optional[torch.Tensor]) -> torch.Tensor:
+
         return apply_int8_linear(input=x,
                                  weight=layer.weight,
                                  weight_scale=layer.weight_scale,
-                                 input_scale=layer.input_scale)
+                                 input_scale=layer.input_scale,
+                                 bias=bias)

+ 5 - 2
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py

@@ -148,7 +148,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
             group_size=layer.group_size)
         replace_tensor(layer, "weight_scale", marlin_scales)
 
-    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
+    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
+                      bias: Optional[torch.Tensor]) -> torch.Tensor:
+
         return apply_marlin_linear(
             input=x,
             weight=layer.weight_packed,
@@ -159,4 +161,5 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
             num_bits=self.num_bits,
             output_size_per_partition=layer.output_size_per_partition,
             input_size_per_partition=layer.input_size_per_partition,
-            is_k_full=True)
+            is_k_full=True,
+            bias=bias)

+ 2 - 3
aphrodite/quantization/utils/w8a8_utils.py

@@ -148,8 +148,6 @@ def apply_int8_linear(
     input_scale: torch.Tensor,
     bias: Optional[torch.Tensor] = None,
 ):
-    if bias is not None:
-        raise NotImplementedError("W8A8 with int8 does not yet support bias.")
 
     # ops.scaled_int8_quant supports both dynamic and static quant.
     # * dynamic, layer.input_scale is None and x_scale computed from x.
@@ -160,4 +158,5 @@ def apply_int8_linear(
                                  weight,
                                  scale_a=x_scale,
                                  scale_b=weight_scale,
-                                 out_dtype=input.dtype)
+                                 out_dtype=input.dtype,
+                                 bias=bias)