MPLinearKernel.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass
  3. from typing import Callable, Optional, Tuple
  4. import torch
  5. from aphrodite.quantization.utils import replace_parameter
  6. from aphrodite.scalar_type import ScalarType
  7. @dataclass
  8. class MPLinearLayerConfig:
  9. full_weight_shape: Tuple[int, int] # [in, out]
  10. partition_weight_shape: Tuple[int, int]
  11. weight_type: ScalarType
  12. act_type: torch.dtype
  13. group_size: int
  14. zero_points: bool
  15. has_g_idx: bool
  16. class MPLinearKernel(ABC):
  17. @classmethod
  18. @abstractmethod
  19. def get_min_capability(cls) -> int:
  20. raise NotImplementedError
  21. @classmethod
  22. @abstractmethod
  23. def can_implement(cls,
  24. c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
  25. raise NotImplementedError
  26. def __init__(self,
  27. c: MPLinearLayerConfig,
  28. w_q_param_name: str,
  29. w_s_param_name: str,
  30. w_zp_param_name: Optional[str] = None,
  31. w_gidx_param_name: Optional[str] = None) -> None:
  32. assert self.can_implement(c)
  33. self.config = c
  34. self.w_q_name = w_q_param_name
  35. self.w_s_name = w_s_param_name
  36. self.w_zp_name = w_zp_param_name
  37. self.w_gidx_name = w_gidx_param_name
  38. @abstractmethod
  39. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  40. raise NotImplementedError
  41. @abstractmethod
  42. def apply_weights(self,
  43. layer: torch.nn.Module,
  44. x: torch.Tensor,
  45. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  46. raise NotImplementedError
  47. def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
  48. fn: Callable) -> None:
  49. if name is not None and getattr(layer, name, None) is not None:
  50. old_param = getattr(layer, name)
  51. new_param = fn(old_param)
  52. # replace the parameter with torch.nn.Parameter for TorchDynamo
  53. # compatibility
  54. replace_parameter(
  55. layer, name,
  56. torch.nn.Parameter(new_param.data, requires_grad=False))
  57. def _get_weight_params(
  58. self, layer: torch.nn.Module
  59. ) -> Tuple[torch.Tensor, # w_q
  60. torch.Tensor, # w_s
  61. Optional[torch.Tensor], # w_zp,
  62. Optional[torch.Tensor] # w_gidx
  63. ]:
  64. return (
  65. getattr(layer, self.w_q_name),
  66. getattr(layer, self.w_s_name),
  67. getattr(layer, self.w_zp_name or "", None),
  68. getattr(layer, self.w_gidx_name or "", None),
  69. )