1
0

test_flash_attn_ck.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754
  1. import math
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from einops import rearrange, repeat
  6. from flash_attn import (
  7. flash_attn_func,
  8. flash_attn_kvpacked_func,
  9. flash_attn_qkvpacked_func,
  10. flash_attn_varlen_func,
  11. flash_attn_varlen_kvpacked_func,
  12. flash_attn_varlen_qkvpacked_func,
  13. )
  14. from test_flash_attn import (
  15. attn_bias_from_alibi_slopes,
  16. convert_flash_attn_S_to_softmax,
  17. generate_qkv,
  18. generate_random_padding_mask,
  19. attention_ref,
  20. attention_kvpacked_ref,
  21. attention_qkvpacked_ref,
  22. )
  23. def is_bwd_hdim_supported(d):
  24. return d <= 128 and d % 2 == 0
  25. def ck_randval_to_dropout_mask(randval, p):
  26. # If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout
  27. # randval in 255 * [0, 0.7] will be kept
  28. # If return dropout_mask >=0, value will be kept
  29. return torch.floor(255.0 * (1 - p) - randval)
  30. def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_rounded, seqlen_k_rounded):
  31. """ pad + rearrange [nheads, total_q, max_seqlen_k] into [b, nheads, seqlen_q_rounded, seqlen_k_rounded]
  32. Arguments:
  33. S_dmask: (nheads, total_q, max_seqlen_k)
  34. cu_seqlens_q: (b + 1)
  35. Output:
  36. S_dmask: (b, nheads, seqlen_q_rounded, seqlen_k_rounded)
  37. """
  38. batch_size = cu_seqlens_q.numel() - 1
  39. seqlens_q = torch.roll(cu_seqlens_q, shifts = -1) - cu_seqlens_q
  40. seqlens_q = seqlens_q[0:batch_size].tolist()
  41. S_dmask = torch.split(S_dmask, seqlens_q, dim=1)
  42. # [(nheads, seqlen_q0, max_seqlen_k), (nheads, seqlen_q1, max_seqlen_k), ..., (nheads, seqlen_qb, max_seqlen_k)]
  43. masks = ()
  44. for mask in S_dmask:
  45. # (nheads, seqlen_qi, max_seqlen_k) -> (nheads, seqlen_q_rounded, seqlen_k_rounded)
  46. mask = F.pad(mask, (0, seqlen_k_rounded - mask.shape[2], 0, seqlen_q_rounded - mask.shape[1], 0, 0)).unsqueeze(1)
  47. masks = masks + (mask, )
  48. S_dmask = torch.cat(masks, dim=1)
  49. S_dmask = S_dmask.transpose(0, 1)
  50. return S_dmask
  51. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  52. @pytest.mark.parametrize("deterministic", [False])
  53. @pytest.mark.parametrize("alibi", [False, True])
  54. @pytest.mark.parametrize("local", [False, True])
  55. @pytest.mark.parametrize("causal", [False, True])
  56. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  57. @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
  58. @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  59. def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
  60. if d > 256:
  61. pytest.skip()
  62. device = "cuda"
  63. # set seed
  64. torch.random.manual_seed(0)
  65. batch_size = 4
  66. nheads = 9
  67. window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
  68. qkv = torch.randn(
  69. batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
  70. )
  71. if alibi:
  72. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  73. attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)
  74. else:
  75. alibi_slopes, attn_bias = None, None
  76. out, lse, S_dmask = flash_attn_qkvpacked_func(
  77. qkv,
  78. dropout_p,
  79. causal=causal,
  80. window_size=window_size,
  81. alibi_slopes=alibi_slopes,
  82. deterministic=deterministic,
  83. return_attn_probs=True,
  84. )
  85. if dropout_p > 0.0:
  86. # TODO - move to c++ mha_varlen_fwd()
  87. S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
  88. S_dmask_converted = convert_flash_attn_S_to_softmax(
  89. S_dmask,
  90. seqlen,
  91. seqlen,
  92. None,
  93. None,
  94. d,
  95. dropout_p > 0.0,
  96. causal=causal,
  97. window_size=window_size,
  98. )
  99. dropout_mask = S_dmask_converted >= 0
  100. # CK does not return P. Hence, we don't test the attn here.
  101. else:
  102. dropout_mask = None
  103. out_ref, attn_ref = attention_qkvpacked_ref(
  104. qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size
  105. )
  106. out_pt, attn_pt = attention_qkvpacked_ref(
  107. qkv,
  108. None,
  109. attn_bias,
  110. dropout_p,
  111. dropout_mask,
  112. causal=causal,
  113. window_size=window_size,
  114. upcast=False,
  115. reorder_ops=True,
  116. )
  117. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  118. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  119. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  120. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  121. # Check that FlashAttention's numerical error is at most twice the numerical error
  122. # of a Pytorch implementation.
  123. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  124. g = torch.randn_like(out)
  125. if is_bwd_hdim_supported(d):
  126. (dqkv,) = torch.autograd.grad(out, qkv, g)
  127. (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
  128. (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
  129. print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  130. print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  131. print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  132. print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
  133. print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  134. print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  135. print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  136. print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
  137. # TODO - use 10 times to check, wait for ck to change dq type to f32
  138. assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()
  139. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  140. @pytest.mark.parametrize("deterministic", [False])
  141. @pytest.mark.parametrize("alibi", [False, True])
  142. @pytest.mark.parametrize("local", [False, True])
  143. @pytest.mark.parametrize("causal", [False, True])
  144. @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
  145. @pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])
  146. @pytest.mark.parametrize("dropout_p", [0, 0.17])
  147. def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
  148. if d > 256:
  149. pytest.skip()
  150. device = "cuda"
  151. # set seed
  152. torch.random.manual_seed(0)
  153. batch_size = 5
  154. nheads = 6
  155. window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
  156. qkv = torch.randn(
  157. batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
  158. )
  159. key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
  160. # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
  161. if alibi:
  162. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  163. attn_bias = attn_bias_from_alibi_slopes(
  164. alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal
  165. )
  166. else:
  167. alibi_slopes, attn_bias = None, None
  168. qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
  169. *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
  170. )
  171. out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
  172. qkv_unpad,
  173. cu_seqlens,
  174. max_seqlen,
  175. dropout_p,
  176. causal=causal,
  177. window_size=window_size,
  178. alibi_slopes=alibi_slopes,
  179. deterministic=deterministic,
  180. return_attn_probs=True,
  181. )
  182. out = output_pad_fn(out_unpad)
  183. if dropout_p > 0.0:
  184. # TODO - move to c++ mha_varlen_fwd()
  185. S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
  186. S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens, seqlen, seqlen)
  187. S_dmask_converted = convert_flash_attn_S_to_softmax(
  188. S_dmask,
  189. seqlen,
  190. seqlen,
  191. key_padding_mask,
  192. key_padding_mask,
  193. d,
  194. dropout_p > 0.0,
  195. causal=causal,
  196. window_size=window_size,
  197. )
  198. dropout_mask = S_dmask_converted >= 0
  199. # CK does not return P. Hence, we don't test the attn here.
  200. else:
  201. dropout_mask = None
  202. out_ref, attn_ref = attention_qkvpacked_ref(
  203. qkv,
  204. key_padding_mask,
  205. attn_bias,
  206. dropout_p,
  207. dropout_mask,
  208. causal=causal,
  209. window_size=window_size,
  210. )
  211. out_pt, attn_pt = attention_qkvpacked_ref(
  212. qkv,
  213. key_padding_mask,
  214. attn_bias,
  215. dropout_p,
  216. dropout_mask,
  217. causal=causal,
  218. window_size=window_size,
  219. upcast=False,
  220. reorder_ops=True,
  221. )
  222. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  223. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  224. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  225. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  226. # Check that FlashAttention's numerical error is at most twice the numerical error
  227. # of a Pytorch implementation.
  228. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  229. g = torch.randn_like(out)
  230. if is_bwd_hdim_supported(d):
  231. (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
  232. dqkv = dqkv_pad_fn(dqkv_unpad)
  233. (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
  234. (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
  235. print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  236. print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  237. print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  238. print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
  239. print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  240. print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  241. print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  242. print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
  243. # TODO - use 10 times to check, wait for ck to change dq type to f32
  244. assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()
  245. @pytest.mark.parametrize("kvpacked", [True, False])
  246. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  247. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  248. @pytest.mark.parametrize("deterministic", [False])
  249. @pytest.mark.parametrize("alibi", [False, True])
  250. @pytest.mark.parametrize("local", [False, True])
  251. @pytest.mark.parametrize("causal", [False, True])
  252. @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
  253. @pytest.mark.parametrize(
  254. "seqlen_q,seqlen_k",
  255. [
  256. (113, 203),
  257. (128, 217),
  258. (113, 211),
  259. (108, 256),
  260. (256, 512),
  261. (512, 256),
  262. (1024, 1024),
  263. (1023, 1024),
  264. (1024, 1023),
  265. (2048, 2048),
  266. ],
  267. )
  268. @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  269. def test_flash_attn_output(
  270. seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
  271. ):
  272. device = "cuda"
  273. # set seed
  274. torch.random.manual_seed(0)
  275. batch_size = 4
  276. nheads = 9
  277. nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
  278. assert nheads % nheads_k == 0
  279. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  280. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  281. if kvpacked:
  282. kv = torch.randn(
  283. batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  284. )
  285. else:
  286. k = torch.randn(
  287. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  288. )
  289. v = torch.randn(
  290. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  291. )
  292. if alibi:
  293. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  294. attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
  295. else:
  296. alibi_slopes, attn_bias = None, None
  297. if kvpacked:
  298. out, lse, S_dmask = flash_attn_kvpacked_func(
  299. q,
  300. kv,
  301. dropout_p,
  302. causal=causal,
  303. window_size=window_size,
  304. alibi_slopes=alibi_slopes,
  305. deterministic=deterministic,
  306. return_attn_probs=True,
  307. )
  308. else:
  309. out, lse, S_dmask = flash_attn_func(
  310. q,
  311. k,
  312. v,
  313. dropout_p,
  314. causal=causal,
  315. window_size=window_size,
  316. alibi_slopes=alibi_slopes,
  317. deterministic=deterministic,
  318. return_attn_probs=True,
  319. )
  320. if dropout_p > 0.0:
  321. # TODO - move to c++ mha_varlen_fwd()
  322. S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
  323. S_dmask_converted = convert_flash_attn_S_to_softmax(
  324. S_dmask,
  325. seqlen_q,
  326. seqlen_k,
  327. None,
  328. None,
  329. d,
  330. dropout_p > 0.0,
  331. causal=causal,
  332. window_size=window_size,
  333. )
  334. dropout_mask = S_dmask_converted >= 0
  335. if kvpacked:
  336. kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
  337. k_rep, v_rep = kv_rep.unbind(dim=2)
  338. else:
  339. k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  340. v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  341. # CK does not return P. Hence, we don't test the attn here.
  342. else:
  343. dropout_mask = None
  344. if kvpacked:
  345. out_ref, attn_ref = attention_kvpacked_ref(
  346. q,
  347. kv,
  348. None,
  349. None,
  350. attn_bias,
  351. dropout_p,
  352. dropout_mask,
  353. causal=causal,
  354. window_size=window_size,
  355. )
  356. out_pt, attn_pt = attention_kvpacked_ref(
  357. q,
  358. kv,
  359. None,
  360. None,
  361. attn_bias,
  362. dropout_p,
  363. dropout_mask,
  364. causal=causal,
  365. window_size=window_size,
  366. upcast=False,
  367. reorder_ops=True,
  368. )
  369. else:
  370. out_ref, attn_ref = attention_ref(
  371. q,
  372. k,
  373. v,
  374. None,
  375. None,
  376. attn_bias,
  377. dropout_p,
  378. dropout_mask,
  379. causal=causal,
  380. window_size=window_size,
  381. )
  382. out_pt, attn_pt = attention_ref(
  383. q,
  384. k,
  385. v,
  386. None,
  387. None,
  388. attn_bias,
  389. dropout_p,
  390. dropout_mask,
  391. causal=causal,
  392. window_size=window_size,
  393. upcast=False,
  394. reorder_ops=True,
  395. )
  396. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  397. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  398. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  399. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  400. # Check that FlashAttention's numerical error is at most twice the numerical error
  401. # of a Pytorch implementation.
  402. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  403. g = torch.randn_like(out)
  404. if is_bwd_hdim_supported(d):
  405. if kvpacked:
  406. (
  407. dq,
  408. dkv,
  409. ) = torch.autograd.grad(out, (q, kv), g)
  410. dk, dv = dkv.unbind(2)
  411. (
  412. dq_ref,
  413. dkv_ref,
  414. ) = torch.autograd.grad(out_ref, (q, kv), g)
  415. dk_ref, dv_ref = dkv_ref.unbind(2)
  416. (
  417. dq_pt,
  418. dkv_pt,
  419. ) = torch.autograd.grad(out_pt, (q, kv), g)
  420. dk_pt, dv_pt = dkv_pt.unbind(2)
  421. else:
  422. (
  423. dq,
  424. dk,
  425. dv,
  426. ) = torch.autograd.grad(out, (q, k, v), g)
  427. (
  428. dq_ref,
  429. dk_ref,
  430. dv_ref,
  431. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  432. (
  433. dq_pt,
  434. dk_pt,
  435. dv_pt,
  436. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  437. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  438. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  439. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  440. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  441. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  442. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  443. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  444. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  445. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  446. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  447. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  448. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  449. # TODO - use 10 times to check, wait for ck to change dq type to f32
  450. assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()
  451. assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()
  452. assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()
  453. @pytest.mark.parametrize("kvpacked", [True, False])
  454. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  455. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  456. @pytest.mark.parametrize("deterministic", [False, True])
  457. @pytest.mark.parametrize("alibi", [False, True])
  458. @pytest.mark.parametrize("local", [False, True])
  459. @pytest.mark.parametrize("causal", [False, True])
  460. @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  461. @pytest.mark.parametrize(
  462. "seqlen_q,seqlen_k",
  463. [
  464. (1, 147),
  465. (113, 203),
  466. (128, 217),
  467. (113, 211),
  468. (108, 256),
  469. (256, 512),
  470. (512, 256),
  471. (1024, 1024),
  472. (1023, 1024),
  473. (1024, 1023),
  474. (2048, 2048),
  475. ],
  476. )
  477. @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  478. def test_flash_attn_varlen_output(
  479. seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
  480. ):
  481. device = "cuda"
  482. # set seed
  483. torch.random.manual_seed(0)
  484. batch_size = 4
  485. nheads = 9
  486. nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
  487. assert nheads % nheads_k == 0
  488. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  489. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  490. if kvpacked:
  491. kv = torch.randn(
  492. batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  493. )
  494. else:
  495. k = torch.randn(
  496. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  497. )
  498. v = torch.randn(
  499. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  500. )
  501. query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
  502. key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
  503. # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
  504. if alibi:
  505. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  506. attn_bias = attn_bias_from_alibi_slopes(
  507. alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal
  508. )
  509. else:
  510. alibi_slopes, attn_bias = None, None
  511. if kvpacked:
  512. (
  513. q_unpad,
  514. kv_unpad,
  515. cu_seqlens_q,
  516. cu_seqlens_k,
  517. max_seqlen_q,
  518. max_seqlen_k,
  519. q,
  520. kv,
  521. output_pad_fn,
  522. dq_pad_fn,
  523. dkv_pad_fn,
  524. ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)
  525. out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
  526. q_unpad,
  527. kv_unpad,
  528. cu_seqlens_q,
  529. cu_seqlens_k,
  530. max_seqlen_q,
  531. max_seqlen_k,
  532. dropout_p,
  533. causal=causal,
  534. window_size=window_size,
  535. alibi_slopes=alibi_slopes,
  536. deterministic=deterministic,
  537. return_attn_probs=True,
  538. )
  539. else:
  540. (
  541. q_unpad,
  542. k_unpad,
  543. v_unpad,
  544. cu_seqlens_q,
  545. cu_seqlens_k,
  546. max_seqlen_q,
  547. max_seqlen_k,
  548. q,
  549. k,
  550. v,
  551. output_pad_fn,
  552. dq_pad_fn,
  553. dk_pad_fn,
  554. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
  555. out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
  556. q_unpad,
  557. k_unpad,
  558. v_unpad,
  559. cu_seqlens_q,
  560. cu_seqlens_k,
  561. max_seqlen_q,
  562. max_seqlen_k,
  563. dropout_p,
  564. causal=causal,
  565. window_size=window_size,
  566. alibi_slopes=alibi_slopes,
  567. deterministic=deterministic,
  568. return_attn_probs=True,
  569. )
  570. out = output_pad_fn(out_unpad)
  571. if dropout_p > 0.0:
  572. # TODO - move to c++ mha_varlen_fwd()
  573. S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
  574. S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q, seqlen_k)
  575. S_dmask_converted = convert_flash_attn_S_to_softmax(
  576. S_dmask,
  577. seqlen_q,
  578. seqlen_k,
  579. query_padding_mask,
  580. key_padding_mask,
  581. d,
  582. dropout_p > 0.0,
  583. causal=causal,
  584. window_size=window_size,
  585. )
  586. dropout_mask = S_dmask_converted >= 0
  587. if kvpacked:
  588. kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
  589. k_rep, v_rep = kv_rep.unbind(dim=2)
  590. else:
  591. k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  592. v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  593. # CK does not return P. Hence, we don't test the attn here.
  594. else:
  595. dropout_mask = None
  596. if kvpacked:
  597. out_ref, attn_ref = attention_kvpacked_ref(
  598. q,
  599. kv,
  600. query_padding_mask,
  601. key_padding_mask,
  602. attn_bias,
  603. dropout_p,
  604. dropout_mask,
  605. causal=causal,
  606. window_size=window_size,
  607. )
  608. out_pt, attn_pt = attention_kvpacked_ref(
  609. q,
  610. kv,
  611. query_padding_mask,
  612. key_padding_mask,
  613. attn_bias,
  614. dropout_p,
  615. dropout_mask,
  616. causal=causal,
  617. window_size=window_size,
  618. upcast=False,
  619. reorder_ops=True,
  620. )
  621. else:
  622. out_ref, attn_ref = attention_ref(
  623. q,
  624. k,
  625. v,
  626. query_padding_mask,
  627. key_padding_mask,
  628. attn_bias,
  629. dropout_p,
  630. dropout_mask,
  631. causal=causal,
  632. window_size=window_size,
  633. )
  634. out_pt, attn_pt = attention_ref(
  635. q,
  636. k,
  637. v,
  638. query_padding_mask,
  639. key_padding_mask,
  640. attn_bias,
  641. dropout_p,
  642. dropout_mask,
  643. causal=causal,
  644. window_size=window_size,
  645. upcast=False,
  646. reorder_ops=True,
  647. )
  648. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  649. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  650. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  651. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  652. # Check that FlashAttention's numerical error is at most 4 times the numerical error
  653. # of a Pytorch implementation.
  654. assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item()
  655. g = torch.randn_like(out)
  656. if is_bwd_hdim_supported(d):
  657. if kvpacked:
  658. (
  659. dq_unpad,
  660. dkv_unpad,
  661. ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
  662. dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
  663. (
  664. dq_ref,
  665. dkv_ref,
  666. ) = torch.autograd.grad(out_ref, (q, kv), g)
  667. dk_ref, dv_ref = dkv_ref.unbind(2)
  668. (
  669. dq_pt,
  670. dkv_pt,
  671. ) = torch.autograd.grad(out_pt, (q, kv), g)
  672. dk_pt, dv_pt = dkv_pt.unbind(2)
  673. else:
  674. (
  675. dq_unpad,
  676. dk_unpad,
  677. dv_unpad,
  678. ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
  679. dk = dk_pad_fn(dk_unpad)
  680. dv = dk_pad_fn(dv_unpad)
  681. (
  682. dq_ref,
  683. dk_ref,
  684. dv_ref,
  685. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  686. (
  687. dq_pt,
  688. dk_pt,
  689. dv_pt,
  690. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  691. dq = dq_pad_fn(dq_unpad)
  692. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  693. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  694. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  695. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  696. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  697. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  698. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  699. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  700. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  701. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  702. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  703. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  704. # TODO - use 10 times to check, wait for ck to change dq type to f32
  705. assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()
  706. assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()
  707. assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()