1
0

marlin_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. from typing import List, Optional, Tuple
  2. import numpy
  3. import torch
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.platforms import current_platform
  6. from aphrodite.scalar_type import ScalarType, scalar_types
  7. from .quant_utils import pack_cols, unpack_cols
  8. GPTQ_MARLIN_TILE = 16
  9. GPTQ_MARLIN_MIN_THREAD_N = 64
  10. GPTQ_MARLIN_MIN_THREAD_K = 128
  11. GPTQ_MARLIN_MAX_PARALLEL = 16
  12. MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  13. # In case there is a performance issue with Marlin, the variable below can be
  14. # changed to False, which allows Marlin to perform global reductions in fp16
  15. # precision (instead of fp32), and therefore, save on some memory movements.
  16. USE_FP32_REDUCE_DEFAULT = True
  17. # For binary size and compile time, we don't support the same types for with and
  18. # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
  19. # TODO: we may want to move this into the C++ so its closer to the actual impl
  20. def query_marlin_supported_quant_types(has_zp: bool,
  21. device_capability: Optional[int] = None
  22. ):
  23. if device_capability is None:
  24. major, minor = current_platform.get_device_capability()
  25. device_capability = major * 10 + minor
  26. if device_capability < 80:
  27. return []
  28. if has_zp:
  29. # AWQ style, unsigned + runtime zero-point
  30. return [scalar_types.uint4, scalar_types.uint8]
  31. else:
  32. # GPTQ style, unsigned + symmetric bias
  33. # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
  34. # to add `scalar_types.float8_e4m3fn` here
  35. return [scalar_types.uint4b8, scalar_types.uint8b128]
  36. def _check_marlin_supported(
  37. quant_type: ScalarType,
  38. group_size: Optional[int],
  39. has_zp: bool,
  40. device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
  41. if device_capability is None:
  42. major, minor = current_platform.get_device_capability()
  43. device_capability = major * 10 + minor
  44. supported_types = query_marlin_supported_quant_types(
  45. has_zp, device_capability)
  46. if quant_type not in supported_types:
  47. return (False, f"Marlin does not support weight_bits = {quant_type}. "
  48. f"Only types = {supported_types} "
  49. f"are supported (for group_size = {group_size}, "
  50. f"device_capability = {device_capability}, zp = {has_zp}).")
  51. if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
  52. return (False, f"Marlin does not support group_size = {group_size}. "
  53. f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
  54. "are supported.")
  55. return True, None
  56. def check_marlin_supported(quant_type: ScalarType,
  57. group_size: int,
  58. has_zp: bool = False,
  59. device_capability: Optional[int] = None) -> bool:
  60. cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
  61. device_capability)
  62. return cond
  63. def verify_marlin_supported(quant_type: ScalarType,
  64. group_size: int,
  65. has_zp: bool = False) -> None:
  66. cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
  67. if not cond:
  68. assert err_msg is not None
  69. raise ValueError(err_msg)
  70. def verify_marlin_supports_shape(output_size_per_partition: int,
  71. input_size_per_partition: int,
  72. input_size: int, group_size: int) -> None:
  73. # Validate output_size_per_partition
  74. if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
  75. raise ValueError(f"Weight output_size_per_partition = "
  76. f"{output_size_per_partition} is not divisible by "
  77. f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
  78. "Consider reducing tensor_parallel_size or running "
  79. "with --quantization gptq.")
  80. # Validate input_size_per_partition
  81. if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
  82. raise ValueError(f"Weight input_size_per_partition = "
  83. f"{input_size_per_partition} is not divisible "
  84. f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
  85. "Consider reducing tensor_parallel_size or running "
  86. "with --quantization gptq.")
  87. if (group_size < input_size
  88. and input_size_per_partition % group_size != 0):
  89. raise ValueError(
  90. f"Weight input_size_per_partition = {input_size_per_partition}"
  91. f" is not divisible by group_size = {group_size}."
  92. "Consider reducing tensor_parallel_size or running "
  93. "with --quantization gptq.")
  94. def check_marlin_supports_shape(output_size_per_partition: int,
  95. input_size_per_partition: int,
  96. input_size: int, group_size: int) \
  97. -> Tuple[bool, Optional[str]]:
  98. try:
  99. verify_marlin_supports_shape(output_size_per_partition,
  100. input_size_per_partition, input_size,
  101. group_size)
  102. except ValueError as e:
  103. return False, e.__str__()
  104. return True, None
  105. def marlin_make_workspace(output_size_per_partition: int,
  106. device: torch.device) -> torch.Tensor:
  107. max_workspace_size = (output_size_per_partition //
  108. GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
  109. return torch.zeros(max_workspace_size,
  110. dtype=torch.int,
  111. device=device,
  112. requires_grad=False)
  113. def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
  114. return (not act_order) or (act_order and not is_row_parallel)
  115. def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
  116. is_row_parallel: bool) -> bool:
  117. # Need to repeat scales on every rank if act_ordering or
  118. # channelwise and RowParallelLinear
  119. is_channelwise = group_size == -1
  120. return act_order or (is_channelwise and is_row_parallel)
  121. def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
  122. return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
  123. requires_grad=False)
  124. def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
  125. return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
  126. requires_grad=False)
  127. def marlin_sort_g_idx(
  128. g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  129. g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
  130. return g_idx[g_idx_sort_indices], g_idx_sort_indices
  131. def get_scale_perms():
  132. scale_perm: List[int] = []
  133. for i in range(8):
  134. scale_perm.extend([i + 8 * j for j in range(8)])
  135. scale_perm_single: List[int] = []
  136. for i in range(4):
  137. scale_perm_single.extend(
  138. [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
  139. return scale_perm, scale_perm_single
  140. def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
  141. group_size: int) -> torch.Tensor:
  142. scale_perm, scale_perm_single = get_scale_perms()
  143. if group_size < size_k and group_size != -1:
  144. s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
  145. else:
  146. s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
  147. s = s.reshape((-1, size_n)).contiguous()
  148. return s
  149. def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
  150. num_bits: int) -> torch.Tensor:
  151. # Permute zero-points in a similar way to scales, but do not use the
  152. # "single" permutation, since zero-points are applied on every MMA
  153. scale_perm, _ = get_scale_perms()
  154. zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
  155. # Interleave column dim (for the dequantize code) and pack it to int32
  156. if num_bits == 4:
  157. interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
  158. elif num_bits == 8:
  159. interleave = numpy.array([0, 2, 1, 3])
  160. else:
  161. raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
  162. zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
  163. zp = zp.reshape((-1, size_n)).contiguous()
  164. zp = pack_cols(zp, num_bits, size_k, size_n)
  165. return zp
  166. def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
  167. size_n: int, num_bits: int) -> torch.Tensor:
  168. # AWQ zero-points are quantized and packed on the column dim.
  169. # In addition, the values are permuted based on dequantizer.
  170. # Here we undo both of these, and then apply marlin permutation
  171. # and pack it back.
  172. q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
  173. # Undo interleaving (use argsort(..) to get inverse perm)
  174. if num_bits == 4:
  175. undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
  176. elif num_bits == 8:
  177. undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
  178. else:
  179. raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
  180. q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
  181. q_zp = q_zp.reshape((-1, size_n)).contiguous()
  182. marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
  183. return marlin_zp
  184. def apply_gptq_marlin_linear(
  185. input: torch.Tensor,
  186. weight: torch.Tensor,
  187. weight_scale: torch.Tensor,
  188. weight_zp: torch.Tensor,
  189. g_idx: torch.Tensor,
  190. g_idx_sort_indices: torch.Tensor,
  191. workspace: torch.Tensor,
  192. wtype: ScalarType,
  193. output_size_per_partition: int,
  194. input_size_per_partition: int,
  195. is_k_full: bool,
  196. bias: Optional[torch.Tensor] = None,
  197. use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
  198. reshaped_x = input.reshape(-1, input.shape[-1])
  199. out_shape = input.shape[:-1] + (output_size_per_partition, )
  200. output = ops.gptq_marlin_gemm(reshaped_x,
  201. weight,
  202. weight_scale,
  203. weight_zp,
  204. g_idx,
  205. g_idx_sort_indices,
  206. workspace,
  207. wtype,
  208. size_m=reshaped_x.shape[0],
  209. size_n=output_size_per_partition,
  210. size_k=input_size_per_partition,
  211. is_k_full=is_k_full,
  212. has_zp=False,
  213. use_fp32_reduce=use_fp32_reduce,
  214. is_zp_float=False)
  215. if bias is not None:
  216. output.add_(bias) # In-place add
  217. return output.reshape(out_shape)
  218. def apply_awq_marlin_linear(
  219. input: torch.Tensor,
  220. weight: torch.Tensor,
  221. weight_scale: torch.Tensor,
  222. weight_zp: torch.Tensor,
  223. g_idx: torch.Tensor,
  224. g_idx_sort_indices: torch.Tensor,
  225. workspace: torch.Tensor,
  226. quant_type: ScalarType,
  227. output_size_per_partition: int,
  228. input_size_per_partition: int,
  229. bias: Optional[torch.Tensor] = None,
  230. use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
  231. reshaped_x = input.reshape(-1, input.shape[-1])
  232. out_shape = input.shape[:-1] + (output_size_per_partition, )
  233. output = ops.gptq_marlin_gemm(reshaped_x,
  234. weight,
  235. weight_scale,
  236. weight_zp,
  237. g_idx,
  238. g_idx_sort_indices,
  239. workspace,
  240. quant_type,
  241. size_m=reshaped_x.shape[0],
  242. size_n=output_size_per_partition,
  243. size_k=input_size_per_partition,
  244. is_k_full=True,
  245. has_zp=True,
  246. use_fp32_reduce=use_fp32_reduce,
  247. is_zp_float=True)
  248. if bias is not None:
  249. output.add_(bias) # In-place add
  250. return output.reshape(out_shape)