compressed_tensors_wNa16.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from typing import Callable, List, Optional
  2. import torch
  3. from aphrodite import _custom_ops as ops
  4. from aphrodite.modeling.parameter import (BaseAphroditeParameter,
  5. ChannelQuantScaleParameter,
  6. GroupQuantScaleParameter,
  7. PackedAphroditeParameter)
  8. from aphrodite.quantization.compressed_tensors.schemes import (
  9. CompressedTensorsScheme)
  10. from aphrodite.quantization.utils.marlin_utils import (
  11. apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
  12. marlin_permute_scales, replace_tensor, verify_marlin_supported,
  13. verify_marlin_supports_shape)
  14. from aphrodite.scalar_type import scalar_types
  15. __all__ = ["CompressedTensorsWNA16"]
  16. WNA16_SUPPORTED_TYPES_MAP = {
  17. 4: scalar_types.uint4b8,
  18. 8: scalar_types.uint8b128,
  19. }
  20. WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
  21. class CompressedTensorsWNA16(CompressedTensorsScheme):
  22. def __init__(self,
  23. strategy: str,
  24. num_bits: int,
  25. group_size: Optional[int] = None):
  26. self.pack_factor = 32 // num_bits
  27. self.strategy = strategy
  28. self.group_size = -1 if group_size is None else group_size
  29. if self.group_size == -1 and self.strategy != "channel":
  30. raise ValueError("Marlin kernels require group quantization or "
  31. "channelwise quantization, but found no group "
  32. "size and strategy is not channelwise.")
  33. if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
  34. raise ValueError(
  35. f"Unsupported num_bits = {num_bits}. "
  36. f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
  37. self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
  38. # Verify supported on platform.
  39. verify_marlin_supported(quant_type=self.quant_type,
  40. group_size=self.group_size)
  41. @classmethod
  42. def get_min_capability(cls) -> int:
  43. # ampere and up
  44. return 80
  45. def create_weights(self, layer: torch.nn.Module, input_size: int,
  46. output_partition_sizes: List[int],
  47. input_size_per_partition: int,
  48. params_dtype: torch.dtype, weight_loader: Callable,
  49. **kwargs):
  50. output_size_per_partition = sum(output_partition_sizes)
  51. # If group_size is -1, we are in channelwise case.
  52. channelwise = (self.group_size == -1)
  53. group_size = self.group_size if self.group_size != -1 else input_size
  54. row_parallel = (input_size != input_size_per_partition)
  55. # In the case of channelwise quantization, we need to replicate the
  56. # scales across all gpus.
  57. partition_scales = (row_parallel and not channelwise)
  58. verify_marlin_supports_shape(
  59. output_size_per_partition=output_size_per_partition,
  60. input_size_per_partition=input_size_per_partition,
  61. input_size=input_size,
  62. group_size=group_size)
  63. scales_and_zp_size = input_size // group_size
  64. if partition_scales:
  65. assert input_size_per_partition % group_size == 0
  66. scales_and_zp_size = input_size_per_partition // group_size
  67. weight = PackedAphroditeParameter(input_dim=1,
  68. output_dim=0,
  69. weight_loader=weight_loader,
  70. packed_factor=self.pack_factor,
  71. packed_dim=1,
  72. data=torch.empty(
  73. output_size_per_partition,
  74. input_size_per_partition //
  75. self.pack_factor,
  76. dtype=torch.int32,
  77. ))
  78. weight_scale_args = {
  79. "weight_loader":
  80. weight_loader,
  81. "data":
  82. torch.empty(
  83. output_size_per_partition,
  84. scales_and_zp_size,
  85. dtype=params_dtype,
  86. )
  87. }
  88. if not partition_scales:
  89. weight_scale = ChannelQuantScaleParameter(output_dim=0,
  90. **weight_scale_args)
  91. else:
  92. weight_scale = GroupQuantScaleParameter(output_dim=0,
  93. input_dim=1,
  94. **weight_scale_args)
  95. # A 2D array defining the original shape of the weights
  96. # before packing
  97. weight_shape = BaseAphroditeParameter(data=torch.empty(2,
  98. dtype=torch.int64),
  99. weight_loader=weight_loader)
  100. layer.register_parameter("weight_packed", weight)
  101. layer.register_parameter("weight_scale", weight_scale)
  102. layer.register_parameter("weight_shape", weight_shape)
  103. layer.input_size_per_partition = input_size_per_partition
  104. layer.output_size_per_partition = output_size_per_partition
  105. layer.input_size = input_size
  106. layer.group_size = group_size
  107. # Checkpoints are serialized in compressed-tensors format, which is
  108. # different from marlin format. Handle repacking here.
  109. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  110. device = layer.weight_packed.device
  111. # Allocate marlin workspace.
  112. layer.workspace = marlin_make_workspace(
  113. layer.output_size_per_partition, device)
  114. # Act-order not supported in compressed-tensors yet, so set to empty.
  115. layer.g_idx = marlin_make_empty_g_idx(device)
  116. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  117. # No zero-point
  118. layer.weight_zp = marlin_make_empty_g_idx(device)
  119. # Update for kernel
  120. layer.weight_packed = torch.nn.Parameter(
  121. layer.weight_packed.t().contiguous(), requires_grad=False)
  122. layer.weight_scale = torch.nn.Parameter(
  123. layer.weight_scale.squeeze().t().contiguous(), requires_grad=False)
  124. # Repack weights from compressed-tensors format to marlin format.
  125. marlin_qweight = ops.gptq_marlin_repack(
  126. layer.weight_packed,
  127. perm=layer.g_idx_sort_indices,
  128. size_k=layer.input_size_per_partition,
  129. size_n=layer.output_size_per_partition,
  130. num_bits=self.quant_type.size_bits)
  131. replace_tensor(layer, "weight_packed", marlin_qweight)
  132. # Permute scales from compressed-tensors format to marlin format.
  133. marlin_scales = marlin_permute_scales(
  134. layer.weight_scale,
  135. size_k=layer.input_size_per_partition,
  136. size_n=layer.output_size_per_partition,
  137. group_size=layer.group_size)
  138. replace_tensor(layer, "weight_scale", marlin_scales)
  139. def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
  140. bias: Optional[torch.Tensor]) -> torch.Tensor:
  141. return apply_gptq_marlin_linear(
  142. input=x,
  143. weight=layer.weight_packed,
  144. weight_scale=layer.weight_scale,
  145. weight_zp=layer.weight_zp,
  146. g_idx=layer.g_idx,
  147. g_idx_sort_indices=layer.g_idx_sort_indices,
  148. workspace=layer.workspace,
  149. wtype=self.quant_type,
  150. output_size_per_partition=layer.output_size_per_partition,
  151. input_size_per_partition=layer.input_size_per_partition,
  152. is_k_full=True,
  153. bias=bias)