test_cutlass.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. """Tests for cutlass kernels
  2. Run `pytest tests/kernels/test_cutlass.py`.
  3. """
  4. from typing import Optional, Type
  5. import pytest
  6. import torch
  7. from aphrodite import _custom_ops as ops
  8. from aphrodite.platforms import current_platform
  9. CUDA_DEVICES = [
  10. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  11. ]
  12. capability = current_platform.get_device_capability()
  13. capability = capability[0] * 10 + capability[1]
  14. def to_fp8(tensor: torch.Tensor):
  15. finfo = torch.finfo(torch.float8_e4m3fn)
  16. return torch.round(tensor.clamp(
  17. min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
  18. def to_int8(tensor: torch.Tensor):
  19. return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
  20. def rand_int8(shape: tuple, device: str = "cuda"):
  21. return to_int8(torch.rand(shape, device=device) * 255 - 128)
  22. def baseline_scaled_mm(a: torch.Tensor,
  23. b: torch.Tensor,
  24. scale_a: torch.Tensor,
  25. scale_b: torch.Tensor,
  26. out_dtype: Type[torch.dtype],
  27. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  28. output = (scale_a * (scale_b * (torch.mm(
  29. a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
  30. if bias is not None:
  31. output = output + bias
  32. return output
  33. def cutlass_fp8_gemm_helper(m: int,
  34. n: int,
  35. k: int,
  36. per_token_act_quant: bool,
  37. per_out_channel_weight_quant: bool,
  38. use_bias: bool,
  39. out_dtype: Type[torch.dtype] = torch.bfloat16,
  40. device: str = "cuda"):
  41. # Test for a cutlass kernel with per-token activation quantization
  42. # and per-output channel weight quantization.
  43. a = to_fp8(torch.randn((m, k), device=device))
  44. b = to_fp8(torch.randn((n, k), device=device).t())
  45. m_a_scales = m if per_token_act_quant else 1
  46. n_b_scales = n if per_out_channel_weight_quant else 1
  47. scale_a = (torch.randn((m_a_scales, 1), device=device,
  48. dtype=torch.float32))
  49. scale_b = (torch.randn((1, n_b_scales), device=device,
  50. dtype=torch.float32))
  51. if use_bias:
  52. bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
  53. else:
  54. bias = None
  55. out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
  56. baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
  57. torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
  58. def cutlass_int8_gemm_helper(m: int,
  59. n: int,
  60. k: int,
  61. per_token_act_quant: bool,
  62. per_out_channel_weight_quant: bool,
  63. use_bias: bool,
  64. out_dtype: Type[torch.dtype] = torch.bfloat16,
  65. device: str = "cuda"):
  66. # Test for a cutlass kernel with per-token activation quantization
  67. # and per-output channel weight quantization.
  68. a = to_int8(torch.randn((m, k), device=device) * 5)
  69. b = to_int8(torch.randn((n, k), device=device).t() * 5)
  70. m_a_scales = m if per_token_act_quant else 1
  71. n_b_scales = n if per_out_channel_weight_quant else 1
  72. scale_a = (torch.randn((m_a_scales, 1), device=device,
  73. dtype=torch.float32))
  74. scale_b = (torch.randn((1, n_b_scales), device=device,
  75. dtype=torch.float32))
  76. if use_bias:
  77. bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
  78. else:
  79. bias = None
  80. out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
  81. baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
  82. torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
  83. @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
  84. @pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
  85. @pytest.mark.parametrize("k", [128, 496, 1024])
  86. @pytest.mark.parametrize("per_act_token", [True, False])
  87. @pytest.mark.parametrize("per_out_ch", [True, False])
  88. @pytest.mark.parametrize("use_bias", [True, False])
  89. @pytest.mark.skipif(capability < 89,
  90. reason="FP8 is not supported on this GPU type.")
  91. def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
  92. per_out_ch: bool, use_bias: bool):
  93. cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
  94. @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 33, 1])
  95. @pytest.mark.parametrize("n", [2048, 8192, 16384, 256, 1024])
  96. @pytest.mark.parametrize("k", [128, 496, 1024])
  97. @pytest.mark.parametrize("per_act_token", [True, False])
  98. @pytest.mark.parametrize("per_out_ch", [True, False])
  99. @pytest.mark.parametrize("use_bias", [True, False])
  100. def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
  101. per_out_ch: bool, use_bias: bool):
  102. cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
  103. @pytest.mark.parametrize("per_act_token", [True, False])
  104. @pytest.mark.parametrize("per_out_ch", [True, False])
  105. @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
  106. @pytest.mark.parametrize("use_bias", [True, False])
  107. def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
  108. out_dtype: Type[torch.dtype],
  109. use_bias: bool):
  110. cutlass_int8_gemm_helper(512,
  111. 512,
  112. 512,
  113. per_act_token,
  114. per_out_ch,
  115. use_bias,
  116. out_dtype=out_dtype)
  117. @pytest.mark.parametrize("per_act_token", [True, False])
  118. @pytest.mark.parametrize("per_out_ch", [True, False])
  119. @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
  120. @pytest.mark.parametrize("use_bias", [True, False])
  121. @pytest.mark.skipif(capability < 89,
  122. reason="FP8 is not supported on this GPU type.")
  123. def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
  124. out_dtype: Type[torch.dtype],
  125. use_bias: bool):
  126. cutlass_fp8_gemm_helper(512,
  127. 512,
  128. 512,
  129. per_act_token,
  130. per_out_ch,
  131. use_bias,
  132. out_dtype=out_dtype)
  133. @pytest.mark.parametrize("per_act_token", [True, False])
  134. @pytest.mark.parametrize("per_out_ch", [True, False])
  135. @pytest.mark.parametrize("use_bias", [True, False])
  136. @pytest.mark.parametrize("device", CUDA_DEVICES)
  137. @pytest.mark.skipif(capability < 89,
  138. reason="FP8 is not supported on this GPU type.")
  139. def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
  140. use_bias: bool, device: str):
  141. cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
  142. torch.bfloat16, device)
  143. @pytest.mark.parametrize("per_act_token", [True, False])
  144. @pytest.mark.parametrize("per_out_ch", [True, False])
  145. @pytest.mark.parametrize("use_bias", [True, False])
  146. @pytest.mark.parametrize("device", CUDA_DEVICES)
  147. def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
  148. use_bias: bool, device: str):
  149. cutlass_int8_gemm_helper(512,
  150. 512,
  151. 512,
  152. per_act_token,
  153. per_out_ch,
  154. use_bias,
  155. out_dtype=torch.bfloat16,
  156. device=device)
  157. # For the following two tests:
  158. # N and K correspond to the size of the weight matrix and likely to be multiples
  159. # of a large power of two. In any case, the kernel will have a naive fallback
  160. # when N and K are not divisible by 16. But M is the number of tokens and the
  161. # kernel must handle any M thrown at it.
  162. @pytest.mark.parametrize("per_act_token", [True, False])
  163. @pytest.mark.parametrize("per_out_ch", [True, False])
  164. @pytest.mark.parametrize("use_bias", [True, False])
  165. @pytest.mark.skipif(capability < 89,
  166. reason="FP8 is not supported on this GPU type.")
  167. def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
  168. use_bias: bool):
  169. for nk in range(32, 128, 32):
  170. for m in range(1, 128):
  171. cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
  172. use_bias)
  173. @pytest.mark.parametrize("per_act_token", [True, False])
  174. @pytest.mark.parametrize("per_out_ch", [True, False])
  175. @pytest.mark.parametrize("use_bias", [True, False])
  176. def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
  177. use_bias: bool):
  178. for nk in range(32, 128, 32):
  179. for m in range(1, 128):
  180. cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
  181. use_bias)
  182. @pytest.mark.parametrize("m", [32, 64, 128])
  183. @pytest.mark.parametrize("n", [16, 32, 64])
  184. @pytest.mark.parametrize("k", [64, 128, 256])
  185. @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
  186. @pytest.mark.skip
  187. def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
  188. out_dtype: torch.dtype):
  189. # Currently, the test is failing because folding azp into
  190. # 16-bit bias loses too much precision
  191. scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
  192. scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
  193. aq_i8 = rand_int8((m, k))
  194. bq_i8 = rand_int8((n, k)).t()
  195. aq_i32 = aq_i8.to(dtype=torch.int32)
  196. bq_i32 = bq_i8.to(dtype=torch.int32)
  197. aq_f32 = aq_i8.to(dtype=torch.float32)
  198. bq_f32 = bq_i8.to(dtype=torch.float32)
  199. b_dq = scale_b * bq_f32
  200. azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
  201. azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
  202. azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
  203. a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
  204. torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
  205. baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
  206. J = torch.ones((1, k), device="cuda", dtype=torch.float32)
  207. azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
  208. assert azp_bias.shape == (1, n)
  209. assert azp_bias[0, :].shape == (n, )
  210. baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
  211. (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
  212. dtype=out_dtype, device='cuda')
  213. out = ops.cutlass_scaled_mm(aq_i8,
  214. bq_i8,
  215. scale_a,
  216. scale_b,
  217. out_dtype=out_dtype,
  218. bias=azp_bias[0, :])
  219. torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
  220. torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
  221. @pytest.mark.parametrize("m", [32, 64, 128])
  222. @pytest.mark.parametrize("n", [16, 32, 64])
  223. @pytest.mark.parametrize("k", [64, 128, 256])
  224. @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
  225. @pytest.mark.parametrize("use_bias", [True, False])
  226. @pytest.mark.parametrize("azp_per_token", [True, False])
  227. def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
  228. use_bias: bool, azp_per_token: bool):
  229. m_azp = m if azp_per_token else 1
  230. scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
  231. scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
  232. aq_i8 = rand_int8((m, k))
  233. aq_i32 = aq_i8.to(dtype=torch.int32)
  234. aq_f32 = aq_i8.to(dtype=torch.float32)
  235. bq_i8 = rand_int8((n, k)).t()
  236. bq_i32 = bq_i8.to(dtype=torch.int32)
  237. bq_f32 = bq_i8.to(dtype=torch.float32)
  238. b_dq = scale_b * bq_f32
  239. azp_a = torch.rand(
  240. (m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
  241. azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
  242. azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
  243. a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
  244. torch.testing.assert_close(a_dq,
  245. scale_a * aq_f32 - azp_a,
  246. rtol=1e-4,
  247. atol=1e-3)
  248. if use_bias:
  249. bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
  250. else:
  251. bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
  252. baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
  253. # int32 mm not supported on CUDA
  254. a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
  255. cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
  256. baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
  257. # Hadamard is just the sum of the cols
  258. azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
  259. azp_i32 = azp_aq_i8.to(dtype=torch.int32)
  260. func_bias = bias if use_bias else None
  261. if azp_per_token:
  262. out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
  263. out_dtype, azp_adj_i32, azp_i32,
  264. func_bias)
  265. else:
  266. azp_with_adj_i32 = azp_i32 * azp_adj_i32
  267. out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
  268. out_dtype, azp_with_adj_i32, None,
  269. func_bias)
  270. # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
  271. # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
  272. rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
  273. atol = 1e-3
  274. torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
  275. torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
  276. # Test working with a subset of A and B
  277. def test_cutlass_subset():
  278. big_m, big_n, big_k = 1024, 1024, 1024
  279. m, n, k = 512, 512, 512
  280. whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
  281. whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
  282. a = whole_a[0:m, 0:k]
  283. b = whole_b[0:k, 0:n]
  284. scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
  285. scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
  286. out = ops.cutlass_scaled_mm(a,
  287. b,
  288. scale_a,
  289. scale_b,
  290. out_dtype=torch.bfloat16)
  291. baseline = baseline_scaled_mm(a,
  292. b,
  293. scale_a,
  294. scale_b,
  295. out_dtype=torch.bfloat16)
  296. torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
  297. # Test to make sure cuda graphs work
  298. class CutlassLayer(torch.nn.Module):
  299. def __init__(self, b, scale_a, scale_b, out_dtype):
  300. super().__init__()
  301. self.b = b
  302. self.scale_a = scale_a
  303. self.scale_b = scale_b
  304. self.out_dtype = out_dtype
  305. def forward(self, a):
  306. return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
  307. self.out_dtype)
  308. @pytest.mark.parametrize("per_act_token", [True, False])
  309. @pytest.mark.parametrize("per_out_ch", [True, False])
  310. def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
  311. m, n, k = 512, 512, 512
  312. a = to_int8(torch.randn((m, k), device="cuda"))
  313. b = to_int8(torch.randn((n, k), device="cuda").t())
  314. m_a_scales = m if per_act_token else 1
  315. n_b_scales = n if per_out_ch else 1
  316. scale_a = (torch.randn(
  317. (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
  318. scale_b = (torch.randn(
  319. (1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
  320. # Construct a trivial model with a single layer that calls a CUTLASS kernel
  321. model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
  322. # Run the model with a cuda graph
  323. stream = torch.cuda.Stream()
  324. with torch.cuda.stream(stream):
  325. g = torch.cuda.CUDAGraph()
  326. with torch.cuda.graph(g):
  327. out = model(a)
  328. out.zero_()
  329. g.replay()
  330. baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
  331. scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
  332. torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)