1
0

marlin_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from typing import List, Optional, Tuple
  2. import numpy
  3. import torch
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.platforms import current_platform
  6. from aphrodite.scalar_type import ScalarType, scalar_types
  7. from .quant_utils import pack_cols, unpack_cols
  8. GPTQ_MARLIN_TILE = 16
  9. GPTQ_MARLIN_MIN_THREAD_N = 64
  10. GPTQ_MARLIN_MIN_THREAD_K = 128
  11. GPTQ_MARLIN_MAX_PARALLEL = 16
  12. MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  13. # In case there is a performance issue with Marlin, the variable below can be
  14. # changed to False, which allows Marlin to perform global reductions in fp16
  15. # precision (instead of fp32), and therefore, save on some memory movements.
  16. USE_FP32_REDUCE_DEFAULT = True
  17. # For binary size and compile time, we don't support the same types for with and
  18. # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
  19. # TODO: we may want to move this into the C++ so its closer to the actual impl
  20. def query_marlin_supported_quant_types(has_zp: bool,
  21. min_capability: Optional[int] = None):
  22. if min_capability is None:
  23. major, minor = current_platform.get_device_capability()
  24. min_capability = major * 10 + minor
  25. if min_capability < 80:
  26. return []
  27. if has_zp:
  28. # AWQ style, unsigned + runtime zero-point
  29. return [scalar_types.uint4, scalar_types.uint8]
  30. else:
  31. # GPTQ style, unsigned + symmetric bias
  32. # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
  33. # to add `scalar_types.float8_e4m3fn` here
  34. return [scalar_types.uint4b8, scalar_types.uint8b128]
  35. def _check_marlin_supported(
  36. quant_type: ScalarType,
  37. group_size: Optional[int],
  38. has_zp: bool,
  39. min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
  40. if min_capability is None:
  41. major, minor = current_platform.get_device_capability()
  42. min_capability = major * 10 + minor
  43. supported_types = query_marlin_supported_quant_types(
  44. has_zp, min_capability)
  45. if quant_type not in supported_types:
  46. return (False, f"Marlin does not support weight_bits = {quant_type}. "
  47. f"Only types = {supported_types} "
  48. f"are supported (for group_size = {group_size}, "
  49. f"min_capability = {min_capability}, zp = {has_zp}).")
  50. if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
  51. return (False, f"Marlin does not support group_size = {group_size}. "
  52. f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
  53. "are supported.")
  54. return True, None
  55. def check_marlin_supported(quant_type: ScalarType,
  56. group_size: int,
  57. has_zp: bool = False,
  58. min_capability: Optional[int] = None) -> bool:
  59. cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
  60. min_capability)
  61. return cond
  62. def verify_marlin_supported(quant_type: ScalarType,
  63. group_size: int,
  64. has_zp: bool = False) -> None:
  65. cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
  66. if not cond:
  67. assert err_msg is not None
  68. raise ValueError(err_msg)
  69. def verify_marlin_supports_shape(output_size_per_partition: int,
  70. input_size_per_partition: int,
  71. input_size: int, group_size: int) -> None:
  72. # Validate output_size_per_partition
  73. if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
  74. raise ValueError(f"Weight output_size_per_partition = "
  75. f"{output_size_per_partition} is not divisible by "
  76. f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
  77. "Consider reducing tensor_parallel_size or running "
  78. "with --quantization gptq.")
  79. # Validate input_size_per_partition
  80. if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
  81. raise ValueError(f"Weight input_size_per_partition = "
  82. f"{input_size_per_partition} is not divisible "
  83. f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
  84. "Consider reducing tensor_parallel_size or running "
  85. "with --quantization gptq.")
  86. if (group_size < input_size
  87. and input_size_per_partition % group_size != 0):
  88. raise ValueError(
  89. f"Weight input_size_per_partition = {input_size_per_partition}"
  90. f" is not divisible by group_size = {group_size}."
  91. "Consider reducing tensor_parallel_size or running "
  92. "with --quantization gptq.")
  93. def marlin_make_workspace(output_size_per_partition: int,
  94. device: torch.device) -> torch.Tensor:
  95. max_workspace_size = (output_size_per_partition //
  96. GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
  97. return torch.zeros(max_workspace_size,
  98. dtype=torch.int,
  99. device=device,
  100. requires_grad=False)
  101. def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
  102. return (not act_order) or (act_order and not is_row_parallel)
  103. def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
  104. is_row_parallel: bool) -> bool:
  105. # Need to repeat scales on every rank if act_ordering or
  106. # channelwise and RowParallelLinear
  107. is_channelwise = group_size == -1
  108. return act_order or (is_channelwise and is_row_parallel)
  109. def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
  110. return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
  111. requires_grad=False)
  112. def marlin_sort_g_idx(
  113. g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  114. g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
  115. return g_idx[g_idx_sort_indices], g_idx_sort_indices
  116. def get_scale_perms():
  117. scale_perm: List[int] = []
  118. for i in range(8):
  119. scale_perm.extend([i + 8 * j for j in range(8)])
  120. scale_perm_single: List[int] = []
  121. for i in range(4):
  122. scale_perm_single.extend(
  123. [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
  124. return scale_perm, scale_perm_single
  125. def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
  126. group_size: int) -> torch.Tensor:
  127. scale_perm, scale_perm_single = get_scale_perms()
  128. if group_size < size_k and group_size != -1:
  129. s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
  130. else:
  131. s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
  132. s = s.reshape((-1, size_n)).contiguous()
  133. return s
  134. def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
  135. num_bits: int) -> torch.Tensor:
  136. # Permute zero-points in a similar way to scales, but do not use the
  137. # "single" permutation, since zero-points are applied on every MMA
  138. scale_perm, _ = get_scale_perms()
  139. zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
  140. # Interleave column dim (for the dequantize code) and pack it to int32
  141. if num_bits == 4:
  142. interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
  143. elif num_bits == 8:
  144. interleave = numpy.array([0, 2, 1, 3])
  145. else:
  146. raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
  147. zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
  148. zp = zp.reshape((-1, size_n)).contiguous()
  149. zp = pack_cols(zp, num_bits, size_k, size_n)
  150. return zp
  151. def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
  152. size_n: int, num_bits: int) -> torch.Tensor:
  153. # AWQ zero-points are quantized and packed on the column dim.
  154. # In addition, the values are permuted based on dequantizer.
  155. # Here we undo both of these, and then apply marlin permutation
  156. # and pack it back.
  157. q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
  158. # Undo interleaving (use argsort(..) to get inverse perm)
  159. if num_bits == 4:
  160. undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
  161. elif num_bits == 8:
  162. undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
  163. else:
  164. raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
  165. q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
  166. q_zp = q_zp.reshape((-1, size_n)).contiguous()
  167. marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
  168. return marlin_zp
  169. # Newly generated tensors need to replace existing tensors that are
  170. # already registered as parameters by Aphrodite (and won't be freed)
  171. def replace_tensor(layer: torch.nn.Module, name: str,
  172. new_t: torch.Tensor) -> None:
  173. # It is important to use resize_() here since it ensures
  174. # the same buffer is reused
  175. getattr(layer, name).resize_(new_t.shape)
  176. getattr(layer, name).copy_(new_t)
  177. del new_t
  178. def apply_gptq_marlin_linear(
  179. input: torch.Tensor,
  180. weight: torch.Tensor,
  181. weight_scale: torch.Tensor,
  182. weight_zp: torch.Tensor,
  183. g_idx: torch.Tensor,
  184. g_idx_sort_indices: torch.Tensor,
  185. workspace: torch.Tensor,
  186. wtype: ScalarType,
  187. output_size_per_partition: int,
  188. input_size_per_partition: int,
  189. is_k_full: bool,
  190. bias: Optional[torch.Tensor] = None,
  191. use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
  192. reshaped_x = input.reshape(-1, input.shape[-1])
  193. out_shape = input.shape[:-1] + (output_size_per_partition, )
  194. output = ops.gptq_marlin_gemm(reshaped_x,
  195. weight,
  196. weight_scale,
  197. weight_zp,
  198. g_idx,
  199. g_idx_sort_indices,
  200. workspace,
  201. wtype,
  202. size_m=reshaped_x.shape[0],
  203. size_n=output_size_per_partition,
  204. size_k=input_size_per_partition,
  205. is_k_full=is_k_full,
  206. has_zp=False,
  207. use_fp32_reduce=use_fp32_reduce)
  208. if bias is not None:
  209. output.add_(bias) # In-place add
  210. return output.reshape(out_shape)
  211. def apply_awq_marlin_linear(
  212. input: torch.Tensor,
  213. weight: torch.Tensor,
  214. weight_scale: torch.Tensor,
  215. weight_zp: torch.Tensor,
  216. g_idx: torch.Tensor,
  217. g_idx_sort_indices: torch.Tensor,
  218. workspace: torch.Tensor,
  219. quant_type: ScalarType,
  220. output_size_per_partition: int,
  221. input_size_per_partition: int,
  222. bias: Optional[torch.Tensor] = None,
  223. use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
  224. reshaped_x = input.reshape(-1, input.shape[-1])
  225. out_shape = input.shape[:-1] + (output_size_per_partition, )
  226. output = ops.gptq_marlin_gemm(reshaped_x,
  227. weight,
  228. weight_scale,
  229. weight_zp,
  230. g_idx,
  231. g_idx_sort_indices,
  232. workspace,
  233. quant_type,
  234. size_m=reshaped_x.shape[0],
  235. size_n=output_size_per_partition,
  236. size_k=input_size_per_partition,
  237. is_k_full=True,
  238. has_zp=True,
  239. use_fp32_reduce=use_fp32_reduce)
  240. if bias is not None:
  241. output.add_(bias) # In-place add
  242. return output.reshape(out_shape)