test_flash_attn_ck.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620
  1. import math
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from einops import rearrange, repeat
  6. from flash_attn import (
  7. flash_attn_func,
  8. flash_attn_kvpacked_func,
  9. flash_attn_qkvpacked_func,
  10. flash_attn_varlen_func,
  11. flash_attn_varlen_kvpacked_func,
  12. flash_attn_varlen_qkvpacked_func,
  13. flash_attn_with_kvcache,
  14. )
  15. from test_flash_attn import (
  16. attn_bias_from_alibi_slopes,
  17. convert_flash_attn_S_to_softmax,
  18. generate_qkv,
  19. generate_random_padding_mask,
  20. _generate_block_kvcache,
  21. attention_ref,
  22. attention_kvpacked_ref,
  23. attention_qkvpacked_ref,
  24. )
  25. from flash_attn.layers.rotary import apply_rotary_emb
  26. def is_bwd_hdim_supported(d):
  27. return d <= 256
  28. def ck_randval_to_dropout_mask(randval, p):
  29. # If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout
  30. # randval in 255 * [0, 0.7] will be kept
  31. # If return dropout_mask >=0, value will be kept
  32. return math.floor(255.0 * (1 - p)) - randval.to(torch.float32)
  33. def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_rounded, seqlen_k_rounded):
  34. """ pad + rearrange [nheads, total_q, max_seqlen_k] into [b, nheads, seqlen_q_rounded, seqlen_k_rounded]
  35. Arguments:
  36. S_dmask: (nheads, total_q, max_seqlen_k)
  37. cu_seqlens_q: (b + 1)
  38. Output:
  39. S_dmask: (b, nheads, seqlen_q_rounded, seqlen_k_rounded)
  40. """
  41. batch_size = cu_seqlens_q.numel() - 1
  42. seqlens_q = torch.roll(cu_seqlens_q, shifts = -1) - cu_seqlens_q
  43. seqlens_q = seqlens_q[0:batch_size].tolist()
  44. S_dmask = torch.split(S_dmask, seqlens_q, dim=1)
  45. # [(nheads, seqlen_q0, max_seqlen_k), (nheads, seqlen_q1, max_seqlen_k), ..., (nheads, seqlen_qb, max_seqlen_k)]
  46. masks = ()
  47. for mask in S_dmask:
  48. # (nheads, seqlen_qi, max_seqlen_k) -> (nheads, seqlen_q_rounded, seqlen_k_rounded)
  49. mask = F.pad(mask, (0, seqlen_k_rounded - mask.shape[2], 0, seqlen_q_rounded - mask.shape[1], 0, 0)).unsqueeze(1)
  50. masks = masks + (mask, )
  51. S_dmask = torch.cat(masks, dim=1)
  52. S_dmask = S_dmask.transpose(0, 1)
  53. return S_dmask
  54. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  55. @pytest.mark.parametrize("deterministic", [False, True])
  56. @pytest.mark.parametrize("alibi", [False, True])
  57. @pytest.mark.parametrize("local", [False, True])
  58. @pytest.mark.parametrize("causal", [False, True])
  59. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  60. @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
  61. @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  62. def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
  63. if d > 256:
  64. pytest.skip()
  65. device = "cuda"
  66. # set seed
  67. torch.random.manual_seed(0)
  68. batch_size = 4
  69. nheads = 9
  70. window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
  71. qkv = torch.randn(
  72. batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
  73. )
  74. if alibi:
  75. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  76. attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)
  77. else:
  78. alibi_slopes, attn_bias = None, None
  79. out, lse, S_dmask = flash_attn_qkvpacked_func(
  80. qkv,
  81. dropout_p,
  82. causal=causal,
  83. window_size=window_size,
  84. alibi_slopes=alibi_slopes,
  85. deterministic=deterministic,
  86. return_attn_probs=True,
  87. )
  88. if dropout_p > 0.0:
  89. # TODO - move to c++ mha_varlen_fwd()
  90. S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
  91. S_dmask_converted = convert_flash_attn_S_to_softmax(
  92. S_dmask,
  93. seqlen,
  94. seqlen,
  95. None,
  96. None,
  97. d,
  98. dropout_p > 0.0,
  99. causal=causal,
  100. window_size=window_size,
  101. )
  102. dropout_mask = S_dmask_converted >= 0
  103. # CK does not return P. Hence, we don't test the attn here.
  104. else:
  105. dropout_mask = None
  106. out_ref, attn_ref = attention_qkvpacked_ref(
  107. qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size
  108. )
  109. out_pt, attn_pt = attention_qkvpacked_ref(
  110. qkv,
  111. None,
  112. attn_bias,
  113. dropout_p,
  114. dropout_mask,
  115. causal=causal,
  116. window_size=window_size,
  117. upcast=False,
  118. reorder_ops=True,
  119. )
  120. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  121. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  122. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  123. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  124. # Check that FlashAttention's numerical error is at most twice the numerical error
  125. # of a Pytorch implementation.
  126. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  127. g = torch.randn_like(out)
  128. if is_bwd_hdim_supported(d):
  129. (dqkv,) = torch.autograd.grad(out, qkv, g)
  130. (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
  131. (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
  132. print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  133. print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  134. print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  135. print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
  136. print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  137. print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  138. print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  139. print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
  140. # TODO - use 10 times to check, wait for ck to fix bwd precision issue
  141. assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()
  142. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  143. @pytest.mark.parametrize("deterministic", [False, True])
  144. @pytest.mark.parametrize("alibi", [False, True])
  145. @pytest.mark.parametrize("local", [False, True])
  146. @pytest.mark.parametrize("causal", [False, True])
  147. @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
  148. @pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])
  149. @pytest.mark.parametrize("dropout_p", [0, 0.17])
  150. def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
  151. if d > 256:
  152. pytest.skip()
  153. device = "cuda"
  154. # set seed
  155. torch.random.manual_seed(0)
  156. batch_size = 5
  157. nheads = 6
  158. window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
  159. qkv = torch.randn(
  160. batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
  161. )
  162. key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
  163. # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
  164. if alibi:
  165. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  166. attn_bias = attn_bias_from_alibi_slopes(
  167. alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal
  168. )
  169. else:
  170. alibi_slopes, attn_bias = None, None
  171. qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
  172. *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
  173. )
  174. out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
  175. qkv_unpad,
  176. cu_seqlens,
  177. max_seqlen,
  178. dropout_p,
  179. causal=causal,
  180. window_size=window_size,
  181. alibi_slopes=alibi_slopes,
  182. deterministic=deterministic,
  183. return_attn_probs=True,
  184. )
  185. out = output_pad_fn(out_unpad)
  186. if dropout_p > 0.0:
  187. # TODO - move to c++ mha_varlen_fwd()
  188. S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
  189. S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens, seqlen, seqlen)
  190. S_dmask_converted = convert_flash_attn_S_to_softmax(
  191. S_dmask,
  192. seqlen,
  193. seqlen,
  194. key_padding_mask,
  195. key_padding_mask,
  196. d,
  197. dropout_p > 0.0,
  198. causal=causal,
  199. window_size=window_size,
  200. )
  201. dropout_mask = S_dmask_converted >= 0
  202. # CK does not return P. Hence, we don't test the attn here.
  203. else:
  204. dropout_mask = None
  205. out_ref, attn_ref = attention_qkvpacked_ref(
  206. qkv,
  207. key_padding_mask,
  208. attn_bias,
  209. dropout_p,
  210. dropout_mask,
  211. causal=causal,
  212. window_size=window_size,
  213. )
  214. out_pt, attn_pt = attention_qkvpacked_ref(
  215. qkv,
  216. key_padding_mask,
  217. attn_bias,
  218. dropout_p,
  219. dropout_mask,
  220. causal=causal,
  221. window_size=window_size,
  222. upcast=False,
  223. reorder_ops=True,
  224. )
  225. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  226. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  227. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  228. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  229. # Check that FlashAttention's numerical error is at most twice the numerical error
  230. # of a Pytorch implementation.
  231. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  232. g = torch.randn_like(out)
  233. if is_bwd_hdim_supported(d):
  234. (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
  235. dqkv = dqkv_pad_fn(dqkv_unpad)
  236. (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
  237. (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
  238. print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  239. print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  240. print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  241. print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
  242. print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  243. print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  244. print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  245. print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
  246. # TODO - use 10 times to check, wait for ck to fix bwd precision issue
  247. assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()
  248. @pytest.mark.parametrize("kvpacked", [True, False])
  249. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  250. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  251. @pytest.mark.parametrize("deterministic", [False, True])
  252. @pytest.mark.parametrize("alibi", [False, True])
  253. @pytest.mark.parametrize("local", [False, True])
  254. @pytest.mark.parametrize("causal", [False, True])
  255. @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
  256. @pytest.mark.parametrize(
  257. "seqlen_q,seqlen_k",
  258. [
  259. (113, 203),
  260. (128, 217),
  261. (113, 211),
  262. (108, 256),
  263. (256, 512),
  264. (512, 256),
  265. (1024, 1024),
  266. (1023, 1024),
  267. (1024, 1023),
  268. (2048, 2048),
  269. ],
  270. )
  271. @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  272. def test_flash_attn_output(
  273. seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
  274. ):
  275. device = "cuda"
  276. # set seed
  277. torch.random.manual_seed(0)
  278. batch_size = 4
  279. nheads = 9
  280. nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
  281. assert nheads % nheads_k == 0
  282. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  283. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  284. if kvpacked:
  285. kv = torch.randn(
  286. batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  287. )
  288. else:
  289. k = torch.randn(
  290. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  291. )
  292. v = torch.randn(
  293. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  294. )
  295. if alibi:
  296. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  297. attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
  298. else:
  299. alibi_slopes, attn_bias = None, None
  300. if kvpacked:
  301. out, lse, S_dmask = flash_attn_kvpacked_func(
  302. q,
  303. kv,
  304. dropout_p,
  305. causal=causal,
  306. window_size=window_size,
  307. alibi_slopes=alibi_slopes,
  308. deterministic=deterministic,
  309. return_attn_probs=True,
  310. )
  311. else:
  312. out, lse, S_dmask = flash_attn_func(
  313. q,
  314. k,
  315. v,
  316. dropout_p,
  317. causal=causal,
  318. window_size=window_size,
  319. alibi_slopes=alibi_slopes,
  320. deterministic=deterministic,
  321. return_attn_probs=True,
  322. )
  323. if dropout_p > 0.0:
  324. # TODO - move to c++ mha_varlen_fwd()
  325. S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
  326. S_dmask_converted = convert_flash_attn_S_to_softmax(
  327. S_dmask,
  328. seqlen_q,
  329. seqlen_k,
  330. None,
  331. None,
  332. d,
  333. dropout_p > 0.0,
  334. causal=causal,
  335. window_size=window_size,
  336. )
  337. dropout_mask = S_dmask_converted >= 0
  338. if kvpacked:
  339. kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
  340. k_rep, v_rep = kv_rep.unbind(dim=2)
  341. else:
  342. k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  343. v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  344. # CK does not return P. Hence, we don't test the attn here.
  345. else:
  346. dropout_mask = None
  347. if kvpacked:
  348. out_ref, attn_ref = attention_kvpacked_ref(
  349. q,
  350. kv,
  351. None,
  352. None,
  353. attn_bias,
  354. dropout_p,
  355. dropout_mask,
  356. causal=causal,
  357. window_size=window_size,
  358. )
  359. out_pt, attn_pt = attention_kvpacked_ref(
  360. q,
  361. kv,
  362. None,
  363. None,
  364. attn_bias,
  365. dropout_p,
  366. dropout_mask,
  367. causal=causal,
  368. window_size=window_size,
  369. upcast=False,
  370. reorder_ops=True,
  371. )
  372. else:
  373. out_ref, attn_ref = attention_ref(
  374. q,
  375. k,
  376. v,
  377. None,
  378. None,
  379. attn_bias,
  380. dropout_p,
  381. dropout_mask,
  382. causal=causal,
  383. window_size=window_size,
  384. )
  385. out_pt, attn_pt = attention_ref(
  386. q,
  387. k,
  388. v,
  389. None,
  390. None,
  391. attn_bias,
  392. dropout_p,
  393. dropout_mask,
  394. causal=causal,
  395. window_size=window_size,
  396. upcast=False,
  397. reorder_ops=True,
  398. )
  399. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  400. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  401. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  402. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  403. # Check that FlashAttention's numerical error is at most twice the numerical error
  404. # of a Pytorch implementation.
  405. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  406. g = torch.randn_like(out)
  407. if is_bwd_hdim_supported(d):
  408. if kvpacked:
  409. (
  410. dq,
  411. dkv,
  412. ) = torch.autograd.grad(out, (q, kv), g)
  413. dk, dv = dkv.unbind(2)
  414. (
  415. dq_ref,
  416. dkv_ref,
  417. ) = torch.autograd.grad(out_ref, (q, kv), g)
  418. dk_ref, dv_ref = dkv_ref.unbind(2)
  419. (
  420. dq_pt,
  421. dkv_pt,
  422. ) = torch.autograd.grad(out_pt, (q, kv), g)
  423. dk_pt, dv_pt = dkv_pt.unbind(2)
  424. else:
  425. (
  426. dq,
  427. dk,
  428. dv,
  429. ) = torch.autograd.grad(out, (q, k, v), g)
  430. (
  431. dq_ref,
  432. dk_ref,
  433. dv_ref,
  434. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  435. (
  436. dq_pt,
  437. dk_pt,
  438. dv_pt,
  439. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  440. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  441. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  442. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  443. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  444. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  445. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  446. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  447. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  448. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  449. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  450. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  451. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  452. # TODO - use 10 times to check, wait for ck to fix bwd precision issue
  453. assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()
  454. assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()
  455. assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()
  456. @pytest.mark.parametrize("kvpacked", [True, False])
  457. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  458. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  459. @pytest.mark.parametrize("deterministic", [False, True])
  460. @pytest.mark.parametrize("alibi", [False, True])
  461. @pytest.mark.parametrize("local", [False, True])
  462. @pytest.mark.parametrize("causal", [False, True])
  463. @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  464. @pytest.mark.parametrize(
  465. "seqlen_q,seqlen_k",
  466. [
  467. (1, 147),
  468. (113, 203),
  469. (128, 217),
  470. (113, 211),
  471. (108, 256),
  472. (256, 512),
  473. (512, 256),
  474. (1024, 1024),
  475. (1023, 1024),
  476. (1024, 1023),
  477. (2048, 2048),
  478. ],
  479. )
  480. @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  481. def test_flash_attn_varlen_output(
  482. seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
  483. ):
  484. device = "cuda"
  485. # set seed
  486. torch.random.manual_seed(0)
  487. batch_size = 4
  488. nheads = 9
  489. nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
  490. assert nheads % nheads_k == 0
  491. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  492. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  493. if kvpacked:
  494. kv = torch.randn(
  495. batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  496. )
  497. else:
  498. k = torch.randn(
  499. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  500. )
  501. v = torch.randn(
  502. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  503. )
  504. query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
  505. key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
  506. # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
  507. if alibi:
  508. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  509. attn_bias = attn_bias_from_alibi_slopes(
  510. alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal
  511. )
  512. else:
  513. alibi_slopes, attn_bias = None, None
  514. if kvpacked:
  515. (
  516. q_unpad,
  517. kv_unpad,
  518. cu_seqlens_q,
  519. cu_seqlens_k,
  520. max_seqlen_q,
  521. max_seqlen_k,
  522. q,
  523. kv,
  524. output_pad_fn,
  525. dq_pad_fn,
  526. dkv_pad_fn,
  527. ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)
  528. out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
  529. q_unpad,
  530. kv_unpad,
  531. cu_seqlens_q,
  532. cu_seqlens_k,
  533. max_seqlen_q,
  534. max_seqlen_k,
  535. dropout_p,
  536. causal=causal,
  537. window_size=window_size,
  538. alibi_slopes=alibi_slopes,
  539. deterministic=deterministic,
  540. return_attn_probs=True,
  541. )
  542. else:
  543. (
  544. q_unpad,
  545. k_unpad,
  546. v_unpad,
  547. cu_seqlens_q,
  548. cu_seqlens_k,
  549. max_seqlen_q,
  550. max_seqlen_k,
  551. q,
  552. k,
  553. v,
  554. output_pad_fn,
  555. dq_pad_fn,
  556. dk_pad_fn,
  557. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
  558. out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
  559. q_unpad,
  560. k_unpad,
  561. v_unpad,
  562. cu_seqlens_q,
  563. cu_seqlens_k,
  564. max_seqlen_q,
  565. max_seqlen_k,
  566. dropout_p,
  567. causal=causal,
  568. window_size=window_size,
  569. alibi_slopes=alibi_slopes,
  570. deterministic=deterministic,
  571. return_attn_probs=True,
  572. )
  573. out = output_pad_fn(out_unpad)
  574. if dropout_p > 0.0:
  575. # TODO - move to c++ mha_varlen_fwd()
  576. S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
  577. S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q, seqlen_k)
  578. S_dmask_converted = convert_flash_attn_S_to_softmax(
  579. S_dmask,
  580. seqlen_q,
  581. seqlen_k,
  582. query_padding_mask,
  583. key_padding_mask,
  584. d,
  585. dropout_p > 0.0,
  586. causal=causal,
  587. window_size=window_size,
  588. )
  589. dropout_mask = S_dmask_converted >= 0
  590. if kvpacked:
  591. kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
  592. k_rep, v_rep = kv_rep.unbind(dim=2)
  593. else:
  594. k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  595. v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  596. # CK does not return P. Hence, we don't test the attn here.
  597. else:
  598. dropout_mask = None
  599. if kvpacked:
  600. out_ref, attn_ref = attention_kvpacked_ref(
  601. q,
  602. kv,
  603. query_padding_mask,
  604. key_padding_mask,
  605. attn_bias,
  606. dropout_p,
  607. dropout_mask,
  608. causal=causal,
  609. window_size=window_size,
  610. )
  611. out_pt, attn_pt = attention_kvpacked_ref(
  612. q,
  613. kv,
  614. query_padding_mask,
  615. key_padding_mask,
  616. attn_bias,
  617. dropout_p,
  618. dropout_mask,
  619. causal=causal,
  620. window_size=window_size,
  621. upcast=False,
  622. reorder_ops=True,
  623. )
  624. else:
  625. out_ref, attn_ref = attention_ref(
  626. q,
  627. k,
  628. v,
  629. query_padding_mask,
  630. key_padding_mask,
  631. attn_bias,
  632. dropout_p,
  633. dropout_mask,
  634. causal=causal,
  635. window_size=window_size,
  636. )
  637. out_pt, attn_pt = attention_ref(
  638. q,
  639. k,
  640. v,
  641. query_padding_mask,
  642. key_padding_mask,
  643. attn_bias,
  644. dropout_p,
  645. dropout_mask,
  646. causal=causal,
  647. window_size=window_size,
  648. upcast=False,
  649. reorder_ops=True,
  650. )
  651. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  652. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  653. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  654. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  655. # Check that FlashAttention's numerical error is at most 4 times the numerical error
  656. # of a Pytorch implementation.
  657. assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item()
  658. g = torch.randn_like(out)
  659. if is_bwd_hdim_supported(d):
  660. if kvpacked:
  661. (
  662. dq_unpad,
  663. dkv_unpad,
  664. ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
  665. dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
  666. (
  667. dq_ref,
  668. dkv_ref,
  669. ) = torch.autograd.grad(out_ref, (q, kv), g)
  670. dk_ref, dv_ref = dkv_ref.unbind(2)
  671. (
  672. dq_pt,
  673. dkv_pt,
  674. ) = torch.autograd.grad(out_pt, (q, kv), g)
  675. dk_pt, dv_pt = dkv_pt.unbind(2)
  676. else:
  677. (
  678. dq_unpad,
  679. dk_unpad,
  680. dv_unpad,
  681. ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
  682. dk = dk_pad_fn(dk_unpad)
  683. dv = dk_pad_fn(dv_unpad)
  684. (
  685. dq_ref,
  686. dk_ref,
  687. dv_ref,
  688. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  689. (
  690. dq_pt,
  691. dk_pt,
  692. dv_pt,
  693. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  694. dq = dq_pad_fn(dq_unpad)
  695. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  696. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  697. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  698. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  699. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  700. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  701. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  702. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  703. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  704. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  705. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  706. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  707. # TODO - use 10 times to check, wait for ck to fix bwd precision issue
  708. assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()
  709. assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()
  710. assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()
  711. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  712. @pytest.mark.parametrize("local", [False, True])
  713. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  714. @pytest.mark.parametrize("swap_sq_sk", [False, True])
  715. @pytest.mark.parametrize(
  716. "seqlen_q,seqlen_k",
  717. [
  718. # (1, 239),
  719. (3, 799),
  720. (127, 512),
  721. (127, 513),
  722. (113, 203),
  723. (128, 217),
  724. (113, 211),
  725. (108, 256),
  726. (256, 512),
  727. (1023, 1024),
  728. ],
  729. )
  730. def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
  731. if max(seqlen_q, seqlen_k) >= 2048:
  732. pytest.skip()
  733. if swap_sq_sk:
  734. seqlen_q, seqlen_k = seqlen_k, seqlen_q
  735. device = "cuda"
  736. causal = True
  737. # set seed
  738. torch.random.manual_seed(0)
  739. batch_size = 8
  740. nheads = 9
  741. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  742. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  743. k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  744. v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  745. out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size)
  746. out_ref, attn_ref = attention_ref(
  747. q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size
  748. )
  749. out_pt, attn_pt = attention_ref(
  750. q,
  751. k,
  752. v,
  753. None,
  754. None,
  755. None,
  756. 0.0,
  757. None,
  758. causal=causal,
  759. window_size=window_size,
  760. upcast=False,
  761. reorder_ops=True,
  762. )
  763. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  764. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  765. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  766. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  767. # Check that FlashAttention's numerical error is at most 4 times the numerical error
  768. # of a Pytorch implementation.
  769. assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item() + 1e-5
  770. g = torch.randn_like(out)
  771. if is_bwd_hdim_supported(d):
  772. do_o = (g.float() * out.float()).sum(-1)
  773. (
  774. dq,
  775. dk,
  776. dv,
  777. ) = torch.autograd.grad(out, (q, k, v), g)
  778. (
  779. dq_ref,
  780. dk_ref,
  781. dv_ref,
  782. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  783. (
  784. dq_pt,
  785. dk_pt,
  786. dv_pt,
  787. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  788. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  789. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  790. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  791. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  792. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  793. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  794. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  795. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  796. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  797. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  798. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  799. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  800. # TODO - use 10 times to check, wait for ck to fix bwd precision issue
  801. assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() + 1e-4
  802. assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() + 1e-4
  803. assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item() + 1e-4
  804. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  805. @pytest.mark.parametrize("local", [False, True])
  806. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  807. @pytest.mark.parametrize("swap_sq_sk", [False, True])
  808. @pytest.mark.parametrize(
  809. "seqlen_q,seqlen_k",
  810. [
  811. # (1, 239),
  812. (3, 799),
  813. (127, 512),
  814. (127, 513),
  815. (113, 203),
  816. (128, 217),
  817. (113, 211),
  818. (108, 256),
  819. (256, 512),
  820. (1023, 1024),
  821. ],
  822. )
  823. # TODO: Support paged_kv_block
  824. # @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
  825. @pytest.mark.parametrize("paged_kv_block_size", [None])
  826. def test_flash_attn_varlen_causal(
  827. seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype
  828. ):
  829. if max(seqlen_q, seqlen_k) >= 2048:
  830. pytest.skip()
  831. if swap_sq_sk:
  832. seqlen_q, seqlen_k = seqlen_k, seqlen_q
  833. device = "cuda"
  834. causal = True
  835. # set seed
  836. torch.random.manual_seed(0)
  837. batch_size = 8
  838. nheads = 9
  839. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  840. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  841. if paged_kv_block_size is None:
  842. k = torch.randn(
  843. batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
  844. )
  845. v = torch.randn(
  846. batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
  847. )
  848. block_table = None
  849. else:
  850. k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(
  851. seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype
  852. )
  853. query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
  854. key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
  855. (
  856. q_unpad,
  857. k_unpad,
  858. v_unpad,
  859. cu_seqlens_q,
  860. cu_seqlens_k,
  861. max_seqlen_q,
  862. max_seqlen_k,
  863. q,
  864. k,
  865. v,
  866. output_pad_fn,
  867. dq_pad_fn,
  868. dk_pad_fn,
  869. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
  870. out_unpad = flash_attn_varlen_func(
  871. q_unpad,
  872. k_unpad if paged_kv_block_size is None else k_cache_paged,
  873. v_unpad if paged_kv_block_size is None else v_cache_paged,
  874. cu_seqlens_q,
  875. cu_seqlens_k,
  876. max_seqlen_q,
  877. max_seqlen_k,
  878. 0.0,
  879. causal=causal,
  880. window_size=window_size,
  881. block_table=block_table,
  882. )
  883. out = output_pad_fn(out_unpad)
  884. out_ref, attn_ref = attention_ref(
  885. q,
  886. k,
  887. v,
  888. query_padding_mask,
  889. key_padding_mask,
  890. None,
  891. 0.0,
  892. None,
  893. causal=causal,
  894. window_size=window_size,
  895. )
  896. out_pt, attn_pt = attention_ref(
  897. q,
  898. k,
  899. v,
  900. query_padding_mask,
  901. key_padding_mask,
  902. None,
  903. 0.0,
  904. None,
  905. causal=causal,
  906. window_size=window_size,
  907. upcast=False,
  908. reorder_ops=True,
  909. )
  910. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  911. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  912. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  913. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  914. # Check that FlashAttention's numerical error is at most twice the numerical error
  915. # of a Pytorch implementation.
  916. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
  917. g = torch.randn_like(out)
  918. if is_bwd_hdim_supported(d):
  919. do_o = (g.float() * out.float()).sum(-1)
  920. test_backward = block_table is None
  921. if test_backward:
  922. (
  923. dq_unpad,
  924. dk_unpad,
  925. dv_unpad,
  926. ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
  927. dq = dq_pad_fn(dq_unpad)
  928. dk = dk_pad_fn(dk_unpad)
  929. dv = dk_pad_fn(dv_unpad)
  930. (
  931. dq_ref,
  932. dk_ref,
  933. dv_ref,
  934. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  935. (
  936. dq_pt,
  937. dk_pt,
  938. dv_pt,
  939. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  940. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  941. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  942. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  943. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  944. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  945. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  946. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  947. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  948. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  949. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  950. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  951. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  952. if test_backward:
  953. # TODO - use 10 times to check, wait for ck to fix bwd precision issue
  954. assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() + 1e-5
  955. assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() + 1e-5
  956. assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item() + 1e-5
  957. # TODO - support splitkv
  958. # def test_flash_attn_splitkv
  959. # TODO - Support has_leftpad
  960. @pytest.mark.parametrize("dtype", [torch.float16])
  961. @pytest.mark.parametrize("num_splits", [1, 0])
  962. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  963. @pytest.mark.parametrize("new_kv", [False, True])
  964. @pytest.mark.parametrize("alibi", [False, True])
  965. @pytest.mark.parametrize("local", [False, True])
  966. @pytest.mark.parametrize("causal", [False, True])
  967. @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
  968. @pytest.mark.parametrize("rotary_interleaved", [False, True])
  969. @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
  970. @pytest.mark.parametrize("paged_kv_block_size", [None, 256])
  971. @pytest.mark.parametrize("has_leftpad", [False])
  972. @pytest.mark.parametrize("has_batch_idx", [False, True])
  973. @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
  974. @pytest.mark.parametrize(
  975. "seqlen_q,seqlen_k",
  976. [
  977. (1, 128),
  978. (1, 339),
  979. (3, 1024),
  980. (64, 800),
  981. (64, 256),
  982. (3, 799),
  983. (64, 2048),
  984. (16, 20000),
  985. (1, 128 * 1024),
  986. (16, 128 * 1024),
  987. (128, 128),
  988. ],
  989. )
  990. def test_flash_attn_kvcache(
  991. seqlen_q,
  992. seqlen_k,
  993. d,
  994. has_batch_idx,
  995. has_leftpad,
  996. paged_kv_block_size,
  997. rotary_fraction,
  998. rotary_interleaved,
  999. seqlen_new_eq_seqlen_q,
  1000. causal,
  1001. local,
  1002. alibi,
  1003. new_kv,
  1004. mha_type,
  1005. num_splits,
  1006. dtype,
  1007. ):
  1008. if seqlen_q > seqlen_k and new_kv:
  1009. pytest.skip()
  1010. if not new_kv and rotary_fraction > 0.0:
  1011. pytest.skip()
  1012. if has_batch_idx and paged_kv_block_size is not None:
  1013. pytest.skip()
  1014. if has_leftpad and paged_kv_block_size is not None:
  1015. pytest.skip()
  1016. device = "cuda"
  1017. # set seed
  1018. torch.random.manual_seed(0)
  1019. batch_size = 1
  1020. batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
  1021. nheads = 6
  1022. # rotary_dim must be a multiple of 16, and must be <= d
  1023. rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
  1024. nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
  1025. assert nheads % nheads_k == 0
  1026. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  1027. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
  1028. seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
  1029. if new_kv:
  1030. k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
  1031. v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
  1032. else:
  1033. k, v = None, None
  1034. if paged_kv_block_size is None:
  1035. k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
  1036. v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
  1037. block_table = None
  1038. else:
  1039. (
  1040. k_cache,
  1041. v_cache,
  1042. block_table,
  1043. k_cache_paged,
  1044. v_cache_paged,
  1045. num_blocks,
  1046. ) = _generate_block_kvcache(
  1047. seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
  1048. )
  1049. cache_seqlens = torch.randint(
  1050. 0 if new_kv else 1,
  1051. # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
  1052. (
  1053. (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
  1054. if new_kv
  1055. else (seqlen_k + 1)
  1056. ),
  1057. (batch_size,),
  1058. dtype=torch.int32,
  1059. device=device,
  1060. )
  1061. if has_leftpad:
  1062. cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)
  1063. if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)
  1064. for i in range(batch_size)])
  1065. else:
  1066. cache_leftpad = None
  1067. arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
  1068. cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
  1069. key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
  1070. if has_leftpad:
  1071. key_padding_mask = torch.logical_and(
  1072. key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)
  1073. )
  1074. if has_batch_idx:
  1075. cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
  1076. :batch_size
  1077. ]
  1078. else:
  1079. cache_batch_idx = None
  1080. if alibi:
  1081. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  1082. attn_bias = attn_bias_from_alibi_slopes(
  1083. alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad
  1084. )
  1085. else:
  1086. alibi_slopes, attn_bias = None, None
  1087. # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
  1088. if rotary_dim > 0:
  1089. angle = (
  1090. torch.rand(
  1091. seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size,
  1092. rotary_dim // 2,
  1093. device=device,
  1094. )
  1095. * 2
  1096. * math.pi
  1097. )
  1098. cos = torch.cos(angle).to(dtype=dtype)
  1099. sin = torch.sin(angle).to(dtype=dtype)
  1100. if causal or local:
  1101. q_ro = apply_rotary_emb(
  1102. q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
  1103. )
  1104. else:
  1105. q_ro = rearrange(
  1106. apply_rotary_emb(
  1107. rearrange(q, "b s h d -> b 1 (s h) d"),
  1108. cos,
  1109. sin,
  1110. seqlen_offsets=cache_seqlens,
  1111. interleaved=rotary_interleaved,
  1112. ),
  1113. "b 1 (s h) d -> b s h d",
  1114. s=seqlen_q,
  1115. )
  1116. # q_ro = q
  1117. k_ro = apply_rotary_emb(
  1118. k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
  1119. )
  1120. else:
  1121. cos, sin = None, None
  1122. q_ro, k_ro = q, k
  1123. # k_cache[:, 64:] = -1
  1124. k_cache_ref = (
  1125. k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
  1126. ).clone()
  1127. v_cache_ref = (
  1128. v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
  1129. ).clone()
  1130. if new_kv:
  1131. update_mask = torch.logical_and(
  1132. cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
  1133. )
  1134. k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
  1135. v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
  1136. k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  1137. v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  1138. out = flash_attn_with_kvcache(
  1139. q,
  1140. k_cache if paged_kv_block_size is None else k_cache_paged,
  1141. v_cache if paged_kv_block_size is None else v_cache_paged,
  1142. k,
  1143. v,
  1144. rotary_cos=cos,
  1145. rotary_sin=sin,
  1146. cache_seqlens=cache_seqlens,
  1147. cache_batch_idx=cache_batch_idx,
  1148. cache_leftpad=cache_leftpad,
  1149. block_table=block_table,
  1150. causal=causal,
  1151. window_size=window_size,
  1152. rotary_interleaved=rotary_interleaved,
  1153. alibi_slopes=alibi_slopes,
  1154. num_splits=num_splits,
  1155. )
  1156. # out = flash_attn_with_kvcache(
  1157. # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
  1158. # )
  1159. # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
  1160. # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
  1161. # m = qk.amax(-1, keepdim=True)
  1162. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  1163. # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
  1164. # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
  1165. # probs = torch.softmax(qk, dim=-1)
  1166. out_ref, _ = attention_ref(
  1167. q_ro,
  1168. k_cache_rep,
  1169. v_cache_rep,
  1170. None,
  1171. key_padding_mask,
  1172. attn_bias,
  1173. 0.0,
  1174. None,
  1175. causal=causal,
  1176. window_size=window_size,
  1177. key_leftpad=cache_leftpad,
  1178. )
  1179. out_pt, _ = attention_ref(
  1180. q_ro,
  1181. k_cache_rep,
  1182. v_cache_rep,
  1183. None,
  1184. key_padding_mask,
  1185. attn_bias,
  1186. 0.0,
  1187. None,
  1188. causal=causal,
  1189. window_size=window_size,
  1190. upcast=False,
  1191. reorder_ops=True,
  1192. key_leftpad=cache_leftpad,
  1193. )
  1194. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  1195. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  1196. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  1197. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  1198. # Check that FlashAttention's numerical error is at most twice the numerical error
  1199. # of a Pytorch implementation.
  1200. if new_kv:
  1201. if paged_kv_block_size is None:
  1202. k_cache_select = (
  1203. k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
  1204. )
  1205. v_cache_select = (
  1206. v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
  1207. )
  1208. else:
  1209. k_cache_select = rearrange(
  1210. k_cache_paged[block_table.to(dtype=torch.long).flatten()],
  1211. "(b nblocks) block_size ... -> b (nblocks block_size) ...",
  1212. b=batch_size,
  1213. )[:, :seqlen_k]
  1214. v_cache_select = rearrange(
  1215. v_cache_paged[block_table.to(dtype=torch.long).flatten()],
  1216. "(b nblocks) block_size ... -> b (nblocks block_size) ...",
  1217. b=batch_size,
  1218. )[:, :seqlen_k]
  1219. assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
  1220. assert torch.equal(v_cache_select, v_cache_ref)
  1221. # mult = 3 if f16, bf16 need 4
  1222. mult = 4 if not alibi else 5
  1223. assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
  1224. @pytest.mark.parametrize("dtype", [torch.float16])
  1225. @pytest.mark.parametrize("causal", [False, True])
  1226. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  1227. @pytest.mark.parametrize(
  1228. "seqlen_q,seqlen_k",
  1229. [
  1230. (1, 239),
  1231. (239, 1),
  1232. (3, 799),
  1233. (799, 3),
  1234. (1024, 128),
  1235. (97, 97),
  1236. (128, 128),
  1237. (200, 200),
  1238. (256, 256),
  1239. (257, 257),
  1240. (384, 384),
  1241. (512, 512),
  1242. (768, 768),
  1243. # (1024, 1024),
  1244. ],
  1245. )
  1246. @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  1247. def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):
  1248. device = "cuda"
  1249. # set seed
  1250. torch.random.manual_seed(0)
  1251. batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger
  1252. nheads = 4
  1253. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1254. k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1255. v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1256. torch.random.manual_seed(42)
  1257. out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
  1258. g = torch.randn_like(out0)
  1259. if dropout_p == 0 and is_bwd_hdim_supported(d):
  1260. (
  1261. dq0,
  1262. dk0,
  1263. dv0,
  1264. ) = torch.autograd.grad(out0, (q, k, v), g)
  1265. # Numerical error if we just do any arithmetic on dq
  1266. dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()
  1267. for i in range(250):
  1268. torch.random.manual_seed(42)
  1269. out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
  1270. assert torch.equal(out, out0)
  1271. assert torch.equal(lse, lse0)
  1272. if dropout_p == 0:
  1273. (
  1274. dq,
  1275. dk,
  1276. dv,
  1277. ) = torch.autograd.grad(out, (q, k, v), g)
  1278. dq_equal = torch.allclose(dq, dq0, atol=dq_atol)
  1279. if not dq_equal:
  1280. print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}")
  1281. assert torch.equal(dv, dv0)
  1282. assert torch.equal(dk, dk0)
  1283. assert dq_equal
  1284. @pytest.mark.parametrize("dtype", [torch.float16])
  1285. @pytest.mark.parametrize("causal", [False, True])
  1286. @pytest.mark.parametrize("d", [16, 32, 64])
  1287. @pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128])
  1288. def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
  1289. """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
  1290. in the case where seqlen % 128 != 0.
  1291. """
  1292. # TODO - 1 or 2 might fail, need to check
  1293. if seqlen == 1 or seqlen == 2:
  1294. pytest.skip()
  1295. device = "cuda"
  1296. # set seed
  1297. torch.random.manual_seed(0)
  1298. batch_size = 2
  1299. nheads = 5
  1300. q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5
  1301. k, v = [
  1302. torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3
  1303. for _ in range(2)
  1304. ]
  1305. q.requires_grad_(True)
  1306. k.requires_grad_(True)
  1307. v.requires_grad_(True)
  1308. out = flash_attn_func(q, k, v, causal=causal)
  1309. g = torch.randn_like(out)
  1310. out.backward(g)
  1311. q_pt = q.detach().clone().requires_grad_(True)
  1312. k_pt = k.detach().clone().requires_grad_(True)
  1313. v_pt = v.detach().clone().requires_grad_(True)
  1314. out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
  1315. out_pt.backward(g)
  1316. q_ref = q.detach().clone().requires_grad_(True)
  1317. k_ref = k.detach().clone().requires_grad_(True)
  1318. v_ref = v.detach().clone().requires_grad_(True)
  1319. out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
  1320. out_ref.backward(g)
  1321. print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
  1322. print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
  1323. print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
  1324. print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
  1325. print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
  1326. print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
  1327. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  1328. assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (
  1329. q_pt.grad - q_ref.grad
  1330. ).abs().max().item() + 1e-3
  1331. assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (
  1332. k_pt.grad - k_ref.grad
  1333. ).abs().max().item() + 1e-3
  1334. assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (
  1335. v_pt.grad - v_ref.grad
  1336. ).abs().max().item() + 1e-3
  1337. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  1338. @pytest.mark.parametrize("causal", [False, True])
  1339. @pytest.mark.parametrize("d", [64, 128])
  1340. @pytest.mark.parametrize("seqlen", [97, 128, 200, 256])
  1341. def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
  1342. """We previously had a bug where we were using the wrong strides of dout, which shows up
  1343. when dout is not contiguous.
  1344. """
  1345. device = "cuda"
  1346. # set seed
  1347. torch.random.manual_seed(0)
  1348. batch_size = 5
  1349. nheads = 2
  1350. q, k, v = [
  1351. torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True)
  1352. for _ in range(3)
  1353. ]
  1354. out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...")
  1355. # So g is not contiguous
  1356. g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2]
  1357. out.backward(g)
  1358. q_pt = q.detach().clone().requires_grad_(True)
  1359. k_pt = k.detach().clone().requires_grad_(True)
  1360. v_pt = v.detach().clone().requires_grad_(True)
  1361. out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
  1362. out_pt = rearrange(out_pt, "b s ... -> s b ...")
  1363. out_pt.backward(g)
  1364. q_ref = q.detach().clone().requires_grad_(True)
  1365. k_ref = k.detach().clone().requires_grad_(True)
  1366. v_ref = v.detach().clone().requires_grad_(True)
  1367. out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
  1368. out_ref = rearrange(out_ref, "b s ... -> s b ...")
  1369. out_ref.backward(g)
  1370. print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
  1371. print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
  1372. print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
  1373. print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
  1374. print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
  1375. print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
  1376. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  1377. assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (
  1378. q_pt.grad - q_ref.grad
  1379. ).abs().max().item()
  1380. assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (
  1381. k_pt.grad - k_ref.grad
  1382. ).abs().max().item()
  1383. assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (
  1384. v_pt.grad - v_ref.grad
  1385. ).abs().max().item()
  1386. @pytest.mark.parametrize("dtype", [torch.float16])
  1387. @pytest.mark.parametrize("causal", [False, True])
  1388. @pytest.mark.parametrize("d", [16, 32, 64])
  1389. def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
  1390. """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
  1391. in the case where seqlen % 128 != 0 or varlen.
  1392. """
  1393. device = "cuda"
  1394. # set seed
  1395. torch.random.manual_seed(0)
  1396. nheads = 5
  1397. q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32)
  1398. k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32)
  1399. Mq = 256
  1400. Mk = 3
  1401. q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3
  1402. k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)]
  1403. q.requires_grad_(True)
  1404. k.requires_grad_(True)
  1405. v.requires_grad_(True)
  1406. out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal)
  1407. g = torch.randn_like(out)
  1408. out.backward(g)
  1409. assert not q.grad.isnan().any()
  1410. assert not k.grad.isnan().any()
  1411. assert not v.grad.isnan().any()
  1412. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  1413. @pytest.mark.parametrize("local", [False, True])
  1414. @pytest.mark.parametrize("causal", [False, True])
  1415. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  1416. @pytest.mark.parametrize("swap_sq_sk", [False, True])
  1417. @pytest.mark.parametrize(
  1418. "seqlen_q,seqlen_k",
  1419. [
  1420. (1, 239),
  1421. (3, 799),
  1422. (127, 512),
  1423. (127, 513),
  1424. (113, 203),
  1425. (128, 217),
  1426. (113, 211),
  1427. (108, 256),
  1428. (256, 512),
  1429. (1023, 1024),
  1430. ],
  1431. )
  1432. def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
  1433. if (
  1434. max(seqlen_q, seqlen_k) >= 2048
  1435. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  1436. ):
  1437. pytest.skip() # Reference implementation OOM
  1438. if swap_sq_sk:
  1439. seqlen_q, seqlen_k = seqlen_k, seqlen_q
  1440. device = "cuda"
  1441. # set seed
  1442. torch.random.manual_seed(0)
  1443. batch_size = 4
  1444. nheads = 9
  1445. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  1446. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1447. k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1448. v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1449. out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
  1450. g = torch.randn_like(out)
  1451. dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
  1452. for _ in range(50):
  1453. dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
  1454. assert torch.equal(dv, dv0)
  1455. assert torch.equal(dk, dk0)
  1456. assert torch.equal(dq, dq0)
  1457. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  1458. @pytest.mark.parametrize("local", [False, True])
  1459. @pytest.mark.parametrize("causal", [False, True])
  1460. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  1461. @pytest.mark.parametrize("swap_sq_sk", [False, True])
  1462. @pytest.mark.parametrize(
  1463. "seqlen_q,seqlen_k",
  1464. [
  1465. (1, 239),
  1466. (3, 799),
  1467. (127, 512),
  1468. (127, 513),
  1469. (113, 203),
  1470. (128, 217),
  1471. (113, 211),
  1472. (108, 256),
  1473. (256, 512),
  1474. (1023, 1024),
  1475. ],
  1476. )
  1477. def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
  1478. if (
  1479. max(seqlen_q, seqlen_k) >= 2048
  1480. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  1481. ):
  1482. pytest.skip() # Reference implementation OOM
  1483. if swap_sq_sk:
  1484. seqlen_q, seqlen_k = seqlen_k, seqlen_q
  1485. device = "cuda"
  1486. # set seed
  1487. torch.random.manual_seed(0)
  1488. batch_size = 2
  1489. nheads = 9
  1490. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  1491. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1492. k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1493. v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1494. query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
  1495. key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
  1496. (
  1497. q_unpad,
  1498. k_unpad,
  1499. v_unpad,
  1500. cu_seqlens_q,
  1501. cu_seqlens_k,
  1502. max_seqlen_q,
  1503. max_seqlen_k,
  1504. q,
  1505. k,
  1506. v,
  1507. output_pad_fn,
  1508. dq_pad_fn,
  1509. dk_pad_fn,
  1510. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
  1511. out = flash_attn_varlen_func(
  1512. q_unpad,
  1513. k_unpad,
  1514. v_unpad,
  1515. cu_seqlens_q,
  1516. cu_seqlens_k,
  1517. max_seqlen_q,
  1518. max_seqlen_k,
  1519. 0.0,
  1520. causal=causal,
  1521. window_size=window_size,
  1522. deterministic=True,
  1523. )
  1524. g = torch.randn_like(out)
  1525. dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
  1526. for _ in range(50):
  1527. dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
  1528. assert torch.equal(dv, dv0)
  1529. assert torch.equal(dk, dk0)
  1530. assert torch.equal(dq, dq0)