|
@@ -1,4 +1,5 @@
|
|
|
from abc import ABC, abstractmethod
|
|
|
+from typing import Optional
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -20,14 +21,16 @@ class CompressedTensorsScheme(ABC):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
@abstractmethod
|
|
|
- def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
|
|
+ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
|
+ bias: Optional[torch.Tensor]):
|
|
|
"""
|
|
|
Run the forward pass for the particular scheme. This is where
|
|
|
scheme-specific dequant/quant steps/kernels should be applied.
|
|
|
|
|
|
- :param layer: toch.nn.Module with the registered weights and
|
|
|
+ :param layer: torch.nn.Module with the registered weights and
|
|
|
other parameters relevant to the particular scheme.
|
|
|
:param x: input to the layer
|
|
|
+ :param bias: bias parameter for the layer
|
|
|
|
|
|
"""
|
|
|
raise NotImplementedError
|