compressed_tensors_wNa16.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from typing import Callable, List, Optional
  2. import torch
  3. from torch.nn import Parameter
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.modeling.utils import set_weight_attrs
  6. from aphrodite.quantization.compressed_tensors.schemes import \
  7. CompressedTensorsScheme
  8. from aphrodite.quantization.utils.marlin_utils import (
  9. apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
  10. marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported,
  11. verify_marlin_supports_shape)
  12. __all__ = ["CompressedTensorsWNA16"]
  13. WNA16_SUPPORTED_BITS = [4, 8]
  14. class CompressedTensorsWNA16(CompressedTensorsScheme):
  15. def __init__(self,
  16. strategy: str,
  17. num_bits: int,
  18. group_size: Optional[int] = None):
  19. self.num_bits = num_bits
  20. self.pack_factor = 32 // self.num_bits
  21. self.strategy = strategy
  22. self.group_size: int
  23. if group_size is None:
  24. if self.strategy != "channel":
  25. raise ValueError(
  26. "Marlin kernels require group quantization or "
  27. "channelwise quantization, but found no group "
  28. "size and strategy is not channelwise.")
  29. self.group_size = -1
  30. else:
  31. self.group_size = group_size
  32. # Verify supported on platform.
  33. verify_gptq_marlin_supported(num_bits=self.num_bits,
  34. group_size=self.group_size,
  35. is_sym=True)
  36. def get_min_capability(self) -> int:
  37. # ampere and up
  38. return 80
  39. def create_weights(self, layer: torch.nn.Module, input_size: int,
  40. output_partition_sizes: List[int],
  41. input_size_per_partition: int,
  42. params_dtype: torch.dtype, weight_loader: Callable,
  43. **kwargs):
  44. output_size_per_partition = sum(output_partition_sizes)
  45. # If group_size is -1, we are in channelwise case.
  46. group_size = input_size if self.group_size == -1 else self.group_size
  47. verify_marlin_supports_shape(
  48. output_size_per_partition=output_size_per_partition,
  49. input_size_per_partition=input_size_per_partition,
  50. input_size=input_size,
  51. group_size=group_size)
  52. weight_scale_dim = None
  53. scales_and_zp_size = input_size // group_size
  54. if (input_size != input_size_per_partition
  55. and self.group_size is not None):
  56. weight_scale_dim = 1
  57. scales_and_zp_size = input_size_per_partition // group_size
  58. weight = Parameter(
  59. torch.empty(
  60. output_size_per_partition,
  61. input_size_per_partition // self.pack_factor,
  62. dtype=torch.int32,
  63. ),
  64. requires_grad=False,
  65. )
  66. set_weight_attrs(
  67. weight, {
  68. "input_dim": 1,
  69. "output_dim": 0,
  70. "packed_dim": 1,
  71. "pack_factor": self.pack_factor,
  72. "weight_loader": weight_loader
  73. })
  74. layer.register_parameter("weight_packed", weight)
  75. weight_scale = Parameter(
  76. torch.empty(
  77. output_size_per_partition,
  78. scales_and_zp_size,
  79. dtype=params_dtype,
  80. ),
  81. requires_grad=False,
  82. )
  83. set_weight_attrs(
  84. weight_scale, {
  85. "weight_loader": weight_loader,
  86. "input_dim": weight_scale_dim,
  87. "output_dim": 0
  88. })
  89. layer.register_parameter("weight_scale", weight_scale)
  90. # A 2D array defining the original shape of the weights
  91. # before packing
  92. weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
  93. requires_grad=False)
  94. layer.register_parameter("weight_shape", weight_shape)
  95. set_weight_attrs(weight_shape, {
  96. "weight_loader": weight_loader,
  97. "ignore_warning": True,
  98. })
  99. layer.input_size_per_partition = input_size_per_partition
  100. layer.output_size_per_partition = output_size_per_partition
  101. layer.input_size = input_size
  102. layer.group_size = group_size
  103. # Checkpoints are serialized in compressed-tensors format, which is
  104. # different from marlin format. Handle repacking here.
  105. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  106. device = layer.weight_packed.device
  107. # Allocate marlin workspace.
  108. layer.workspace = marlin_make_workspace(
  109. layer.output_size_per_partition, device)
  110. # Act-order not supported in compressed-tensors yet, so set to empty.
  111. layer.g_idx = marlin_make_empty_g_idx(device)
  112. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  113. # No zero-point
  114. layer.weight_zp = marlin_make_empty_g_idx(device)
  115. # Repack weights from compressed-tensors format to marlin format.
  116. marlin_qweight = ops.gptq_marlin_repack(
  117. layer.weight_packed.t().contiguous(),
  118. perm=layer.g_idx_sort_indices,
  119. size_k=layer.input_size_per_partition,
  120. size_n=layer.output_size_per_partition,
  121. num_bits=self.num_bits)
  122. replace_tensor(layer, "weight_packed", marlin_qweight)
  123. # Permute scales from compressed-tensors format to marlin format.
  124. marlin_scales = marlin_permute_scales(
  125. layer.weight_scale.squeeze().t().contiguous(),
  126. size_k=layer.input_size_per_partition,
  127. size_n=layer.output_size_per_partition,
  128. group_size=layer.group_size)
  129. replace_tensor(layer, "weight_scale", marlin_scales)
  130. def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
  131. bias: Optional[torch.Tensor]) -> torch.Tensor:
  132. return apply_gptq_marlin_linear(
  133. input=x,
  134. weight=layer.weight_packed,
  135. weight_scale=layer.weight_scale,
  136. weight_zp=layer.weight_zp,
  137. g_idx=layer.g_idx,
  138. g_idx_sort_indices=layer.g_idx_sort_indices,
  139. workspace=layer.workspace,
  140. num_bits=self.num_bits,
  141. output_size_per_partition=layer.output_size_per_partition,
  142. input_size_per_partition=layer.input_size_per_partition,
  143. is_k_full=True,
  144. bias=bias)