test.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724
  1. import torch
  2. import pytest
  3. from .utils import MetaData, get_input_shapes, input_helper, varlen_input_helper, DEBUG
  4. from .interface_torch import attention_prefill, attention_decode
  5. from .fwd_ref import attention_forward_pytorch_ref_impl, compute_alibi_tensor_ref
  6. from .fwd_prefill import attention_prefill_forward_triton_impl
  7. from .bwd_prefill import attention_prefill_backward_triton_impl
  8. from .bwd_ref import attention_backward_pytorch_ref_impl
  9. from .fwd_decode import dequantize_kv_fp16, quantize_kv_int4
  10. # defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html
  11. ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose.
  12. # ATOL, RTOL = 1e-3, 1e-3 # catchs fa mismatch issues
  13. # ATOL, RTOL = 1e-4, 1e-3 # to strict. there will be small diffs
  14. # ATOL, RTOL = 1e-5, 1e-3 # # default fp16. there will be small diffs
  15. EQUAL_NAN = True
  16. @pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [
  17. (4, 48, 24, 1024, 1024, 64),
  18. (1, 24, 6, 8192, 8192, 64),
  19. (1, 4, 2, 16384, 16384, 128),
  20. (2, 16, 4, 1020, 987, 128),
  21. (2, 16, 4, 15498, 2, 128),
  22. (2, 16, 2, 7, 16219, 64),
  23. (4, 48, 12, 1, 1, 64),
  24. (4, 48, 48, 1, 1, 128),
  25. (4, 48, 24, 3, 3, 128),
  26. (4, 48, 48, 1001, 990, 64),
  27. (1, 8, 8, 8081, 7099, 64),
  28. (1, 4, 4, 16330, 15989, 128),
  29. (4, 4, 1, 1024, 1024, 33),
  30. (4, 4, 2, 65, 1018, 65),
  31. (4, 4, 4, 128, 128, 65),
  32. (4, 4, 4, 113, 123, 1),
  33. ])
  34. @pytest.mark.parametrize('causal', [True, False])
  35. @pytest.mark.parametrize('use_alibi', [True, False])
  36. @pytest.mark.parametrize('layout', ['bshd', 'bhsd'])
  37. def test_op_fwd_prefill(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16):
  38. torch.manual_seed(20)
  39. q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout)
  40. if causal:
  41. input_metadata.need_causal()
  42. if use_alibi:
  43. # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n)
  44. alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32,
  45. device="cuda").repeat(Z, 1)
  46. input_metadata.need_alibi(alibi_slopes, Z, HQ)
  47. else:
  48. alibi_slopes = None
  49. o = torch.empty_like(q)
  50. # triton implementation
  51. tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata)
  52. # Transpose here if layout is bshd so we have same reference code for all layouts
  53. if layout == 'bshd':
  54. q = q.transpose(1, 2).clone()
  55. k = k.transpose(1, 2).clone()
  56. v = v.transpose(1, 2).clone()
  57. # Replicate K and V if using MQA/GQA
  58. if HQ != HK:
  59. k = k.view(k.shape[0], k.shape[1], -1, k.shape[2],
  60. k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3])
  61. v = v.view(v.shape[0], v.shape[1], -1, v.shape[2],
  62. v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3])
  63. scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale
  64. if causal:
  65. mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q)
  66. scores[:, :, mask == 0] = float("-inf")
  67. if use_alibi:
  68. scores += compute_alibi_tensor_ref(alibi_slopes, N_CTX_Q, N_CTX_K)
  69. p = torch.softmax(scores, dim=-1)
  70. if causal:
  71. # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into
  72. # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix
  73. # this by converting the NaNs to 0s, which is what they should be out of the softmax.
  74. nan_mask = torch.isnan(p)
  75. p[nan_mask == 1] = 0
  76. ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v)
  77. # compare
  78. if layout == 'bshd':
  79. ref_out = ref_out.transpose(1, 2).clone()
  80. torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
  81. @pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [
  82. (4, 48, 1024, 1024, 64),
  83. (4, 12, 8192, 8192, 64),
  84. (2, 4, 16384, 16384, 128),
  85. (2, 16, 15498, 2, 128),
  86. (2, 4, 7, 16219, 64),
  87. (4, 48, 1, 1, 64),
  88. (4, 48, 1, 1, 128),
  89. (4, 48, 3, 3, 128),
  90. (4, 48, 1001, 990, 64),
  91. (1, 8, 8081, 7099, 64),
  92. (1, 8, 16330, 15989, 128),
  93. (4, 4, 1024, 1024, 33),
  94. (4, 4, 65, 1019, 65),
  95. (4, 4, 128, 128, 65),
  96. # TODO: This config fails. Disabled until triaged and fixed.
  97. # (2, 16, 1020, 987, 128),
  98. # (4, 4, 113, 123, 1),
  99. ])
  100. @pytest.mark.parametrize('causal', [True, False])
  101. @pytest.mark.parametrize('use_bias', [True])
  102. def test_op_fwd_prefill_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16):
  103. torch.manual_seed(20)
  104. sm_scale = D_HEAD**-0.5
  105. input_metadata = MetaData(sm_scale=sm_scale)
  106. q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout='bhsd')
  107. if causal:
  108. input_metadata.need_causal()
  109. if use_bias:
  110. bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda")
  111. input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K)
  112. else:
  113. bias = None
  114. o = torch.empty_like(q)
  115. # triton implementation
  116. tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata)
  117. # reference implementation:171
  118. scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale
  119. if causal:
  120. mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q)
  121. scores[:, :, mask == 0] = float("-inf")
  122. if use_bias:
  123. scores += input_metadata.bias
  124. p = torch.softmax(scores, dim=-1)
  125. if causal:
  126. # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into
  127. # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix
  128. # this by converting the NaNs to 0s, which is what they should be out of the softmax.
  129. nan_mask = torch.isnan(p)
  130. p[nan_mask == 1] = 0
  131. ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v)
  132. # compare
  133. torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
  134. @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [
  135. (4, 48, 8192, 64),
  136. (4, 48, 256, 64),
  137. (4, 48, 512, 64),
  138. (4, 48, 1024, 64),
  139. (8, 48, 4096, 64),
  140. (4, 48, 8192, 64),
  141. (4, 48, 128, 128),
  142. (4, 48, 4096, 128),
  143. (4, 48, 16384, 128),
  144. (4, 16, 1024, 128),
  145. (4, 16, 8192, 128),
  146. (32, 48, 8192, 128)
  147. ]
  148. )
  149. @pytest.mark.parametrize('causal', [True, False])
  150. def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
  151. q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype)
  152. tri_out = torch.empty_like(q)
  153. ref_out = torch.empty_like(q)
  154. for i in range(0, input_metadata.num_contexts):
  155. start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i]
  156. end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1]
  157. scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k[start_k:end_k]).float()
  158. p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half()
  159. ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v[start_k:end_k])
  160. attention_prefill(q, k, v, tri_out, input_metadata)
  161. torch.testing.assert_close(ref_out, tri_out, atol=ATOL, rtol=RTOL)
  162. @pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), (4, 48, 12, 256, 64), (4, 48, 4, 512, 64),
  163. (4, 48, 2, 1024, 64), (8, 48, 6, 4096, 64), (4, 48, 8, 16384, 64),
  164. (4, 64, 16, 128, 128), (4, 64, 4, 4096, 128),
  165. (4, 64, 8, 16384, 128), (4, 16, 4, 1024, 128),
  166. (4, 16, 2, 8192, 128), (32, 128, 32, 8192, 128)])
  167. @pytest.mark.parametrize('causal', [False])
  168. def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16):
  169. q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype)
  170. ref_out = torch.empty_like(q)
  171. tri_out = torch.empty_like(q)
  172. # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so the
  173. # size aligns with Q.
  174. k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand(-1, -1, HQ // HK, -1)
  175. v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(-1, -1, HQ // HK, -1)
  176. for i in range(0, input_metadata.num_contexts):
  177. start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i]
  178. end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1]
  179. k_curr = k_ref[start_k:end_k]
  180. k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3])
  181. v_curr = v_ref[start_k:end_k]
  182. v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3])
  183. scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k_curr).float()
  184. p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half()
  185. ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr)
  186. attention_prefill(q, k, v, tri_out, input_metadata)
  187. torch.testing.assert_close(ref_out, tri_out, atol=ATOL, rtol=RTOL)
  188. @pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [
  189. # smallest config test
  190. (1, 1, 16, 16, 64), # pass on new # fail on old
  191. (1, 1, 32, 32, 64), # pass on new # fail on old
  192. (1, 1, 64, 64, 16), # pass # smallest head_size = 16
  193. (1, 1, 64, 64, 64), # pass # smallest seq len seems to be 64
  194. (1, 1, 128, 128, 64), # pass
  195. (1, 1, 256, 256, 64), # pass
  196. (1, 1, 512, 512, 64), # pass
  197. # failing FA
  198. (1, 1, 256, 512, 16),
  199. # old tests that work
  200. (4, 48, 1024, 1024, 64), # pass
  201. (4, 48, 2048, 2048, 64), # pass
  202. (2, 48, 4096, 4096, 64), # pass
  203. (1, 16, 1024, 1024, 64), # pass
  204. (1, 16, 1024, 1024, 128), # pass
  205. # old tests that were commented out
  206. # (1, 16, 8192, 8192, 63),
  207. # (1, 16, 1022, 1022, 64),
  208. ])
  209. # @pytest.mark.parametrize('torch_sdpa_test', [False, True])
  210. @pytest.mark.parametrize('torch_sdpa_test', [False])
  211. # @pytest.mark.parametrize('causal', [True, False])
  212. @pytest.mark.parametrize('causal', [False])
  213. # @pytest.mark.parametrize('use_alibi', [False, True])
  214. @pytest.mark.parametrize('use_alibi', [False])
  215. def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_alibi, dtype=torch.float16):
  216. torch.manual_seed(20)
  217. DEBUG_INPUT = False
  218. # seqlens
  219. seqlen_q = N_CTX_Q
  220. seqlen_k = N_CTX_K
  221. # setup up metadata
  222. if DEBUG_INPUT:
  223. sm_scale = 1
  224. else:
  225. sm_scale = D_HEAD**-0.5
  226. input_metadata = MetaData(sm_scale=sm_scale)
  227. input_metadata.max_seqlens_q = seqlen_q
  228. input_metadata.max_seqlens_k = seqlen_k
  229. input_metadata.layout = "bhsd"
  230. dropout_p = 0
  231. if DEBUG_INPUT:
  232. q = torch.arange(seqlen_q, dtype=dtype, device="cuda").view(1, 1, seqlen_q, 1).expand(Z, H, seqlen_q, D_HEAD).requires_grad_()
  233. k = torch.arange(seqlen_k, dtype=dtype, device="cuda").view(1, 1, seqlen_k, 1).expand(Z, H, seqlen_k, D_HEAD).requires_grad_()
  234. v = torch.arange(seqlen_k, dtype=dtype, device="cuda").view(1, 1, seqlen_k, 1).expand(Z, H, seqlen_k, D_HEAD).requires_grad_()
  235. o = torch.zeros_like(q)
  236. else:
  237. # Generate random inputs
  238. q = torch.randn(Z, H, N_CTX_Q, D_HEAD, device='cuda', dtype=dtype, requires_grad=True)
  239. k = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True)
  240. v = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True)
  241. o = torch.empty_like(q)
  242. if causal:
  243. input_metadata.need_causal()
  244. if use_alibi and not torch_sdpa_test:
  245. # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n)
  246. alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32,
  247. device="cuda").repeat(Z, 1)
  248. input_metadata.need_alibi(alibi_slopes, Z, H)
  249. if DEBUG_INPUT:
  250. dout = torch.ones_like(q)
  251. else:
  252. dout = torch.randn_like(q)
  253. # reference implementation
  254. if torch_sdpa_test:
  255. ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p,
  256. is_causal=causal, scale=sm_scale,
  257. dropout_mask=None)
  258. ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype))
  259. ref_dv, v.grad = v.grad.clone(), None
  260. ref_dk, k.grad = k.grad.clone(), None
  261. ref_dq, q.grad = q.grad.clone(), None
  262. else:
  263. M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda"))
  264. p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
  265. if use_alibi:
  266. p += compute_alibi_tensor_ref(alibi_slopes, N_CTX_Q, N_CTX_K)
  267. if causal:
  268. p[:, :, M == 0] = float("-inf")
  269. p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype)
  270. ref_out = torch.matmul(p, v)
  271. ref_out.backward(dout)
  272. ref_dv, v.grad = v.grad.clone(), None
  273. ref_dk, k.grad = k.grad.clone(), None
  274. ref_dq, q.grad = q.grad.clone(), None
  275. # # triton implementation
  276. tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata)
  277. tri_out.backward(dout)
  278. tri_dv, v.grad = v.grad.clone(), None
  279. tri_dk, k.grad = k.grad.clone(), None
  280. tri_dq, q.grad = q.grad.clone(), None
  281. # compare
  282. if DEBUG:
  283. print("tri_out:", tri_out)
  284. print("ref_out:",ref_out )
  285. torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
  286. # The current block size for MI200 series is 64x64. This results in
  287. # larger differences in float results due to rounding.
  288. if dtype == torch.bfloat16:
  289. ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0)
  290. if dtype == torch.float32:
  291. ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0)
  292. else:
  293. ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0)
  294. RTOL = 0
  295. if DEBUG:
  296. print("ref_dv:", ref_dv)
  297. print("tri_dv:", tri_dv)
  298. print("ref_dk:", ref_dk)
  299. print("tri_dk:", tri_dk)
  300. print("ref_dq:", ref_dq)
  301. print("tri_dq:", tri_dq)
  302. torch.testing.assert_close(ref_dv, tri_dv, atol=ATOL, rtol=RTOL)
  303. torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL)
  304. torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL)
  305. @pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [
  306. (1, 1, 1, 1, 1),
  307. (1, 1, 2, 4, 16),
  308. (1, 1, 4, 2, 16),
  309. (1, 1, 4, 4, 16),
  310. (1, 2, 4, 4, 16),
  311. (2, 1, 4, 4, 16),
  312. (2, 2, 4, 4, 16),
  313. (1, 1, 128, 64, 16),
  314. (2, 2, 2, 128, 1),
  315. (2, 3, 2, 128, 16),
  316. (3, 2, 256, 512, 16),
  317. (3, 3, 128, 128, 64),
  318. (2, 4, 1024, 1024, 64),
  319. (4, 6, 108, 256, 224),
  320. (4, 8, 2048, 2048, 128),
  321. (4, 16, 4096, 4096, 64),
  322. (2, 4, 8192, 8192, 32),
  323. # # fa configs
  324. (4, 6, 113, 203, 256),
  325. (4, 6, 128, 217, 256),
  326. (4, 6, 113, 211, 128),
  327. (4, 6, 108, 256, 128),
  328. (4, 6, 256, 512, 64),
  329. (4, 6, 512, 256, 64),
  330. (4, 6, 1024, 1024, 32),
  331. (4, 6, 1023, 1024, 32),
  332. (4, 6, 1024, 1023, 32),
  333. (4, 6, 2048, 2048, 32),
  334. ])
  335. @pytest.mark.parametrize('causal', [True, False])
  336. @pytest.mark.parametrize('return_scores', [False])
  337. @pytest.mark.parametrize('layout', ["bhsd", "bshd", "thd"])
  338. @pytest.mark.parametrize('use_exp2', [True, False]) # works when use_exp2 is false
  339. @pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues
  340. def test_op_prefill_fwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scores, layout, use_exp2, DEBUG_INPUT):
  341. dtype = torch.float16
  342. torch.manual_seed(0)
  343. alibi_slopes = None
  344. dropout_p = 0.0
  345. device = "cuda"
  346. if layout == "thd":
  347. q, k, v, metadata = varlen_input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device=device, DEBUG_INPUT=DEBUG_INPUT)
  348. else:
  349. q, k, v, metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT)
  350. if DEBUG_INPUT:
  351. output_triton = torch.zeros_like(q).contiguous()
  352. else:
  353. output_triton = torch.empty_like(q)
  354. # update metadata
  355. metadata.use_exp2 = use_exp2
  356. if causal:
  357. metadata.need_causal()
  358. # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
  359. if return_scores:
  360. metadata.return_scores = True
  361. # call Triton's forward implementation directly
  362. ( output_triton,
  363. softmax_lse_triton,
  364. exp_scores_triton,
  365. _,
  366. _,
  367. _,
  368. _,
  369. _,
  370. _) = attention_prefill_forward_triton_impl(
  371. q,
  372. k,
  373. v,
  374. output_triton,
  375. metadata.sm_scale,
  376. metadata.alibi_slopes,
  377. metadata.causal,
  378. metadata.bias,
  379. metadata.dropout_p,
  380. metadata.layout,
  381. metadata.cu_seqlens_q,
  382. metadata.cu_seqlens_k,
  383. metadata.max_seqlens_q,
  384. metadata.max_seqlens_k,
  385. metadata.return_scores,
  386. metadata.use_exp2)
  387. (
  388. output_ref,
  389. softmax_lse_ref,
  390. exp_scores_ref,
  391. softmax_ref,
  392. attention_shifted_scaled_scores_ref,
  393. attention_scaled_scores_ref,
  394. attention_scores_ref,
  395. ) = attention_forward_pytorch_ref_impl(
  396. q.clone(),
  397. k.clone(),
  398. v.clone(),
  399. metadata.sm_scale,
  400. causal,
  401. layout,
  402. metadata.cu_seqlens_q,
  403. metadata.cu_seqlens_k,
  404. metadata.max_seqlens_q,
  405. metadata.max_seqlens_k,
  406. use_exp2
  407. )
  408. if DEBUG:
  409. print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape)
  410. print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape)
  411. torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL)
  412. if layout != "thd":
  413. # use trick with lse to get the softmax. you need the scores but is it
  414. softmax_triton = torch.exp(attention_scaled_scores_ref - softmax_lse_triton.unsqueeze(-1))
  415. if DEBUG:
  416. print("attention_scaled_scores_ref:", attention_scaled_scores_ref, attention_scaled_scores_ref.shape)
  417. print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape)
  418. print("softmax_triton:", softmax_triton, softmax_triton.shape)
  419. print("softmax_ref:", softmax_ref, softmax_ref.shape)
  420. torch.testing.assert_close(softmax_triton, softmax_ref, atol=ATOL, rtol=RTOL)
  421. if DEBUG:
  422. print("output_triton:", output_triton, output_triton.shape)
  423. print("output_ref:", output_ref, output_ref.shape)
  424. torch.testing.assert_close(output_triton, output_ref, atol=ATOL, rtol=RTOL)
  425. # compare with pytorch expect thd and causal impl is different
  426. if False and layout in ["bhsd", "bshd"] and not causal:
  427. out_pytorch, softmax_pytorch = torch.ops.aten._scaled_dot_product_attention_math(
  428. q.transpose(1, 2) if layout == "bshd" else q ,
  429. k.transpose(1, 2) if layout == "bshd" else k,
  430. v.transpose(1, 2) if layout == "bshd" else v,
  431. dropout_p=dropout_p,
  432. is_causal=causal, scale=metadata.sm_scale,
  433. dropout_mask=None)
  434. out_pytorch = out_pytorch.transpose(1, 2) if layout == "bshd" else out_pytorch
  435. if DEBUG:
  436. print("o:", output_triton, output_triton.shape)
  437. print("out_pytorch:", out_pytorch, out_pytorch.shape)
  438. torch.testing.assert_close(output_triton, out_pytorch, atol=ATOL, rtol=RTOL)
  439. # compare with pytorch output
  440. if DEBUG:
  441. print("softmax_triton:", softmax_triton, softmax_triton.shape)
  442. print("softmax_pytorch:", softmax_pytorch, softmax_pytorch.shape)
  443. torch.testing.assert_close(softmax_triton, softmax_pytorch.to(torch.float32), atol=ATOL, rtol=RTOL)
  444. @pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [
  445. (1, 1, 1, 1, 1),
  446. (1, 1, 4, 4, 4),
  447. (2, 1, 4, 4, 16),
  448. (1, 2, 4, 4, 16),
  449. (2, 2, 4, 4, 16),
  450. (1, 1, 4, 4, 16),
  451. (2, 1, 4, 4 , 16),
  452. (4, 6, 8, 8 , 16),
  453. (1, 1, 4, 4, 32),
  454. (1, 1, 16, 16, 16),
  455. (1, 1, 32, 32, 16),
  456. (1, 1, 64, 64, 16),
  457. (1, 1, 64, 64, 64),
  458. (1, 1, 64, 128, 32),
  459. (1, 1, 128, 128, 64),
  460. (1, 1, 128, 256, 45),
  461. (1, 1, 113, 203, 192),
  462. (1, 1, 256, 256, 64),
  463. (1, 1, 256, 512, 16),
  464. (1, 1, 512, 512, 64),
  465. (1, 1, 1024, 1024, 64),
  466. # fa configs
  467. (2, 2, 128, 128, 65),
  468. (2, 2, 128, 128, 224),
  469. (4, 6, 108, 256, 224),
  470. (1, 1, 256, 512, 16),
  471. # old tests that work
  472. (4, 48, 1024, 1024, 73),
  473. (4, 48, 1024, 1024, 64),
  474. (4, 48, 2048, 2048, 64),
  475. (1, 24, 4096, 4096, 64),
  476. (1, 16, 1024, 1024, 64),
  477. (1, 16, 1024, 1024, 128),
  478. ])
  479. @pytest.mark.parametrize('causal', [True, False])
  480. @pytest.mark.parametrize('use_exp2', [False]) # FIXME: using exp2 causes issue when used with causal
  481. @pytest.mark.parametrize('layout', ["bhsd", "bshd", "thd"])
  482. @pytest.mark.parametrize('sequence_parallel', [True, False])
  483. @pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans in both new and old backend
  484. def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, layout, sequence_parallel, DEBUG_INPUT):
  485. dtype = torch.float16
  486. torch.manual_seed(20) # seed from test_op_bwd
  487. alibi_slopes = None
  488. if layout == "thd":
  489. q, k, v, metadata = varlen_input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, DEBUG_INPUT=DEBUG_INPUT)
  490. else:
  491. q, k, v, metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, DEBUG_INPUT=DEBUG_INPUT)
  492. if DEBUG_INPUT:
  493. do = torch.ones_like(q).contiguous()
  494. else:
  495. do = torch.randn_like(q)
  496. # =============================================== Reference ==============================================================
  497. q_ref = q.clone()
  498. k_ref = k.clone()
  499. v_ref = v.clone()
  500. (
  501. o_ref,
  502. softmax_lse_ref,
  503. _,
  504. _,
  505. _,
  506. _,
  507. _,
  508. ) = attention_forward_pytorch_ref_impl(
  509. q_ref,
  510. k_ref,
  511. v_ref,
  512. metadata.sm_scale,
  513. causal,
  514. layout,
  515. metadata.cu_seqlens_q,
  516. metadata.cu_seqlens_k,
  517. metadata.max_seqlens_q,
  518. metadata.max_seqlens_k,
  519. use_exp2
  520. )
  521. dq = torch.zeros_like(q, dtype=q.dtype) # NOTE: the kernel does inplace accumlation on dq so dq has to be zeros
  522. if DEBUG_INPUT:
  523. dk = torch.zeros_like(k, dtype=k.dtype)
  524. dv = torch.zeros_like(v, dtype=v.dtype)
  525. else:
  526. dk = torch.empty_like(k, dtype=k.dtype)
  527. dv = torch.empty_like(v, dtype=v.dtype)
  528. do_ref = do.clone()
  529. dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
  530. do_ref,
  531. q_ref,
  532. k_ref,
  533. v_ref,
  534. o_ref,
  535. softmax_lse_ref,
  536. metadata.sm_scale,
  537. causal,
  538. layout,
  539. metadata.cu_seqlens_q,
  540. metadata.cu_seqlens_k,
  541. metadata.max_seqlens_q,
  542. metadata.max_seqlens_k,
  543. use_exp2
  544. )
  545. # =============================================== Triton ==============================================================
  546. o = o_ref.clone().contiguous()
  547. softmax_lse = softmax_lse_ref.clone().contiguous()
  548. dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl(
  549. do,
  550. q,
  551. k,
  552. v,
  553. o,
  554. softmax_lse,
  555. dq,
  556. dk,
  557. dv,
  558. metadata.sm_scale,
  559. alibi_slopes,
  560. causal,
  561. layout,
  562. metadata.cu_seqlens_q,
  563. metadata.cu_seqlens_k,
  564. metadata.max_seqlens_q,
  565. metadata.max_seqlens_k,
  566. use_exp2,
  567. sequence_parallel=sequence_parallel
  568. )
  569. # =============================================== Check ==============================================================
  570. if DEBUG:
  571. print()
  572. if DEBUG:
  573. print("delta_triton:", delta_triton, delta_triton.shape)
  574. print("delta_ref:", delta_ref, delta_ref.shape)
  575. torch.testing.assert_close(delta_triton, delta_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN)
  576. if DEBUG:
  577. print("dv_triton:", dv_triton, dv_triton.shape)
  578. print("dv_ref:", dv_ref, dv_ref.shape)
  579. torch.testing.assert_close(dv_triton, dv_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN)
  580. if DEBUG:
  581. print("dk_triton:", dk_triton, dk_triton.shape)
  582. print("dk_ref:", dk_ref, dk_ref.shape)
  583. torch.testing.assert_close(dk_triton, dk_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN)
  584. if DEBUG:
  585. print("dq_triton:", dq_triton, dq_triton.shape)
  586. print("dq_ref:", dq_ref, dq_ref.shape)
  587. torch.testing.assert_close(dq_triton, dq_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN)
  588. @pytest.mark.parametrize('batch_size, seqlen_q, seqlen_k, group_q, group_k, dim', get_input_shapes())
  589. def test_op_fwd_decode(batch_size, seqlen_q, seqlen_k, group_q, group_k, dim, dtype=torch.bfloat16):
  590. if DEBUG:
  591. print()
  592. print(f"batch_size = {batch_size}, seqlen_q = {seqlen_q}, seqlen_k = {seqlen_k}, group_q = {group_q}, group_k = {group_k}, dim = {dim}")
  593. torch.manual_seed(20)
  594. query_group_head_size = (group_q + group_k - 1) // group_k
  595. q = (torch.empty((batch_size, seqlen_q, group_k, query_group_head_size, dim), dtype=dtype,
  596. device="cuda").normal_(mean=0., std=0.5).requires_grad_())
  597. k = (torch.empty((batch_size, seqlen_k, group_k, 1, dim), dtype=dtype,
  598. device="cuda").normal_(mean=0.,
  599. std=0.5).requires_grad_()).expand(-1, -1, -1, query_group_head_size, -1)
  600. v = (torch.empty((batch_size, seqlen_k, group_k, 1, dim), dtype=dtype,
  601. device="cuda").normal_(mean=0.,
  602. std=0.5).requires_grad_()).expand(-1, -1, -1, query_group_head_size, -1)
  603. scale = 1 / dim**0.5
  604. input_metadata = MetaData(sm_scale=scale)
  605. input_metadata.layout = "bsghd"
  606. tri_out, _ = attention_decode(q, k, v, input_metadata)
  607. q = q.reshape([batch_size, seqlen_q, -1, dim]).permute(0, 2, 1, 3)
  608. k = k.reshape([batch_size, seqlen_k, -1, dim]).permute(0, 2, 1, 3)
  609. v = v.reshape([batch_size, seqlen_k, -1, dim]).permute(0, 2, 1, 3)
  610. attn = (q @ k.transpose(-1, -2) * scale).softmax(-1)
  611. ref_out = attn @ v
  612. # compare
  613. torch.testing.assert_close(ref_out, tri_out, atol=1e-3, rtol=0)
  614. def test_quantization():
  615. a = torch.randn((2, 4, 32), dtype=torch.float16, device='cuda')
  616. qa = quantize_kv_int4(a, num_groups=4)
  617. dqa = dequantize_kv_fp16(qa, num_groups=4)
  618. torch.testing.assert_close(a, dqa, atol=1.5e-1, rtol=1e-1)
  619. @pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', get_input_shapes())
  620. def test_op_fwd_decode_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16):
  621. pytest.skip("Decode kernel doesnot support quantization yet")
  622. torch.manual_seed(2)
  623. q = (torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype,
  624. device="cuda").normal_(mean=1.0, std=0.5).requires_grad_())
  625. k = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype,
  626. device="cuda").normal_(mean=1.0,
  627. std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1)
  628. v = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype,
  629. device="cuda").normal_(mean=1.0,
  630. std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1)
  631. num_groups = 1
  632. quant_k = (quantize_kv_int4(k, num_groups=num_groups).contiguous().view(torch.int32))
  633. quant_v = (quantize_kv_int4(v, num_groups=num_groups).contiguous().view(torch.int32))
  634. scale = 1 / K**0.5
  635. input_metadata = MetaData(sm_scale=scale)
  636. input_metadata.layout = "bsghd"
  637. tri_out, _ = attention_decode(q, quant_k, quant_v, input_metadata)
  638. q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3)
  639. k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
  640. v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
  641. attn = (q @ k.transpose(-1, -2) * scale).softmax(-1)
  642. ref_out = attn @ v
  643. # compare
  644. torch.testing.assert_close(ref_out, tri_out, atol=2.1e-2, rtol=0)
  645. # since quantization introduces rounding error, use the
  646. # dequantized kv as inputs to the ref implementation to reduce
  647. # the tolerance to 1e-3
  648. dqk = dequantize_kv_fp16(quant_k, num_groups=num_groups)
  649. dqv = dequantize_kv_fp16(quant_v, num_groups=num_groups)
  650. dqk = dqk.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
  651. dqv = dqv.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
  652. dq_attn = (q @ dqk.transpose(-1, -2) * scale).softmax(-1)
  653. dq_ref_out = dq_attn @ dqv
  654. torch.testing.assert_close(dq_ref_out, tri_out, atol=1e-3, rtol=0)