compressed_tensors_wNa16.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from typing import Callable, List, Optional, Set
  2. import torch
  3. from loguru import logger
  4. from aphrodite.modeling.parameter import (BaseAphroditeParameter,
  5. ChannelQuantScaleParameter,
  6. GroupQuantScaleParameter,
  7. PackedAphroditeParameter,
  8. RowAphroditeParameter)
  9. from aphrodite.quantization.compressed_tensors.schemes import (
  10. CompressedTensorsScheme)
  11. from aphrodite.quantization.compressed_tensors.utils import ActivationOrdering
  12. from aphrodite.quantization.kernels import (MPLinearLayerConfig,
  13. choose_mp_linear_kernel)
  14. from aphrodite.quantization.utils.marlin_utils import (
  15. marlin_repeat_scales_on_all_ranks)
  16. from aphrodite.scalar_type import scalar_types
  17. __all__ = ["CompressedTensorsWNA16"]
  18. WNA16_SUPPORTED_TYPES_MAP = {
  19. 4: scalar_types.uint4b8,
  20. 8: scalar_types.uint8b128
  21. }
  22. WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
  23. class CompressedTensorsWNA16(CompressedTensorsScheme):
  24. _kernel_backends_being_used: Set[str] = set()
  25. def __init__(self,
  26. strategy: str,
  27. num_bits: int,
  28. group_size: Optional[int] = None,
  29. actorder: Optional[ActivationOrdering] = None):
  30. self.pack_factor = 32 // num_bits
  31. self.strategy = strategy
  32. self.group_size = -1 if group_size is None else group_size
  33. self.has_g_idx = actorder == ActivationOrdering.GROUP
  34. if self.group_size == -1 and self.strategy != "channel":
  35. raise ValueError("Marlin kernels require group quantization or "
  36. "channelwise quantization, but found no group "
  37. "size and strategy is not channelwise.")
  38. if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
  39. raise ValueError(
  40. f"Unsupported num_bits = {num_bits}. "
  41. f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
  42. self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
  43. @classmethod
  44. def get_min_capability(cls) -> int:
  45. # ampere and up
  46. return 80
  47. def create_weights(self, layer: torch.nn.Module, output_size: int,
  48. input_size: int, output_partition_sizes: List[int],
  49. input_size_per_partition: int,
  50. params_dtype: torch.dtype, weight_loader: Callable,
  51. **kwargs):
  52. output_size_per_partition = sum(output_partition_sizes)
  53. mp_linear_kernel_config = MPLinearLayerConfig(
  54. full_weight_shape=(input_size, output_size),
  55. partition_weight_shape=\
  56. (input_size_per_partition, output_size_per_partition),
  57. weight_type=self.quant_type,
  58. act_type=params_dtype,
  59. group_size=self.group_size,
  60. zero_points=False,
  61. has_g_idx=self.has_g_idx
  62. )
  63. kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
  64. if kernel_type.__name__ not in self._kernel_backends_being_used:
  65. logger.info(
  66. f"Using {kernel_type.__name__} for CompressedTensorsWNA16")
  67. self._kernel_backends_being_used.add(kernel_type.__name__)
  68. # If group_size is -1, we are in channelwise case.
  69. group_size = self.group_size if self.group_size != -1 else input_size
  70. row_parallel = (input_size != input_size_per_partition)
  71. partition_scales = not marlin_repeat_scales_on_all_ranks(
  72. self.has_g_idx, self.group_size, row_parallel)
  73. scales_and_zp_size = input_size // group_size
  74. if partition_scales:
  75. assert input_size_per_partition % group_size == 0
  76. scales_and_zp_size = input_size_per_partition // group_size
  77. weight = PackedAphroditeParameter(input_dim=1,
  78. output_dim=0,
  79. weight_loader=weight_loader,
  80. packed_factor=self.pack_factor,
  81. packed_dim=1,
  82. data=torch.empty(
  83. output_size_per_partition,
  84. input_size_per_partition //
  85. self.pack_factor,
  86. dtype=torch.int32,
  87. ))
  88. weight_scale_args = {
  89. "weight_loader":
  90. weight_loader,
  91. "data":
  92. torch.empty(
  93. output_size_per_partition,
  94. scales_and_zp_size,
  95. dtype=params_dtype,
  96. )
  97. }
  98. if not partition_scales:
  99. weight_scale = ChannelQuantScaleParameter(output_dim=0,
  100. **weight_scale_args)
  101. else:
  102. weight_scale = GroupQuantScaleParameter(output_dim=0,
  103. input_dim=1,
  104. **weight_scale_args)
  105. # A 2D array defining the original shape of the weights
  106. # before packing
  107. weight_shape = BaseAphroditeParameter(data=torch.empty(2,
  108. dtype=torch.int64),
  109. weight_loader=weight_loader)
  110. layer.register_parameter("weight_packed", weight)
  111. layer.register_parameter("weight_scale", weight_scale)
  112. layer.register_parameter("weight_shape", weight_shape)
  113. # group index (for activation reordering)
  114. if self.has_g_idx:
  115. weight_g_idx = RowAphroditeParameter(data=torch.empty(
  116. input_size_per_partition,
  117. dtype=torch.int32,
  118. ),
  119. input_dim=0,
  120. weight_loader=weight_loader)
  121. layer.register_parameter("weight_g_idx", weight_g_idx)
  122. self.kernel = kernel_type(mp_linear_kernel_config,
  123. w_q_param_name="weight_packed",
  124. w_s_param_name="weight_scale",
  125. w_zp_param_name=None,
  126. w_gidx_param_name="weight_g_idx")
  127. # Checkpoints are serialized in compressed-tensors format, which is
  128. # different from the format the kernel may want. Handle repacking here.
  129. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  130. self.kernel.process_weights_after_loading(layer)
  131. def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
  132. bias: Optional[torch.Tensor]) -> torch.Tensor:
  133. return self.kernel.apply_weights(layer, x, bias)