test_flash_attn.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031
  1. import os
  2. import math
  3. import itertools
  4. import pytest
  5. import torch
  6. import torch.nn.functional as F
  7. from einops import rearrange, repeat
  8. from flash_attn.layers.rotary import apply_rotary_emb
  9. from padding import pad_input, unpad_input
  10. from test_util import (
  11. attention_ref,
  12. generate_qkv,
  13. generate_random_padding_mask,
  14. )
  15. from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine, flash_attn_with_kvcache
  16. DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE"
  17. DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE"
  18. DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE"
  19. DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE"
  20. DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE"
  21. DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE"
  22. DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE"
  23. DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE"
  24. DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE"
  25. DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE"
  26. DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE"
  27. DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE"
  28. DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE"
  29. DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE"
  30. COMPILED_HDIMS = (
  31. []
  32. + ([64] if not DISABLE_HDIM64 else [])
  33. + ([96] if not DISABLE_HDIM96 else [])
  34. + ([128] if not DISABLE_HDIM128 else [])
  35. + ([192] if not DISABLE_HDIM192 else [])
  36. + ([256] if not DISABLE_HDIM256 else [])
  37. )
  38. # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
  39. @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))
  40. # @pytest.mark.parametrize("dtype", [torch.bfloat16])
  41. # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
  42. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  43. # @pytest.mark.parametrize("mha_type", ["mha"])
  44. # @pytest.mark.parametrize("deterministic", [False, True])
  45. @pytest.mark.parametrize("deterministic", [False])
  46. @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))
  47. # @pytest.mark.parametrize("softcap", [0.0])
  48. @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else []))
  49. # @pytest.mark.parametrize("local", [False])
  50. @pytest.mark.parametrize("causal", [False, True])
  51. # @pytest.mark.parametrize("causal", [False])
  52. # @pytest.mark.parametrize("V_colmajor", [False, True])
  53. @pytest.mark.parametrize("V_colmajor", [False])
  54. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  55. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
  56. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  57. # @pytest.mark.parametrize('d', [56, 80])
  58. # @pytest.mark.parametrize("d", [64, 128, 256])
  59. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
  60. # @pytest.mark.parametrize("d", [64, 96, 128, 192])
  61. @pytest.mark.parametrize("d", COMPILED_HDIMS)
  62. # @pytest.mark.parametrize("d", [128])
  63. @pytest.mark.parametrize(
  64. "seqlen_q,seqlen_k",
  65. [
  66. (1, 1),
  67. (64, 128),
  68. (128, 192),
  69. (256, 256),
  70. (239, 1),
  71. (799, 3),
  72. (113, 203),
  73. (113, 128),
  74. (128, 217),
  75. (113, 211),
  76. (108, 256),
  77. (256, 512),
  78. (384, 256),
  79. (640, 128),
  80. (512, 256),
  81. (1024, 1024),
  82. (1023, 1024),
  83. (1024, 1023),
  84. (4096, 4096),
  85. (4224, 4224),
  86. ],
  87. )
  88. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  89. def test_flash_attn_output(
  90. seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, mha_type, dtype
  91. ):
  92. # sink_token_length = 0 if not local else 4
  93. sink_token_length = 0 if not local else 0
  94. if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn):
  95. pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn")
  96. device = "cuda"
  97. # set seed
  98. torch.random.manual_seed(0)
  99. # batch_size = 40
  100. # nheads = 16
  101. batch_size = 9 if seqlen_k <= 2048 else 2
  102. # batch_size = 1
  103. nheads = 6
  104. # nheads = 1
  105. nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  106. dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
  107. q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref)
  108. if softcap > 0.0:
  109. # Ensure the values of qk are at least within softcap range.
  110. q_ref = (q_ref * softcap / 4)
  111. q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()
  112. k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()
  113. v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()
  114. # Put window_size after QKV randn so that window_size changes from test to test
  115. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  116. # window_size = (-1, -1) if not local else (16, 0)
  117. if dtype == torch.float8_e4m3fn:
  118. q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)]
  119. else:
  120. q_descale, k_descale, v_descale = None, None, None
  121. q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)]
  122. if V_colmajor:
  123. v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_()
  124. out_ref, attn_ref = attention_ref(
  125. q_ref,
  126. k_ref,
  127. v_ref,
  128. None,
  129. None,
  130. causal=causal,
  131. q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
  132. window_size=window_size,
  133. sink_token_length=sink_token_length,
  134. softcap=softcap
  135. )
  136. out_pt, attn_pt = attention_ref(
  137. q_ref,
  138. k_ref,
  139. v_ref,
  140. None,
  141. None,
  142. causal=causal,
  143. q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
  144. window_size=window_size,
  145. sink_token_length=sink_token_length,
  146. softcap=softcap,
  147. upcast=False,
  148. reorder_ops=True,
  149. intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
  150. )
  151. # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float()
  152. # m = qk.amax(-1, keepdim=True)
  153. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  154. # exp_sum = s_tmp.sum(-1)
  155. # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float())
  156. # lse_ref = torch.logsumexp(qk, dim=-1)
  157. abs_tol = 1e-4 if softcap == 0.0 else 5e-4
  158. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  159. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  160. pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
  161. num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
  162. for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
  163. out, lse = flash_attn_func(
  164. q,
  165. k,
  166. v,
  167. causal=causal,
  168. q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
  169. window_size=window_size,
  170. sink_token_length=sink_token_length,
  171. softcap=softcap,
  172. pack_gqa=pack_gqa,
  173. num_splits=num_splits
  174. )
  175. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  176. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  177. # if not causal:
  178. # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
  179. # breakpoint()
  180. # Check that FlashAttention's numerical error is at most twice the numerical error
  181. # of a Pytorch implementation.
  182. multiple = 2 if dtype != torch.float8_e4m3fn else 3
  183. assert (out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item() + abs_tol
  184. if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor:
  185. g = torch.randn_like(out)
  186. do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)
  187. # import flash_attn_3_cuda
  188. # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd(
  189. # g,
  190. # q,
  191. # k,
  192. # v,
  193. # out,
  194. # lse,
  195. # None,
  196. # None,
  197. # None,
  198. # d ** (-0.5),
  199. # causal,
  200. # window_size[0], window_size[1],
  201. # sink_token_length,
  202. # softcap,
  203. # deterministic,
  204. # 0, # sm_margin
  205. # )
  206. dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
  207. # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
  208. # assert (softmax_d - do_o).abs().max().item() <= 1e-5
  209. # assert dq_accum.abs().max().item() == 0.0
  210. # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
  211. # P = torch.softmax(qk, -1)
  212. # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1))
  213. # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
  214. # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
  215. # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
  216. # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
  217. dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)
  218. dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)
  219. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  220. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  221. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  222. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  223. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  224. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  225. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  226. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  227. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  228. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  229. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  230. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  231. # breakpoint()
  232. if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor:
  233. multiple = 2
  234. assert (dq - dq_ref).abs().max().item() <= multiple * (dq_pt - dq_ref).abs().max().item() + abs_tol
  235. assert (dk - dk_ref).abs().max().item() <= multiple * (dk_pt - dk_ref).abs().max().item() + abs_tol
  236. assert (dv - dv_ref).abs().max().item() <= multiple * (dv_pt - dv_ref).abs().max().item() + abs_tol
  237. # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
  238. @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))
  239. # @pytest.mark.parametrize("dtype", [torch.bfloat16])
  240. # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
  241. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  242. # @pytest.mark.parametrize("mha_type", ["mha"])
  243. # @pytest.mark.parametrize("deterministic", [False, True])
  244. @pytest.mark.parametrize("deterministic", [False])
  245. @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))
  246. # @pytest.mark.parametrize("softcap", [0.0])
  247. @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else []))
  248. # @pytest.mark.parametrize("local", [False])
  249. @pytest.mark.parametrize("causal", [False, True])
  250. # @pytest.mark.parametrize("causal", [False])
  251. @pytest.mark.parametrize("add_unused_qkv", [False, True])
  252. # @pytest.mark.parametrize("add_unused_qkv", [True])
  253. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  254. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
  255. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  256. # @pytest.mark.parametrize('d', [56, 80])
  257. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
  258. # @pytest.mark.parametrize("d", [64, 96, 128])
  259. @pytest.mark.parametrize("d", COMPILED_HDIMS)
  260. # @pytest.mark.parametrize("d", [128])
  261. @pytest.mark.parametrize(
  262. "seqlen_q,seqlen_k",
  263. [
  264. (1, 1),
  265. (1, 3),
  266. (2, 1),
  267. (511, 1),
  268. (3, 513),
  269. (64, 128),
  270. (128, 128),
  271. (256, 256),
  272. (113, 203),
  273. (128, 217),
  274. (113, 211),
  275. (108, 256),
  276. (256, 512),
  277. (307, 256),
  278. (640, 128),
  279. (512, 256),
  280. (1024, 1024),
  281. (1023, 1024),
  282. (1024, 1023),
  283. (2048, 2048),
  284. ],
  285. )
  286. def test_flash_attn_varlen_output(
  287. seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, mha_type, dtype
  288. ):
  289. device = "cuda"
  290. # set seed
  291. torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))
  292. # batch_size = 40
  293. # nheads = 16
  294. batch_size = 9 if seqlen_q <= 2048 else 2
  295. nheads = 6
  296. # batch_size = 2
  297. # nheads = 1
  298. nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  299. dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
  300. q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref)
  301. if softcap > 0.0:
  302. # Ensure the values of qk are at least within softcap range.
  303. q_ref = (q_ref * softcap / 4).detach().requires_grad_()
  304. q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()
  305. k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()
  306. v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()
  307. # Put window_size after QKV randn so that window_size changes from test to test
  308. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  309. if dtype == torch.float8_e4m3fn:
  310. q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)]
  311. else:
  312. q_descale, k_descale, v_descale = None, None, None
  313. q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]
  314. query_padding_mask = generate_random_padding_mask(
  315. seqlen_q, batch_size, device, mode="random", zero_lengths=False
  316. )
  317. key_padding_mask = generate_random_padding_mask(
  318. seqlen_k, batch_size, device, mode="random", zero_lengths=True
  319. )
  320. def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
  321. if add_unused:
  322. another_mask = generate_random_padding_mask(max_seq_len, bs, device)
  323. attn_mask = torch.logical_and(padding_mask, another_mask)
  324. unused_mask = torch.logical_xor(
  325. torch.logical_or(padding_mask, another_mask), attn_mask
  326. )
  327. else:
  328. attn_mask = padding_mask
  329. unused_mask = None
  330. return attn_mask, unused_mask
  331. query_padding_mask, query_unused_mask = _gen_unused_masks(
  332. query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device
  333. )
  334. key_padding_mask, key_unused_mask = _gen_unused_masks(
  335. key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device
  336. )
  337. (
  338. q_unpad,
  339. k_unpad,
  340. v_unpad,
  341. cu_seqlens_q,
  342. cu_seqlens_k,
  343. seqused_q,
  344. seqused_k,
  345. max_seqlen_q,
  346. max_seqlen_k,
  347. q,
  348. k,
  349. v,
  350. output_pad_fn,
  351. dq_pad_fn,
  352. dk_pad_fn,
  353. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False,
  354. query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask)
  355. q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)]
  356. out_ref, attn_ref = attention_ref(
  357. q_ref,
  358. k_ref,
  359. v_ref,
  360. query_padding_mask,
  361. key_padding_mask,
  362. causal=causal,
  363. q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
  364. window_size=window_size,
  365. softcap=softcap
  366. )
  367. out_pt, attn_pt = attention_ref(
  368. q_ref,
  369. k_ref,
  370. v_ref,
  371. query_padding_mask,
  372. key_padding_mask,
  373. causal=causal,
  374. q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
  375. window_size=window_size,
  376. softcap=softcap,
  377. upcast=False,
  378. reorder_ops=True,
  379. intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
  380. )
  381. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  382. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  383. if query_unused_mask is not None:
  384. q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
  385. # Numerical error if we just do any arithmetic on out_ref
  386. fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
  387. rel_tol = 2 if softcap == 0.0 else 3
  388. pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
  389. num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
  390. for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
  391. out_unpad, lse = flash_attn_varlen_func(
  392. q_unpad,
  393. k_unpad,
  394. v_unpad,
  395. cu_seqlens_q,
  396. cu_seqlens_k,
  397. seqused_q, seqused_k,
  398. max_seqlen_q,
  399. max_seqlen_k,
  400. causal=causal,
  401. q_descale=q_descale,
  402. k_descale=k_descale, v_descale=v_descale,
  403. window_size=window_size,
  404. softcap=softcap,
  405. )
  406. out = output_pad_fn(out_unpad)
  407. if query_unused_mask is not None:
  408. out.masked_fill_(q_zero_masking, 0.0)
  409. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  410. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  411. # if not causal:
  412. # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
  413. # breakpoint()
  414. # Check that FlashAttention's numerical error is at most 3x the numerical error
  415. # of a Pytorch implementation.
  416. assert (out - out_ref).abs().max().item() <= rel_tol * (out_pt - out_ref).abs().max().item() + fwd_atol
  417. if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn:
  418. g_unpad = torch.randn_like(out_unpad)
  419. do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
  420. # import flash_attn_3_cuda
  421. # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen(
  422. # g_unpad,
  423. # q_unpad,
  424. # k_unpad,
  425. # v_unpad,
  426. # out_unpad,
  427. # lse,
  428. # None,
  429. # None,
  430. # None,
  431. # cu_seqlens_q,
  432. # cu_seqlens_k,
  433. # None, None,
  434. # max_seqlen_q,
  435. # max_seqlen_k,
  436. # d ** (-0.5),
  437. # causal,
  438. # window_size[0], window_size[1],
  439. # softcap,
  440. # deterministic,
  441. # 0, # sm_margin
  442. # )
  443. dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad)
  444. dq = dq_pad_fn(dq_unpad)
  445. dk = dk_pad_fn(dk_unpad)
  446. dv = dk_pad_fn(dv_unpad)
  447. if key_unused_mask is not None:
  448. k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
  449. dk.masked_fill_(k_zero_masking, 0.0)
  450. dv.masked_fill_(k_zero_masking, 0.0)
  451. if query_unused_mask is not None:
  452. dq.masked_fill_(q_zero_masking, 0.0)
  453. # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
  454. # assert (softmax_d - do_o).abs().max().item() <= 1e-5
  455. # assert dq_accum.abs().max().item() == 0.0
  456. g = output_pad_fn(g_unpad)
  457. # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()
  458. # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
  459. # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
  460. # P = torch.softmax(qk, -1)
  461. # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1))
  462. # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
  463. # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
  464. # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
  465. # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
  466. dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)
  467. dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)
  468. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  469. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  470. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  471. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  472. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  473. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  474. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  475. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  476. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  477. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  478. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  479. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  480. # breakpoint()
  481. if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn:
  482. dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)
  483. assert (dq - dq_ref).abs().max().item() <= rel_tol * (dq_pt - dq_ref).abs().max().item() + dq_atol
  484. dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)
  485. assert (dk - dk_ref).abs().max().item() <= rel_tol * (dk_pt - dk_ref).abs().max().item() + dk_atol
  486. dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)
  487. assert (dv - dv_ref).abs().max().item() <= rel_tol * (dv_pt - dv_ref).abs().max().item() + dv_atol
  488. # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
  489. @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))
  490. # @pytest.mark.parametrize("dtype", [torch.bfloat16])
  491. # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
  492. @pytest.mark.parametrize("num_splits", [1] + ([0] if not DISABLE_SPLIT else []))
  493. # @pytest.mark.parametrize("num_splits", [1])
  494. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  495. # @pytest.mark.parametrize("mha_type", ["mha"])
  496. @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else []))
  497. # @pytest.mark.parametrize("new_kv", [True])
  498. # @pytest.mark.parametrize("local", [False, True])
  499. @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []))
  500. # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)])
  501. # @pytest.mark.parametrize("causal,local", [(False, False)])
  502. @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True])
  503. # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
  504. @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False])
  505. # @pytest.mark.parametrize("rotary_interleaved", [True])
  506. @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if not DISABLE_APPENDKV else [0.0])
  507. # @pytest.mark.parametrize("rotary_fraction", [0.0])
  508. @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else []))
  509. # @pytest.mark.parametrize("page_size", [None])
  510. @pytest.mark.parametrize("has_leftpad", [False, True])
  511. # @pytest.mark.parametrize("has_leftpad", [False])
  512. @pytest.mark.parametrize("has_batch_idx", [False, True])
  513. # @pytest.mark.parametrize("has_batch_idx", [False])
  514. @pytest.mark.parametrize("varlen_q", [False, True])
  515. # @pytest.mark.parametrize("varlen_q", [True])
  516. # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
  517. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  518. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  519. # @pytest.mark.parametrize('d', [56, 80])
  520. @pytest.mark.parametrize("d", [128])
  521. @pytest.mark.parametrize(
  522. "seqlen_q,seqlen_k",
  523. [
  524. (1, 128),
  525. (1, 339),
  526. (3, 1024),
  527. (64, 800),
  528. (64, 256),
  529. (3, 799),
  530. (64, 2048),
  531. (16, 20000),
  532. (1, 128 * 1024),
  533. (16, 128 * 1024),
  534. (128, 128),
  535. (256, 512), # To test appending KV with more than 1 block
  536. (2048, 3577), # Enough tile to test persistent scheduler
  537. ],
  538. )
  539. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
  540. def test_flash_attn_kvcache(
  541. seqlen_q,
  542. seqlen_k,
  543. d,
  544. varlen_q,
  545. has_batch_idx,
  546. has_leftpad,
  547. page_size,
  548. rotary_fraction,
  549. rotary_interleaved,
  550. seqlen_new_eq_seqlen_q,
  551. causal,
  552. local,
  553. new_kv,
  554. mha_type,
  555. num_splits,
  556. dtype,
  557. ):
  558. if page_size is not None and seqlen_k % page_size != 0:
  559. pytest.skip()
  560. if seqlen_q > seqlen_k and new_kv:
  561. pytest.skip()
  562. if not new_kv and rotary_fraction > 0.0:
  563. pytest.skip()
  564. device = "cuda"
  565. # set seed
  566. torch.random.manual_seed(0)
  567. batch_size = 5
  568. # batch_size = 1
  569. batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
  570. nheads = 6
  571. # nheads = 1
  572. # rotary_dim must be a multiple of 16, and must be <= d
  573. rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
  574. nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
  575. assert nheads % nheads_k == 0
  576. dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
  577. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)
  578. if varlen_q:
  579. query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
  580. q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask)
  581. output_pad_fn = lambda output_unpad: pad_input(
  582. output_unpad, indices_q, batch_size, seqlen_q
  583. )
  584. else:
  585. query_padding_mask = None
  586. q_unpad = q
  587. cu_seqlens_q, max_seqlen_q = None, None
  588. # Put window_size after QKV randn so that window_size changes from test to test
  589. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  590. seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
  591. cu_seqlens_k_new = None
  592. key_new_padding_mask = None
  593. if new_kv:
  594. k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)
  595. v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)
  596. if varlen_q: # k & v are also varlen
  597. key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random")
  598. k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask)
  599. v_unpad, *rest = unpad_input(v, key_new_padding_mask)
  600. else:
  601. k_unpad, v_unpad = k, v
  602. else:
  603. k, v, k_unpad, v_unpad = None, None, None, None
  604. if page_size is None:
  605. k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)
  606. v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)
  607. page_table = None
  608. else:
  609. (
  610. k_cache,
  611. v_cache,
  612. page_table,
  613. k_cache_paged,
  614. v_cache_paged,
  615. num_blocks,
  616. ) = _generate_block_kvcache(
  617. seqlen_k, page_size, batch_size_cache, nheads_k, d, device, dtype_ref
  618. )
  619. cache_seqlens = torch.randint(
  620. 0 if new_kv else 1,
  621. # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
  622. (
  623. (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
  624. if new_kv
  625. else (seqlen_k + 1)
  626. ),
  627. (batch_size,),
  628. dtype=torch.int32,
  629. device=device,
  630. )
  631. if has_leftpad:
  632. cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)
  633. if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)
  634. for i in range(batch_size)])
  635. else:
  636. cache_leftpad = None
  637. if has_batch_idx:
  638. cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
  639. :batch_size
  640. ]
  641. else:
  642. cache_batch_idx = None
  643. arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
  644. cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
  645. if not new_kv:
  646. key_padding_mask = arange < cache_seqlens_expanded
  647. else:
  648. k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new
  649. key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens
  650. if has_leftpad:
  651. key_padding_mask = torch.logical_and(
  652. key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)
  653. )
  654. # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
  655. if rotary_dim > 0:
  656. angle = (
  657. torch.rand(
  658. seqlen_k if page_size is None else num_blocks * page_size,
  659. rotary_dim // 2,
  660. device=device,
  661. )
  662. * 2
  663. * math.pi
  664. )
  665. cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)
  666. sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)
  667. if causal or local:
  668. q_ro = apply_rotary_emb(
  669. q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
  670. )
  671. else:
  672. q_ro = rearrange(
  673. apply_rotary_emb(
  674. rearrange(q, "b s h d -> b 1 (s h) d"),
  675. cos,
  676. sin,
  677. seqlen_offsets=cache_seqlens,
  678. interleaved=rotary_interleaved,
  679. ),
  680. "b 1 (s h) d -> b s h d",
  681. s=seqlen_q,
  682. )
  683. # q_ro = q
  684. k_ro = apply_rotary_emb(
  685. k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
  686. )
  687. else:
  688. cos, sin = None, None
  689. q_ro, k_ro = q, k
  690. # k_cache[:, 64:] = -1
  691. k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
  692. v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
  693. if new_kv:
  694. update_mask = torch.logical_and(
  695. cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens
  696. )
  697. k_to_update = rearrange(k_ro, "b s ... -> (b s) ...")
  698. v_to_update = rearrange(v, "b s ... -> (b s) ...")
  699. if varlen_q:
  700. k_to_update = k_to_update[indices_k]
  701. v_to_update = v_to_update[indices_k]
  702. k_cache_ref[update_mask] = k_to_update
  703. v_cache_ref[update_mask] = v_to_update
  704. k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  705. v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  706. out_ref, _ = attention_ref(
  707. q_ro,
  708. k_cache_rep,
  709. v_cache_rep,
  710. query_padding_mask,
  711. key_padding_mask,
  712. causal=causal,
  713. window_size=window_size,
  714. key_leftpad=cache_leftpad,
  715. )
  716. out_pt, _ = attention_ref(
  717. q_ro,
  718. k_cache_rep,
  719. v_cache_rep,
  720. query_padding_mask,
  721. key_padding_mask,
  722. causal=causal,
  723. window_size=window_size,
  724. upcast=False,
  725. reorder_ops=True,
  726. key_leftpad=cache_leftpad,
  727. intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None
  728. )
  729. q = q.to(dtype)
  730. q_unpad = q_unpad.to(dtype) if varlen_q else None
  731. k_cache = k_cache.to(dtype)
  732. v_cache = v_cache.to(dtype)
  733. k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None
  734. v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None
  735. k = k.to(dtype) if k is not None else None
  736. v = v.to(dtype) if v is not None else None
  737. k_unpad = k_unpad.to(dtype) if k_unpad is not None else None
  738. v_unpad = v_unpad.to(dtype) if v_unpad is not None else None
  739. cos = cos.to(dtype) if cos is not None else None
  740. sin = sin.to(dtype) if sin is not None else None
  741. out, lse, *rest = flash_attn_with_kvcache(
  742. q if not varlen_q else q_unpad,
  743. k_cache if page_size is None else k_cache_paged,
  744. v_cache if page_size is None else v_cache_paged,
  745. k if not new_kv or not varlen_q else k_unpad,
  746. v if not new_kv or not varlen_q else v_unpad,
  747. rotary_cos=cos,
  748. rotary_sin=sin,
  749. cache_seqlens=cache_seqlens,
  750. cache_batch_idx=cache_batch_idx,
  751. cache_leftpad=cache_leftpad,
  752. page_table=page_table,
  753. cu_seqlens_q=cu_seqlens_q,
  754. cu_seqlens_k_new=cu_seqlens_k_new,
  755. max_seqlen_q=max_seqlen_q,
  756. causal=causal,
  757. window_size=window_size,
  758. rotary_interleaved=rotary_interleaved,
  759. num_splits=num_splits,
  760. return_softmax_lse=True
  761. )
  762. if varlen_q:
  763. out = output_pad_fn(out)
  764. # out = flash_attn_with_kvcache(
  765. # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
  766. # )
  767. # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
  768. # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
  769. # m = qk.amax(-1, keepdim=True)
  770. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  771. # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
  772. # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
  773. # probs = torch.softmax(qk, dim=-1)
  774. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  775. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  776. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  777. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  778. # breakpoint()
  779. # Check that FlashAttention's numerical error is at most twice the numerical error
  780. # of a Pytorch implementation.
  781. if new_kv:
  782. if page_size is None:
  783. k_cache_select = (
  784. k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx]
  785. )
  786. v_cache_select = (
  787. v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx]
  788. )
  789. else:
  790. k_cache_select = rearrange(
  791. k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()],
  792. "(b nblocks) block_size ... -> b (nblocks block_size) ...",
  793. b=batch_size,
  794. )[:, :seqlen_k].to(dtype_ref)
  795. v_cache_select = rearrange(
  796. v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()],
  797. "(b nblocks) block_size ... -> b (nblocks block_size) ...",
  798. b=batch_size,
  799. )[:, :seqlen_k].to(dtype_ref)
  800. k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref)
  801. v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref)
  802. if dtype is not torch.float8_e4m3fn:
  803. assert torch.equal(v_cache_select, v_cache_ref)
  804. else:
  805. assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3)
  806. # breakpoint()
  807. # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:
  808. if rotary_dim == 0:
  809. assert torch.equal(k_cache_select, k_cache_ref)
  810. else:
  811. # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):
  812. # breakpoint()
  813. if dtype is not torch.float8_e4m3fn:
  814. assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
  815. else:
  816. assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1)
  817. mult = 4 if dtype == torch.float8_e4m3fn else 2
  818. assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
  819. mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5
  820. assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item()
  821. def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, device, dtype):
  822. num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3
  823. k_cache_paged = torch.randn(
  824. num_blocks, page_size, nheads_k, d, device=device, dtype=dtype
  825. )
  826. v_cache_paged = torch.randn(
  827. num_blocks, page_size, nheads_k, d, device=device, dtype=dtype
  828. )
  829. page_table = rearrange(
  830. torch.randperm(num_blocks, dtype=torch.int32, device=device),
  831. "(b nblocks) -> b nblocks",
  832. b=batch_size,
  833. )
  834. k_cache = rearrange(
  835. k_cache_paged[page_table.flatten()],
  836. "(b nblocks) block_size ... -> b (nblocks block_size) ...",
  837. b=batch_size,
  838. )[:, :seqlen_k]
  839. v_cache = rearrange(
  840. v_cache_paged[page_table.flatten()],
  841. "(b nblocks) block_size ... -> b (nblocks block_size) ...",
  842. b=batch_size,
  843. )[:, :seqlen_k]
  844. return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks
  845. @pytest.mark.parametrize("dtype", [torch.bfloat16])
  846. @pytest.mark.parametrize("causal", [False, True])
  847. # @pytest.mark.parametrize('causal', [False])
  848. @pytest.mark.parametrize('d', [128])
  849. @pytest.mark.parametrize(
  850. "seqlen_q,seqlen_k",
  851. [
  852. (64, 8192),
  853. ],
  854. )
  855. def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype):
  856. device = "cuda"
  857. torch.random.manual_seed(0)
  858. batch_size = 2
  859. nheads = 16
  860. nheads_kv = 4
  861. # There was a bug where this would cause "unspecified launch failure" due to Cluster
  862. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
  863. k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
  864. v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)
  865. for _ in range(100):
  866. flash_attn_func(q, k, v, causal=causal)
  867. # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  868. @pytest.mark.parametrize("dtype", [torch.bfloat16])
  869. @pytest.mark.parametrize("causal", [False, True])
  870. # @pytest.mark.parametrize('causal', [False])
  871. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  872. # @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128])
  873. # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
  874. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192])
  875. # @pytest.mark.parametrize('d', [80])
  876. @pytest.mark.parametrize(
  877. "seqlen_q,seqlen_k",
  878. [
  879. (1, 239),
  880. (239, 1),
  881. (3, 799),
  882. (799, 3),
  883. (1024, 128),
  884. (97, 97),
  885. (128, 128),
  886. (200, 200),
  887. (256, 256),
  888. (257, 257),
  889. (384, 384),
  890. (512, 512),
  891. (768, 768),
  892. (1024, 1024),
  893. (2048, 2048),
  894. ],
  895. )
  896. def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
  897. device = "cuda"
  898. # set seed
  899. torch.random.manual_seed(0)
  900. # Simulate under memory load
  901. dummy = torch.empty(70 * 1024 ** 3, dtype=torch.uint8, device=device)
  902. batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger
  903. nheads = 4
  904. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  905. k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  906. v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  907. torch.random.manual_seed(42)
  908. out0, lse0 = flash_attn_func(q, k, v, causal=causal)
  909. g = torch.randn_like(out0)
  910. dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g)
  911. # Numerical error if we just do any arithmetic on dq
  912. dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()
  913. for i in range(1000):
  914. torch.random.manual_seed(42)
  915. out, lse = flash_attn_func(q, k, v, causal=causal)
  916. assert torch.equal(out, out0)
  917. assert torch.equal(lse, lse0)
  918. dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
  919. dq_equal = torch.allclose(dq, dq0, atol=dq_atol)
  920. if not dq_equal:
  921. print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}")
  922. # breakpoint()
  923. assert torch.equal(dv, dv0)
  924. assert torch.equal(dk, dk0)
  925. assert dq_equal
  926. def attention_combine_ref(out_partial, lse_partial):
  927. """
  928. out_partial: (num_splits, batch_size, seqlen, nheads, d)
  929. lse_partial: (num_splits, batch_size, nheads, seqlen)
  930. """
  931. lse = torch.logsumexp(lse_partial, dim=0)
  932. scale = torch.exp(lse_partial - lse)
  933. scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale)
  934. out = (scale.unsqueeze(-1) * out_partial).sum(0)
  935. return out, lse
  936. @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
  937. # @pytest.mark.parametrize("dtype", [torch.float32])
  938. # @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  939. @pytest.mark.parametrize("d", [64, 96, 128, 192, 256])
  940. # @pytest.mark.parametrize("d", [128])
  941. @pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024, 2048])
  942. # @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192])
  943. # @pytest.mark.parametrize("seqlen", [15])
  944. @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 155])
  945. # @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11])
  946. # @pytest.mark.parametrize("num_splits", [128])
  947. def test_flash_attn_combine(num_splits, seqlen, d, dtype):
  948. if DISABLE_SPLIT:
  949. pytest.skip()
  950. device = "cuda"
  951. # set seed
  952. torch.random.manual_seed(1)
  953. batch_size = 5
  954. nheads = 16
  955. # batch_size = 1
  956. # nheads = 1
  957. out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor
  958. lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor
  959. # To test short-circuiting based on num_splits
  960. lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf")
  961. out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype)
  962. out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)
  963. out_pt = out_ref.to(dtype)
  964. print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
  965. print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}")
  966. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  967. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  968. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  969. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  970. # breakpoint()
  971. assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5)
  972. multiple = 2
  973. assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5)
  974. # from flash_attn.utils.benchmark import pytorch_profiler
  975. # # pytorch_profiler(torch.sum, lse_partial)
  976. # pytorch_profiler(flash_attn_combine, out_partial, lse_partial)
  977. # pytorch_profiler(torch.sum, out_partial)