test_cutlass.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. """Tests for cutlass kernels
  2. Run `pytest tests/kernels/test_cutlass.py`.
  3. """
  4. from typing import Optional, Type
  5. import pytest
  6. import torch
  7. from aphrodite import _custom_ops as ops
  8. from aphrodite.platforms import current_platform
  9. from tests.kernels.utils import opcheck
  10. CUDA_DEVICES = [
  11. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  12. ]
  13. capability = current_platform.get_device_capability()
  14. capability = capability[0] * 10 + capability[1]
  15. def to_fp8(tensor: torch.Tensor):
  16. finfo = torch.finfo(torch.float8_e4m3fn)
  17. return torch.round(tensor.clamp(
  18. min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
  19. def to_int8(tensor: torch.Tensor):
  20. return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
  21. def rand_int8(shape: tuple, device: str = "cuda"):
  22. return to_int8(torch.rand(shape, device=device) * 255 - 128)
  23. def baseline_scaled_mm(a: torch.Tensor,
  24. b: torch.Tensor,
  25. scale_a: torch.Tensor,
  26. scale_b: torch.Tensor,
  27. out_dtype: Type[torch.dtype],
  28. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  29. output = (scale_a * (scale_b * (torch.mm(
  30. a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
  31. if bias is not None:
  32. output = output + bias
  33. return output
  34. def cutlass_fp8_gemm_helper(m: int,
  35. n: int,
  36. k: int,
  37. per_token_act_quant: bool,
  38. per_out_channel_weight_quant: bool,
  39. use_bias: bool,
  40. out_dtype: Type[torch.dtype] = torch.bfloat16,
  41. device: str = "cuda"):
  42. # Test for a cutlass kernel with per-token activation quantization
  43. # and per-output channel weight quantization.
  44. a = to_fp8(torch.randn((m, k), device=device))
  45. b = to_fp8(torch.randn((n, k), device=device).t())
  46. m_a_scales = m if per_token_act_quant else 1
  47. n_b_scales = n if per_out_channel_weight_quant else 1
  48. scale_a = (torch.randn((m_a_scales, 1), device=device,
  49. dtype=torch.float32))
  50. scale_b = (torch.randn((1, n_b_scales), device=device,
  51. dtype=torch.float32))
  52. if use_bias:
  53. bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
  54. else:
  55. bias = None
  56. out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
  57. baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
  58. torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
  59. def cutlass_int8_gemm_helper(m: int,
  60. n: int,
  61. k: int,
  62. per_token_act_quant: bool,
  63. per_out_channel_weight_quant: bool,
  64. use_bias: bool,
  65. out_dtype: Type[torch.dtype] = torch.bfloat16,
  66. device: str = "cuda"):
  67. # Test for a cutlass kernel with per-token activation quantization
  68. # and per-output channel weight quantization.
  69. a = to_int8(torch.randn((m, k), device=device) * 5)
  70. b = to_int8(torch.randn((n, k), device=device).t() * 5)
  71. m_a_scales = m if per_token_act_quant else 1
  72. n_b_scales = n if per_out_channel_weight_quant else 1
  73. scale_a = (torch.randn((m_a_scales, 1), device=device,
  74. dtype=torch.float32))
  75. scale_b = (torch.randn((1, n_b_scales), device=device,
  76. dtype=torch.float32))
  77. if use_bias:
  78. bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
  79. else:
  80. bias = None
  81. out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
  82. baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
  83. torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
  84. opcheck(torch.ops._C.cutlass_scaled_mm,
  85. (out, a, b, scale_a, scale_b, bias))
  86. @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
  87. @pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
  88. @pytest.mark.parametrize("k", [128, 496, 1024])
  89. @pytest.mark.parametrize("per_act_token", [True, False])
  90. @pytest.mark.parametrize("per_out_ch", [True, False])
  91. @pytest.mark.parametrize("use_bias", [True, False])
  92. @pytest.mark.skipif(capability < 89,
  93. reason="FP8 is not supported on this GPU type.")
  94. def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
  95. per_out_ch: bool, use_bias: bool):
  96. cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
  97. @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 33, 1])
  98. @pytest.mark.parametrize("n", [2048, 8192, 16384, 256, 1024])
  99. @pytest.mark.parametrize("k", [128, 496, 1024])
  100. @pytest.mark.parametrize("per_act_token", [True, False])
  101. @pytest.mark.parametrize("per_out_ch", [True, False])
  102. @pytest.mark.parametrize("use_bias", [True, False])
  103. def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
  104. per_out_ch: bool, use_bias: bool):
  105. cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
  106. @pytest.mark.parametrize("per_act_token", [True, False])
  107. @pytest.mark.parametrize("per_out_ch", [True, False])
  108. @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
  109. @pytest.mark.parametrize("use_bias", [True, False])
  110. def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
  111. out_dtype: Type[torch.dtype],
  112. use_bias: bool):
  113. cutlass_int8_gemm_helper(512,
  114. 512,
  115. 512,
  116. per_act_token,
  117. per_out_ch,
  118. use_bias,
  119. out_dtype=out_dtype)
  120. @pytest.mark.parametrize("per_act_token", [True, False])
  121. @pytest.mark.parametrize("per_out_ch", [True, False])
  122. @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
  123. @pytest.mark.parametrize("use_bias", [True, False])
  124. @pytest.mark.skipif(capability < 89,
  125. reason="FP8 is not supported on this GPU type.")
  126. def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
  127. out_dtype: Type[torch.dtype],
  128. use_bias: bool):
  129. cutlass_fp8_gemm_helper(512,
  130. 512,
  131. 512,
  132. per_act_token,
  133. per_out_ch,
  134. use_bias,
  135. out_dtype=out_dtype)
  136. @pytest.mark.parametrize("per_act_token", [True, False])
  137. @pytest.mark.parametrize("per_out_ch", [True, False])
  138. @pytest.mark.parametrize("use_bias", [True, False])
  139. @pytest.mark.parametrize("device", CUDA_DEVICES)
  140. @pytest.mark.skipif(capability < 89,
  141. reason="FP8 is not supported on this GPU type.")
  142. def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
  143. use_bias: bool, device: str):
  144. cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
  145. torch.bfloat16, device)
  146. @pytest.mark.parametrize("per_act_token", [True, False])
  147. @pytest.mark.parametrize("per_out_ch", [True, False])
  148. @pytest.mark.parametrize("use_bias", [True, False])
  149. @pytest.mark.parametrize("device", CUDA_DEVICES)
  150. def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
  151. use_bias: bool, device: str):
  152. cutlass_int8_gemm_helper(512,
  153. 512,
  154. 512,
  155. per_act_token,
  156. per_out_ch,
  157. use_bias,
  158. out_dtype=torch.bfloat16,
  159. device=device)
  160. # For the following two tests:
  161. # N and K correspond to the size of the weight matrix and likely to be multiples
  162. # of a large power of two. In any case, the kernel will have a naive fallback
  163. # when N and K are not divisible by 16. But M is the number of tokens and the
  164. # kernel must handle any M thrown at it.
  165. @pytest.mark.parametrize("per_act_token", [True, False])
  166. @pytest.mark.parametrize("per_out_ch", [True, False])
  167. @pytest.mark.parametrize("use_bias", [True, False])
  168. @pytest.mark.skipif(capability < 89,
  169. reason="FP8 is not supported on this GPU type.")
  170. def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
  171. use_bias: bool):
  172. for nk in range(32, 128, 32):
  173. for m in range(1, 128):
  174. cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
  175. use_bias)
  176. @pytest.mark.parametrize("per_act_token", [True, False])
  177. @pytest.mark.parametrize("per_out_ch", [True, False])
  178. @pytest.mark.parametrize("use_bias", [True, False])
  179. def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
  180. use_bias: bool):
  181. for nk in range(32, 128, 32):
  182. for m in range(1, 128):
  183. cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
  184. use_bias)
  185. @pytest.mark.parametrize("m", [32, 64, 128])
  186. @pytest.mark.parametrize("n", [16, 32, 64])
  187. @pytest.mark.parametrize("k", [64, 128, 256])
  188. @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
  189. @pytest.mark.skip
  190. def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
  191. out_dtype: torch.dtype):
  192. # Currently, the test is failing because folding azp into
  193. # 16-bit bias loses too much precision
  194. scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
  195. scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
  196. aq_i8 = rand_int8((m, k))
  197. bq_i8 = rand_int8((n, k)).t()
  198. aq_i32 = aq_i8.to(dtype=torch.int32)
  199. bq_i32 = bq_i8.to(dtype=torch.int32)
  200. aq_f32 = aq_i8.to(dtype=torch.float32)
  201. bq_f32 = bq_i8.to(dtype=torch.float32)
  202. b_dq = scale_b * bq_f32
  203. azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
  204. azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
  205. azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
  206. a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
  207. torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
  208. baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
  209. J = torch.ones((1, k), device="cuda", dtype=torch.float32)
  210. azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
  211. assert azp_bias.shape == (1, n)
  212. assert azp_bias[0, :].shape == (n, )
  213. baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
  214. (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
  215. dtype=out_dtype, device='cuda')
  216. out = ops.cutlass_scaled_mm(aq_i8,
  217. bq_i8,
  218. scale_a,
  219. scale_b,
  220. out_dtype=out_dtype,
  221. bias=azp_bias[0, :])
  222. torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
  223. torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
  224. @pytest.mark.parametrize("m", [32, 64, 128])
  225. @pytest.mark.parametrize("n", [16, 32, 64])
  226. @pytest.mark.parametrize("k", [64, 128, 256])
  227. @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
  228. @pytest.mark.parametrize("use_bias", [True, False])
  229. @pytest.mark.parametrize("azp_per_token", [True, False])
  230. def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
  231. use_bias: bool, azp_per_token: bool):
  232. m_azp = m if azp_per_token else 1
  233. scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
  234. scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
  235. aq_i8 = rand_int8((m, k))
  236. aq_i32 = aq_i8.to(dtype=torch.int32)
  237. aq_f32 = aq_i8.to(dtype=torch.float32)
  238. bq_i8 = rand_int8((n, k)).t()
  239. bq_i32 = bq_i8.to(dtype=torch.int32)
  240. bq_f32 = bq_i8.to(dtype=torch.float32)
  241. b_dq = scale_b * bq_f32
  242. azp_a = torch.rand(
  243. (m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
  244. azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
  245. azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
  246. a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
  247. torch.testing.assert_close(a_dq,
  248. scale_a * aq_f32 - azp_a,
  249. rtol=1e-4,
  250. atol=1e-3)
  251. if use_bias:
  252. bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
  253. else:
  254. bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
  255. baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
  256. # int32 mm not supported on CUDA
  257. a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
  258. cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
  259. baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
  260. # Hadamard is just the sum of the cols
  261. azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
  262. azp_i32 = azp_aq_i8.to(dtype=torch.int32)
  263. func_bias = bias if use_bias else None
  264. if azp_per_token:
  265. out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
  266. out_dtype, azp_adj_i32, azp_i32,
  267. func_bias)
  268. else:
  269. azp_with_adj_i32 = azp_i32 * azp_adj_i32
  270. out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
  271. out_dtype, azp_with_adj_i32, None,
  272. func_bias)
  273. # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
  274. # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
  275. rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
  276. atol = 1e-3
  277. torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
  278. torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
  279. if azp_per_token:
  280. opcheck(torch.ops._C.cutlass_scaled_mm_azp,
  281. (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
  282. func_bias))
  283. else:
  284. opcheck(torch.ops._C.cutlass_scaled_mm_azp,
  285. (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
  286. func_bias))
  287. # Test working with a subset of A and B
  288. def test_cutlass_subset():
  289. big_m, big_n, big_k = 1024, 1024, 1024
  290. m, n, k = 512, 512, 512
  291. whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
  292. whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
  293. a = whole_a[0:m, 0:k]
  294. b = whole_b[0:k, 0:n]
  295. scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
  296. scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
  297. out = ops.cutlass_scaled_mm(a,
  298. b,
  299. scale_a,
  300. scale_b,
  301. out_dtype=torch.bfloat16)
  302. baseline = baseline_scaled_mm(a,
  303. b,
  304. scale_a,
  305. scale_b,
  306. out_dtype=torch.bfloat16)
  307. torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
  308. # Test to make sure cuda graphs work
  309. class CutlassLayer(torch.nn.Module):
  310. def __init__(self, b, scale_a, scale_b, out_dtype):
  311. super().__init__()
  312. self.b = b
  313. self.scale_a = scale_a
  314. self.scale_b = scale_b
  315. self.out_dtype = out_dtype
  316. def forward(self, a):
  317. return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
  318. self.out_dtype)
  319. @pytest.mark.parametrize("per_act_token", [True, False])
  320. @pytest.mark.parametrize("per_out_ch", [True, False])
  321. def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
  322. m, n, k = 512, 512, 512
  323. a = to_int8(torch.randn((m, k), device="cuda"))
  324. b = to_int8(torch.randn((n, k), device="cuda").t())
  325. m_a_scales = m if per_act_token else 1
  326. n_b_scales = n if per_out_ch else 1
  327. scale_a = (torch.randn(
  328. (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
  329. scale_b = (torch.randn(
  330. (1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
  331. # Construct a trivial model with a single layer that calls a CUTLASS kernel
  332. model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
  333. # Run the model with a cuda graph
  334. stream = torch.cuda.Stream()
  335. with torch.cuda.stream(stream):
  336. g = torch.cuda.CUDAGraph()
  337. with torch.cuda.graph(g):
  338. out = model(a)
  339. out.zero_()
  340. g.replay()
  341. baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
  342. scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
  343. torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)