compressed_tensors_scheme.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from abc import ABC, abstractmethod
  2. from typing import Optional
  3. import torch
  4. __all__ = ["CompressedTensorsScheme"]
  5. class CompressedTensorsScheme(ABC):
  6. """
  7. Abstract class used to describe the weight creation and forward pass
  8. of different quantization schemes supported by CompressedTensors.
  9. """
  10. @classmethod
  11. @abstractmethod
  12. def get_min_capability(cls) -> int:
  13. """
  14. Get minimum device capability.
  15. """
  16. raise NotImplementedError
  17. @abstractmethod
  18. def create_weights(self, *args, **kwargs):
  19. """
  20. Weight creation for the particular scheme. Inputs to this function
  21. """
  22. raise NotImplementedError
  23. @abstractmethod
  24. def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
  25. bias: Optional[torch.Tensor]):
  26. """
  27. Run the forward pass for the particular scheme. This is where
  28. scheme-specific dequant/quant steps/kernels should be applied.
  29. :param layer: torch.nn.Module with the registered weights and
  30. other parameters relevant to the particular scheme.
  31. :param x: input to the layer
  32. :param bias: bias parameter for the layer
  33. """
  34. raise NotImplementedError
  35. @abstractmethod
  36. def process_weights_after_loading(self, layer: torch.nn.Module):
  37. """
  38. Called after weight loading is complete for any cleanup that
  39. needs to occur.
  40. """
  41. raise NotImplementedError