marlin_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. from typing import List, Optional, Tuple
  2. import numpy
  3. import torch
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.platforms import current_platform
  6. from aphrodite.scalar_type import ScalarType, scalar_types
  7. from .quant_utils import pack_cols, unpack_cols
  8. GPTQ_MARLIN_TILE = 16
  9. GPTQ_MARLIN_MIN_THREAD_N = 64
  10. GPTQ_MARLIN_MIN_THREAD_K = 128
  11. GPTQ_MARLIN_MAX_PARALLEL = 16
  12. MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  13. # In case there is a performance issue with Marlin, the variable below can be
  14. # changed to False, which allows Marlin to perform global reductions in fp16
  15. # precision (instead of fp32), and therefore, save on some memory movements.
  16. USE_FP32_REDUCE_DEFAULT = True
  17. # For binary size and compile time, we don't support the same types for with and
  18. # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
  19. # TODO: we may want to move this into the C++ so its closer to the actual impl
  20. def query_marlin_supported_quant_types(has_zp: bool,
  21. device_capability: Optional[int] = None
  22. ):
  23. if device_capability is None:
  24. major, minor = current_platform.get_device_capability()
  25. device_capability = major * 10 + minor
  26. if device_capability < 80:
  27. return []
  28. if has_zp:
  29. # AWQ style, unsigned + runtime zero-point
  30. return [scalar_types.uint4, scalar_types.uint8]
  31. else:
  32. # GPTQ style, unsigned + symmetric bias
  33. # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
  34. # to add `scalar_types.float8_e4m3fn` here
  35. return [scalar_types.uint4b8, scalar_types.uint8b128]
  36. def _check_marlin_supported(
  37. quant_type: ScalarType,
  38. group_size: Optional[int],
  39. has_zp: bool,
  40. device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
  41. if device_capability is None:
  42. major, minor = current_platform.get_device_capability()
  43. device_capability = major * 10 + minor
  44. supported_types = query_marlin_supported_quant_types(
  45. has_zp, device_capability)
  46. if quant_type not in supported_types:
  47. return (False, f"Marlin does not support weight_bits = {quant_type}. "
  48. f"Only types = {supported_types} "
  49. f"are supported (for group_size = {group_size}, "
  50. f"device_capability = {device_capability}, zp = {has_zp}).")
  51. if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
  52. return (False, f"Marlin does not support group_size = {group_size}. "
  53. f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
  54. "are supported.")
  55. return True, None
  56. def check_marlin_supported(quant_type: ScalarType,
  57. group_size: int,
  58. has_zp: bool = False,
  59. device_capability: Optional[int] = None) -> bool:
  60. cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
  61. device_capability)
  62. return cond
  63. def verify_marlin_supported(quant_type: ScalarType,
  64. group_size: int,
  65. has_zp: bool = False) -> None:
  66. cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
  67. if not cond:
  68. assert err_msg is not None
  69. raise ValueError(err_msg)
  70. def verify_marlin_supports_shape(output_size_per_partition: int,
  71. input_size_per_partition: int,
  72. input_size: int, group_size: int) -> None:
  73. # Validate output_size_per_partition
  74. if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
  75. raise ValueError(f"Weight output_size_per_partition = "
  76. f"{output_size_per_partition} is not divisible by "
  77. f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
  78. "Consider reducing tensor_parallel_size or running "
  79. "with --quantization gptq.")
  80. # Validate input_size_per_partition
  81. if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
  82. raise ValueError(f"Weight input_size_per_partition = "
  83. f"{input_size_per_partition} is not divisible "
  84. f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
  85. "Consider reducing tensor_parallel_size or running "
  86. "with --quantization gptq.")
  87. if (group_size < input_size
  88. and input_size_per_partition % group_size != 0):
  89. raise ValueError(
  90. f"Weight input_size_per_partition = {input_size_per_partition}"
  91. f" is not divisible by group_size = {group_size}."
  92. "Consider reducing tensor_parallel_size or running "
  93. "with --quantization gptq.")
  94. def marlin_make_workspace(output_size_per_partition: int,
  95. device: torch.device) -> torch.Tensor:
  96. max_workspace_size = (output_size_per_partition //
  97. GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
  98. return torch.zeros(max_workspace_size,
  99. dtype=torch.int,
  100. device=device,
  101. requires_grad=False)
  102. def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
  103. return (not act_order) or (act_order and not is_row_parallel)
  104. def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
  105. is_row_parallel: bool) -> bool:
  106. # Need to repeat scales on every rank if act_ordering or
  107. # channelwise and RowParallelLinear
  108. is_channelwise = group_size == -1
  109. return act_order or (is_channelwise and is_row_parallel)
  110. def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
  111. return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
  112. requires_grad=False)
  113. def marlin_sort_g_idx(
  114. g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  115. g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
  116. return g_idx[g_idx_sort_indices], g_idx_sort_indices
  117. def get_scale_perms():
  118. scale_perm: List[int] = []
  119. for i in range(8):
  120. scale_perm.extend([i + 8 * j for j in range(8)])
  121. scale_perm_single: List[int] = []
  122. for i in range(4):
  123. scale_perm_single.extend(
  124. [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
  125. return scale_perm, scale_perm_single
  126. def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
  127. group_size: int) -> torch.Tensor:
  128. scale_perm, scale_perm_single = get_scale_perms()
  129. if group_size < size_k and group_size != -1:
  130. s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
  131. else:
  132. s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
  133. s = s.reshape((-1, size_n)).contiguous()
  134. return s
  135. def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
  136. num_bits: int) -> torch.Tensor:
  137. # Permute zero-points in a similar way to scales, but do not use the
  138. # "single" permutation, since zero-points are applied on every MMA
  139. scale_perm, _ = get_scale_perms()
  140. zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
  141. # Interleave column dim (for the dequantize code) and pack it to int32
  142. if num_bits == 4:
  143. interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
  144. elif num_bits == 8:
  145. interleave = numpy.array([0, 2, 1, 3])
  146. else:
  147. raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
  148. zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
  149. zp = zp.reshape((-1, size_n)).contiguous()
  150. zp = pack_cols(zp, num_bits, size_k, size_n)
  151. return zp
  152. def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
  153. size_n: int, num_bits: int) -> torch.Tensor:
  154. # AWQ zero-points are quantized and packed on the column dim.
  155. # In addition, the values are permuted based on dequantizer.
  156. # Here we undo both of these, and then apply marlin permutation
  157. # and pack it back.
  158. q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
  159. # Undo interleaving (use argsort(..) to get inverse perm)
  160. if num_bits == 4:
  161. undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
  162. elif num_bits == 8:
  163. undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
  164. else:
  165. raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
  166. q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
  167. q_zp = q_zp.reshape((-1, size_n)).contiguous()
  168. marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
  169. return marlin_zp
  170. # Newly generated tensors need to replace existing tensors that are
  171. # already registered as parameters by Aphrodite (and won't be freed)
  172. def replace_tensor(layer: torch.nn.Module, name: str,
  173. new_t: torch.Tensor) -> None:
  174. # It is important to use resize_() here since it ensures
  175. # the same buffer is reused
  176. getattr(layer, name).resize_(new_t.shape)
  177. getattr(layer, name).copy_(new_t)
  178. del new_t
  179. def apply_gptq_marlin_linear(
  180. input: torch.Tensor,
  181. weight: torch.Tensor,
  182. weight_scale: torch.Tensor,
  183. weight_zp: torch.Tensor,
  184. g_idx: torch.Tensor,
  185. g_idx_sort_indices: torch.Tensor,
  186. workspace: torch.Tensor,
  187. wtype: ScalarType,
  188. output_size_per_partition: int,
  189. input_size_per_partition: int,
  190. is_k_full: bool,
  191. bias: Optional[torch.Tensor] = None,
  192. use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
  193. reshaped_x = input.reshape(-1, input.shape[-1])
  194. out_shape = input.shape[:-1] + (output_size_per_partition, )
  195. output = ops.gptq_marlin_gemm(reshaped_x,
  196. weight,
  197. weight_scale,
  198. weight_zp,
  199. g_idx,
  200. g_idx_sort_indices,
  201. workspace,
  202. wtype,
  203. size_m=reshaped_x.shape[0],
  204. size_n=output_size_per_partition,
  205. size_k=input_size_per_partition,
  206. is_k_full=is_k_full,
  207. has_zp=False,
  208. use_fp32_reduce=use_fp32_reduce,
  209. is_zp_float=False)
  210. if bias is not None:
  211. output.add_(bias) # In-place add
  212. return output.reshape(out_shape)
  213. def apply_awq_marlin_linear(
  214. input: torch.Tensor,
  215. weight: torch.Tensor,
  216. weight_scale: torch.Tensor,
  217. weight_zp: torch.Tensor,
  218. g_idx: torch.Tensor,
  219. g_idx_sort_indices: torch.Tensor,
  220. workspace: torch.Tensor,
  221. quant_type: ScalarType,
  222. output_size_per_partition: int,
  223. input_size_per_partition: int,
  224. bias: Optional[torch.Tensor] = None,
  225. use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
  226. reshaped_x = input.reshape(-1, input.shape[-1])
  227. out_shape = input.shape[:-1] + (output_size_per_partition, )
  228. output = ops.gptq_marlin_gemm(reshaped_x,
  229. weight,
  230. weight_scale,
  231. weight_zp,
  232. g_idx,
  233. g_idx_sort_indices,
  234. workspace,
  235. quant_type,
  236. size_m=reshaped_x.shape[0],
  237. size_n=output_size_per_partition,
  238. size_k=input_size_per_partition,
  239. is_k_full=True,
  240. has_zp=True,
  241. use_fp32_reduce=use_fp32_reduce,
  242. is_zp_float=True)
  243. if bias is not None:
  244. output.add_(bias) # In-place add
  245. return output.reshape(out_shape)