compressed_tensors_wNa16.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. from typing import Callable, List, Optional
  2. import torch
  3. from torch.nn import Parameter
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.modeling.utils import set_weight_attrs
  6. from aphrodite.quantization.compressed_tensors.schemes import (
  7. CompressedTensorsScheme)
  8. from aphrodite.quantization.utils.marlin_utils import (
  9. apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
  10. marlin_permute_scales, replace_tensor, verify_marlin_supported,
  11. verify_marlin_supports_shape)
  12. from aphrodite.scalar_type import scalar_types
  13. __all__ = ["CompressedTensorsWNA16"]
  14. WNA16_SUPPORTED_TYPES_MAP = {
  15. 4: scalar_types.uint4b8,
  16. 8: scalar_types.uint8b128,
  17. }
  18. WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
  19. class CompressedTensorsWNA16(CompressedTensorsScheme):
  20. def __init__(self,
  21. strategy: str,
  22. num_bits: int,
  23. group_size: Optional[int] = None):
  24. self.pack_factor = 32 // num_bits
  25. self.strategy = strategy
  26. self.group_size: int
  27. if group_size is None:
  28. if self.strategy != "channel":
  29. raise ValueError(
  30. "Marlin kernels require group quantization or "
  31. "channelwise quantization, but found no group "
  32. "size and strategy is not channelwise.")
  33. self.group_size = -1
  34. else:
  35. self.group_size = group_size
  36. if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
  37. raise ValueError(
  38. f"Unsupported num_bits = {num_bits}. "
  39. f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
  40. self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
  41. # Verify supported on platform.
  42. verify_marlin_supported(quant_type=self.quant_type,
  43. group_size=self.group_size)
  44. @classmethod
  45. def get_min_capability(cls) -> int:
  46. # ampere and up
  47. return 80
  48. def create_weights(self, layer: torch.nn.Module, input_size: int,
  49. output_partition_sizes: List[int],
  50. input_size_per_partition: int,
  51. params_dtype: torch.dtype, weight_loader: Callable,
  52. **kwargs):
  53. output_size_per_partition = sum(output_partition_sizes)
  54. # If group_size is -1, we are in channelwise case.
  55. channelwise = (self.group_size == -1)
  56. group_size = input_size if channelwise else self.group_size
  57. row_parallel = (input_size != input_size_per_partition)
  58. # In the case of channelwise quantization, we need to replicate the
  59. # scales across all gpus.
  60. partition_scales = (row_parallel and not channelwise)
  61. verify_marlin_supports_shape(
  62. output_size_per_partition=output_size_per_partition,
  63. input_size_per_partition=input_size_per_partition,
  64. input_size=input_size,
  65. group_size=group_size)
  66. weight_scale_dim = None
  67. scales_and_zp_size = input_size // group_size
  68. if partition_scales:
  69. assert input_size_per_partition % group_size == 0
  70. weight_scale_dim = 1
  71. scales_and_zp_size = input_size_per_partition // group_size
  72. weight = Parameter(
  73. torch.empty(
  74. output_size_per_partition,
  75. input_size_per_partition // self.pack_factor,
  76. dtype=torch.int32,
  77. ),
  78. requires_grad=False,
  79. )
  80. set_weight_attrs(
  81. weight, {
  82. "input_dim": 1,
  83. "output_dim": 0,
  84. "packed_dim": 1,
  85. "pack_factor": self.pack_factor,
  86. "weight_loader": weight_loader
  87. })
  88. layer.register_parameter("weight_packed", weight)
  89. weight_scale = Parameter(
  90. torch.empty(
  91. output_size_per_partition,
  92. scales_and_zp_size,
  93. dtype=params_dtype,
  94. ),
  95. requires_grad=False,
  96. )
  97. set_weight_attrs(
  98. weight_scale, {
  99. "weight_loader": weight_loader,
  100. "input_dim": weight_scale_dim,
  101. "output_dim": 0
  102. })
  103. layer.register_parameter("weight_scale", weight_scale)
  104. # A 2D array defining the original shape of the weights
  105. # before packing
  106. weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
  107. requires_grad=False)
  108. layer.register_parameter("weight_shape", weight_shape)
  109. set_weight_attrs(weight_shape, {
  110. "weight_loader": weight_loader,
  111. "ignore_warning": True,
  112. })
  113. layer.input_size_per_partition = input_size_per_partition
  114. layer.output_size_per_partition = output_size_per_partition
  115. layer.input_size = input_size
  116. layer.group_size = group_size
  117. # Checkpoints are serialized in compressed-tensors format, which is
  118. # different from marlin format. Handle repacking here.
  119. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  120. device = layer.weight_packed.device
  121. # Allocate marlin workspace.
  122. layer.workspace = marlin_make_workspace(
  123. layer.output_size_per_partition, device)
  124. # Act-order not supported in compressed-tensors yet, so set to empty.
  125. layer.g_idx = marlin_make_empty_g_idx(device)
  126. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  127. # No zero-point
  128. layer.weight_zp = marlin_make_empty_g_idx(device)
  129. # Repack weights from compressed-tensors format to marlin format.
  130. marlin_qweight = ops.gptq_marlin_repack(
  131. layer.weight_packed.t().contiguous(),
  132. perm=layer.g_idx_sort_indices,
  133. size_k=layer.input_size_per_partition,
  134. size_n=layer.output_size_per_partition,
  135. num_bits=self.quant_type.size_bits)
  136. replace_tensor(layer, "weight_packed", marlin_qweight)
  137. # Permute scales from compressed-tensors format to marlin format.
  138. marlin_scales = marlin_permute_scales(
  139. layer.weight_scale.squeeze().t().contiguous(),
  140. size_k=layer.input_size_per_partition,
  141. size_n=layer.output_size_per_partition,
  142. group_size=layer.group_size)
  143. replace_tensor(layer, "weight_scale", marlin_scales)
  144. def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
  145. bias: Optional[torch.Tensor]) -> torch.Tensor:
  146. return apply_gptq_marlin_linear(
  147. input=x,
  148. weight=layer.weight_packed,
  149. weight_scale=layer.weight_scale,
  150. weight_zp=layer.weight_zp,
  151. g_idx=layer.g_idx,
  152. g_idx_sort_indices=layer.g_idx_sort_indices,
  153. workspace=layer.workspace,
  154. wtype=self.quant_type,
  155. output_size_per_partition=layer.output_size_per_partition,
  156. input_size_per_partition=layer.input_size_per_partition,
  157. is_k_full=True,
  158. bias=bias)