compressed_tensors_scheme.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from abc import ABC, abstractmethod
  2. from typing import Optional
  3. import torch
  4. __all__ = ["CompressedTensorsScheme"]
  5. class CompressedTensorsScheme(ABC):
  6. """
  7. Abstract class used to describe the weight creation and forward pass
  8. of different quantization schemes supported by CompressedTensors.
  9. """
  10. @abstractmethod
  11. def get_min_capability(self) -> int:
  12. """
  13. Get minimum device capability.
  14. """
  15. raise NotImplementedError
  16. @abstractmethod
  17. def create_weights(self, *args, **kwargs):
  18. """
  19. Weight creation for the particular scheme. Inputs to this function
  20. """
  21. raise NotImplementedError
  22. @abstractmethod
  23. def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
  24. bias: Optional[torch.Tensor]):
  25. """
  26. Run the forward pass for the particular scheme. This is where
  27. scheme-specific dequant/quant steps/kernels should be applied.
  28. :param layer: torch.nn.Module with the registered weights and
  29. other parameters relevant to the particular scheme.
  30. :param x: input to the layer
  31. :param bias: bias parameter for the layer
  32. """
  33. raise NotImplementedError
  34. @abstractmethod
  35. def process_weights_after_loading(self, layer: torch.nn.Module):
  36. """
  37. Called after weight loading is complete for any cleanup that
  38. needs to occur.
  39. """
  40. raise NotImplementedError