test_machete_gemm.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. """Tests for the machete kernel.
  2. Run `pytest tests/kernels/test_machete_gemm.py`.
  3. """
  4. import math
  5. from typing import Optional, Tuple
  6. import pytest
  7. import torch
  8. from aphrodite import _custom_ops as ops
  9. from aphrodite.platforms import current_platform
  10. from aphrodite.quantization.utils.quant_utils import (pack_rows,
  11. quantize_weights)
  12. from aphrodite.scalar_type import ScalarType, scalar_types
  13. from tests.kernels.utils import opcheck
  14. CUDA_DEVICES = [
  15. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  16. ]
  17. MNK_SHAPES = [
  18. (1, 128, 128),
  19. (1, 512, 1024),
  20. (1, 4096, 4096),
  21. (13, 8192, 4096),
  22. (26, 4096, 8192),
  23. (1, 4096, 4096),
  24. (257, 128, 4096),
  25. (257, 4224, 4160),
  26. (257, 4096, 4096),
  27. (64, 4096, 4096),
  28. (1024, 4096, 8192),
  29. (1024, 8192, 4096),
  30. ]
  31. ACT_TYPES = [torch.float16, torch.bfloat16]
  32. WTYPE_ZEROPOINTS = [
  33. # GPTQ style
  34. (scalar_types.uint4b8, False),
  35. (scalar_types.uint8b128, False),
  36. # AWQ style
  37. (scalar_types.uint4, True),
  38. (scalar_types.uint8, True),
  39. ]
  40. # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
  41. # unit tests to a common utility function. Currently the use of
  42. # `is_quant_method_supported` conflates kernels with quantization methods
  43. # an assumption which is breaking down as quantizations methods can have
  44. # have kernels and some kernels support multiple quantization methods.
  45. IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
  46. def rand_data(shape, dtype=torch.float16):
  47. return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3)
  48. def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
  49. return zps if zps is None else -1 * s * (zps.to(s.dtype))
  50. def machete_quantize_and_pack(w: torch.Tensor,
  51. wtype: ScalarType,
  52. group_size: int,
  53. zero_points: bool = False):
  54. assert wtype.is_integer(), "TODO: support floating point weights"
  55. w_ref, w_q, w_s, w_zp = quantize_weights(
  56. w,
  57. wtype,
  58. group_size,
  59. zero_points=zero_points,
  60. # to match how the kernel applies zps
  61. ref_zero_points_after_scales=True)
  62. w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
  63. w_q = w_q.t().contiguous().t() # convert to col major
  64. w_q_machete = ops.machete_prepack_B(w_q, wtype)
  65. opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype))
  66. return w_ref, w_q_machete, w_s, w_zp
  67. def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor,
  68. wtype: ScalarType, group_size: int,
  69. zero_points: bool):
  70. w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
  71. b, wtype, group_size, zero_points)
  72. output_ref = torch.matmul(a, w_ref)
  73. output = ops.machete_gemm(
  74. a=a,
  75. b_q=w_q_packed,
  76. b_type=wtype,
  77. b_scales=w_s,
  78. b_zeros=maybe_convert_zeropoints(w_zp, w_s),
  79. b_group_size=group_size,
  80. )
  81. # Relax atol as our reduction dim becomes larger (more rounding error)
  82. # Relax atol when we have zeropoints since the way machete applies
  83. # zeropoints (after scales) causes noise around 0
  84. atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1)
  85. torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)
  86. @pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
  87. reason="Machete is not supported on this GPU type.")
  88. @pytest.mark.parametrize("shape",
  89. MNK_SHAPES,
  90. ids=lambda x: "x".join(str(v) for v in x))
  91. @pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x))
  92. @pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS)
  93. @pytest.mark.parametrize("group_size", [128, None])
  94. def test_machete_all_schedules(shape, atype: torch.dtype,
  95. wtype_zeropoints: Tuple[ScalarType, bool],
  96. group_size: Optional[int]):
  97. m, n, k = shape
  98. wtype, zero_points = wtype_zeropoints
  99. if group_size is not None and k % group_size != 0:
  100. return
  101. print(f"MNK = {m} {n} {k}")
  102. # Normalize group_size
  103. if group_size is None:
  104. group_size = k
  105. assert group_size <= k
  106. a = rand_data((m, k), atype)
  107. w = rand_data((k, n), atype)
  108. w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack(
  109. w, wtype, group_size, zero_points)
  110. output_ref = torch.matmul(a, w_ref)
  111. for schedule in ops.machete_supported_schedules(wtype):
  112. print(f"Testing schedule {schedule}")
  113. output = ops.machete_gemm(
  114. a,
  115. b_q=w_q_machete,
  116. b_type=wtype,
  117. b_scales=w_s,
  118. b_zeros=maybe_convert_zeropoints(w_zp, w_s),
  119. b_group_size=group_size,
  120. schedule=schedule,
  121. )
  122. opcheck(torch.ops._C.machete_gemm,
  123. (a, w_q_machete, wtype, w_s, maybe_convert_zeropoints(
  124. w_zp, w_s), group_size, None, None, None, schedule))
  125. # Relax atol as our reduction dim becomes larger (more rounding error)
  126. # Relax atol when we have zeropoints since the way machete applies
  127. # zeropoints (after scales) causes noise around 0
  128. atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
  129. torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\
  130. f"Schedule failed {schedule}"
  131. @pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
  132. reason="Machete is not supported on this GPU type.")
  133. @pytest.mark.parametrize("shape",
  134. MNK_SHAPES,
  135. ids=lambda x: "x".join(str(v) for v in x))
  136. @pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x))
  137. @pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS)
  138. @pytest.mark.parametrize("group_size", [128, None])
  139. def test_machete_heuristic(shape, atype: torch.dtype,
  140. wtype_zeropoints: Tuple[ScalarType, bool],
  141. group_size: Optional[int]):
  142. m, n, k = shape
  143. wtype, zero_points = wtype_zeropoints
  144. if group_size is not None and k % group_size != 0:
  145. return
  146. # Normalize group_size
  147. if group_size is None:
  148. group_size = k
  149. assert group_size <= k
  150. a = rand_data((m, k), atype)
  151. b = rand_data((k, n), atype)
  152. machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
  153. # Test working on other devices
  154. @pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
  155. reason="Machete is not supported on this GPU type.")
  156. @pytest.mark.parametrize("device", CUDA_DEVICES)
  157. def test_machete_devices(device: str):
  158. m, n, k = 512, 4096, 4096
  159. wtype = scalar_types.uint4b8
  160. group_size = 128
  161. zero_points = False
  162. print(f"MNK = {m} {n} {k}, device = {device}")
  163. a = rand_data((m, k), torch.float16).to(device)
  164. b = rand_data((k, n), torch.float16).to(device)
  165. machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
  166. # Test working with a subset of A and B
  167. @pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
  168. reason="Machete is not supported on this GPU type.")
  169. def test_machete_subset():
  170. big_m, big_n, big_k = 1024, 1024, 1024
  171. m, n, k = 512, 512, 512
  172. wtype = scalar_types.uint4b8
  173. group_size = 128
  174. zero_points = False
  175. whole_a = rand_data((big_m, big_k), torch.float16)
  176. whole_b = rand_data((big_k, big_n), torch.float16)
  177. a = whole_a[0:m, 0:k]
  178. b = whole_b[0:k, 0:n]
  179. machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
  180. # Test to make sure cuda graphs work
  181. class MacheteLayer(torch.nn.Module):
  182. def __init__(self, **kwargs):
  183. super().__init__()
  184. self.kwargs = kwargs
  185. def forward(self, a):
  186. return ops.machete_gemm(**self.kwargs)
  187. @pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
  188. reason="Machete is not supported on this GPU type.")
  189. def test_machete_cuda_graph():
  190. m, n, k = 512, 4096, 4096
  191. a = rand_data((m, k), torch.float16)
  192. b = rand_data((k, n), torch.float16)
  193. wtype = scalar_types.uint4b8
  194. group_size = 128
  195. zero_points = False
  196. w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
  197. b, wtype, group_size, zero_points)
  198. # Construct a trivial model with a single layer that calls a machete kernel
  199. model = MacheteLayer(
  200. a=a,
  201. b_q=w_q_packed,
  202. b_type=wtype,
  203. b_scales=w_s,
  204. b_zeros=maybe_convert_zeropoints(w_zp, w_s),
  205. b_group_size=group_size,
  206. )
  207. output_ref = torch.matmul(a, w_ref)
  208. # Run the model with a cuda graph
  209. stream = torch.cuda.Stream()
  210. with torch.cuda.stream(stream):
  211. g = torch.cuda.CUDAGraph()
  212. with torch.cuda.graph(g):
  213. output = model(a)
  214. output.zero_()
  215. g.replay()
  216. # Relax atol as our reduction dim becomes larger (more rounding error)
  217. # Relax atol when we have zeropoints since the way machete applies
  218. # zeropoints (after scales) causes noise around 0
  219. atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
  220. torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)