test_flash_attn.py 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187
  1. import math
  2. import einops
  3. import pytest
  4. import torch
  5. import torch.nn.functional as F
  6. from einops import rearrange, repeat
  7. from flash_attn_interface import (
  8. _flash_attn_forward,
  9. flash_attn_func,
  10. flash_attn_varlen_func,
  11. )
  12. from tests.test_util import (
  13. attention_ref,
  14. construct_local_mask,
  15. generate_qkv,
  16. generate_random_padding_mask,
  17. )
  18. ABS_TOL = 5e-3
  19. REL_TOL = 1e-1
  20. def print_diffs(out, out_ref):
  21. out_1d = out.flatten()
  22. out_ref_1d = out_ref.flatten()
  23. for idx, (e_o, e_o_ref) in enumerate(zip(out_1d, out_ref_1d)):
  24. diff = e_o - e_o_ref
  25. abs_diff = abs(diff)
  26. abs_ref = abs(e_o_ref + 1e-5)
  27. relative_diff = abs_diff / abs_ref
  28. if abs_diff > ABS_TOL or relative_diff > REL_TOL:
  29. print(f"==== diff ==== {idx}, test: {e_o}, ref: {e_o_ref}")
  30. @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
  31. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  32. @pytest.mark.parametrize("causal", [False, True])
  33. @pytest.mark.parametrize("local", [False, True])
  34. @pytest.mark.parametrize("deterministic", [True])
  35. @pytest.mark.parametrize("gqa_parallel", [False, True])
  36. @pytest.mark.parametrize("d", [64, 128, 256])
  37. # @pytest.mark.parametrize("descale", [1.0])
  38. @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0])
  39. @pytest.mark.parametrize(
  40. "seqlen_q,seqlen_k",
  41. [
  42. (1, 1),
  43. (64, 128),
  44. (128, 128),
  45. (256, 256),
  46. (113, 203),
  47. (128, 217),
  48. (113, 211),
  49. (108, 256),
  50. (256, 512),
  51. (384, 256),
  52. (640, 128),
  53. (512, 256),
  54. (1024, 1024),
  55. (1023, 1024),
  56. (1024, 1023),
  57. (4096, 4096),
  58. (4224, 4224),
  59. ],
  60. )
  61. def test_flash_attn_output_fp8(
  62. seqlen_q,
  63. seqlen_k,
  64. d,
  65. causal,
  66. local,
  67. deterministic,
  68. mha_type,
  69. dtype,
  70. descale,
  71. gqa_parallel,
  72. ):
  73. device = "cuda"
  74. dtype_init = torch.bfloat16
  75. print(dtype)
  76. print("causal", causal)
  77. print("local", local)
  78. print("gqa_parallel", gqa_parallel)
  79. # set seed
  80. torch.random.manual_seed(42)
  81. # batch_size = 40
  82. # nheads = 16
  83. batch_size = 4
  84. nheads = 6
  85. nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  86. # nheads_kv = 1
  87. # batch_size = 9
  88. # nheads = 6
  89. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  90. q = torch.randn(
  91. batch_size,
  92. seqlen_q,
  93. nheads,
  94. d,
  95. device=device,
  96. dtype=dtype_init,
  97. requires_grad=True,
  98. )
  99. k = torch.randn(
  100. batch_size,
  101. seqlen_k,
  102. nheads_kv,
  103. d,
  104. device=device,
  105. dtype=dtype_init,
  106. requires_grad=True,
  107. )
  108. v = torch.randn(
  109. batch_size,
  110. seqlen_k,
  111. nheads_kv,
  112. d,
  113. device=device,
  114. dtype=dtype_init,
  115. requires_grad=True,
  116. )
  117. q = q.to(dtype)
  118. k = k.to(dtype)
  119. v = v.to(dtype)
  120. softmax_scale = q.shape[-1] ** (-0.5)
  121. descale_q = torch.tensor([descale], dtype=torch.float32, device="cuda")
  122. descale_k = torch.tensor([descale], dtype=torch.float32, device="cuda")
  123. descale_v = torch.tensor([descale], dtype=torch.float32, device="cuda")
  124. out, lse = flash_attn_func(
  125. q,
  126. k,
  127. v,
  128. causal=causal,
  129. window_size=window_size,
  130. deterministic=deterministic,
  131. gqa_parallel=gqa_parallel,
  132. descale_q=descale_q,
  133. descale_k=descale_k,
  134. descale_v=descale_v,
  135. )
  136. q = q.to(dtype_init)
  137. k = k.to(dtype_init)
  138. v = v.to(dtype_init)
  139. descale_q = descale_q.to(dtype_init)
  140. descale_k = descale_k.to(dtype_init)
  141. descale_v = descale_v.to(dtype_init)
  142. q = q * descale_q
  143. k = k * descale_k
  144. v = v * descale_v
  145. out_ref, attn_ref = attention_ref(
  146. q,
  147. k,
  148. v,
  149. None,
  150. None,
  151. causal=causal,
  152. window_size=window_size,
  153. )
  154. out_pt, attn_pt = attention_ref(
  155. q,
  156. k,
  157. v,
  158. None,
  159. None,
  160. causal=causal,
  161. window_size=window_size,
  162. upcast=False,
  163. reorder_ops=True,
  164. )
  165. # qk = torch.einsum('bshd,bthd->bhst', q, k).float()
  166. # m = qk.amax(-1, keepdim=True)
  167. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  168. # exp_sum = s_tmp.sum(-1)
  169. # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
  170. # lse_ref = torch.logsumexp(qk, dim=-1)
  171. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  172. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  173. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  174. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  175. # if not causal:
  176. # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
  177. # breakpoint()
  178. # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
  179. # P = torch.softmax(qk, -1)
  180. # dP = P * (dS - do_o.unsqueeze(1))
  181. # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
  182. # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
  183. # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
  184. # breakpoint()
  185. # assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item() + 1e-2
  186. atol = 4 * (out_pt - out_ref).abs().max().item() + 1e-2
  187. torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=atol, check_dtype=False)
  188. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  189. # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
  190. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  191. # @pytest.mark.parametrize("mha_type", ["mha"])
  192. @pytest.mark.parametrize("causal", [False, True])
  193. # @pytest.mark.parametrize("causal", [False])
  194. @pytest.mark.parametrize("local", [False, True])
  195. # @pytest.mark.parametrize("local", [True])
  196. @pytest.mark.parametrize("deterministic", [False, True])
  197. # @pytest.mark.parametrize("deterministic", [True])
  198. @pytest.mark.parametrize("gqa_parallel", [False, True])
  199. # @pytest.mark.parametrize("gqa_parallel", [False])
  200. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  201. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  202. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  203. # @pytest.mark.parametrize('d', [56, 80])
  204. # @pytest.mark.parametrize("d", [64, 128, 256])
  205. # @pytest.mark.parametrize("d", [64, 96, 128])
  206. # @pytest.mark.parametrize("d", [64])
  207. @pytest.mark.parametrize("d", [64, 128, 256])
  208. @pytest.mark.parametrize("descale", [1.0])
  209. # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0])
  210. @pytest.mark.parametrize(
  211. "seqlen_q,seqlen_k",
  212. [
  213. (1, 1),
  214. (64, 128),
  215. (128, 128),
  216. (256, 256),
  217. (113, 203),
  218. (128, 217),
  219. (113, 211),
  220. (108, 256),
  221. (256, 512),
  222. (384, 256),
  223. (640, 128),
  224. (512, 256),
  225. (1024, 1024),
  226. (1023, 1024),
  227. (1024, 1023),
  228. (4096, 4096),
  229. (4224, 4224),
  230. ],
  231. )
  232. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  233. def test_flash_attn_output(
  234. seqlen_q,
  235. seqlen_k,
  236. d,
  237. causal,
  238. local,
  239. deterministic,
  240. mha_type,
  241. dtype,
  242. descale,
  243. gqa_parallel,
  244. ):
  245. device = "cuda"
  246. if dtype == torch.float8_e4m3fn:
  247. dtype_init = torch.bfloat16
  248. else:
  249. dtype_init = dtype
  250. print(dtype)
  251. print("causal", causal)
  252. print("local", local)
  253. print("gqa_parallel", gqa_parallel)
  254. # set seed
  255. torch.random.manual_seed(42)
  256. # batch_size = 40
  257. # nheads = 16
  258. batch_size = 4
  259. nheads = 6
  260. nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  261. # nheads_kv = 1
  262. # batch_size = 9
  263. # nheads = 6
  264. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  265. q = torch.randn(
  266. batch_size,
  267. seqlen_q,
  268. nheads,
  269. d,
  270. device=device,
  271. dtype=dtype_init,
  272. requires_grad=True,
  273. )
  274. k = torch.randn(
  275. batch_size,
  276. seqlen_k,
  277. nheads_kv,
  278. d,
  279. device=device,
  280. dtype=dtype_init,
  281. requires_grad=True,
  282. )
  283. v = torch.randn(
  284. batch_size,
  285. seqlen_k,
  286. nheads_kv,
  287. d,
  288. device=device,
  289. dtype=dtype_init,
  290. requires_grad=True,
  291. )
  292. q = q.to(dtype)
  293. k = k.to(dtype)
  294. v = v.to(dtype)
  295. softmax_scale = q.shape[-1] ** (-0.5)
  296. descale_q = torch.tensor([descale], dtype=torch.float32, device="cuda")
  297. descale_k = torch.tensor([descale], dtype=torch.float32, device="cuda")
  298. descale_v = torch.tensor([descale], dtype=torch.float32, device="cuda")
  299. if dtype != torch.float8_e4m3fn:
  300. out, lse = flash_attn_func(
  301. q,
  302. k,
  303. v,
  304. causal=causal,
  305. window_size=window_size,
  306. deterministic=deterministic,
  307. gqa_parallel=gqa_parallel,
  308. )
  309. else:
  310. out, lse = flash_attn_func(
  311. q,
  312. k,
  313. v,
  314. causal=causal,
  315. window_size=window_size,
  316. deterministic=deterministic,
  317. gqa_parallel=gqa_parallel,
  318. descale_q=descale_q,
  319. descale_k=descale_k,
  320. descale_v=descale_v,
  321. )
  322. q = q.to(dtype_init)
  323. k = k.to(dtype_init)
  324. v = v.to(dtype_init)
  325. if dtype == torch.float8_e4m3fn:
  326. descale_q = descale_q.to(dtype_init)
  327. descale_k = descale_k.to(dtype_init)
  328. descale_v = descale_v.to(dtype_init)
  329. q = q * descale_q
  330. k = k * descale_k
  331. v = v * descale_v
  332. out_ref, attn_ref = attention_ref(
  333. q,
  334. k,
  335. v,
  336. None,
  337. None,
  338. causal=causal,
  339. window_size=window_size,
  340. )
  341. out_pt, attn_pt = attention_ref(
  342. q,
  343. k,
  344. v,
  345. None,
  346. None,
  347. causal=causal,
  348. window_size=window_size,
  349. upcast=False,
  350. reorder_ops=True,
  351. )
  352. # qk = torch.einsum('bshd,bthd->bhst', q, k).float()
  353. # m = qk.amax(-1, keepdim=True)
  354. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  355. # exp_sum = s_tmp.sum(-1)
  356. # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
  357. # lse_ref = torch.logsumexp(qk, dim=-1)
  358. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  359. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  360. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  361. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  362. # if not causal:
  363. # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
  364. # breakpoint()
  365. if d <= 128 and dtype != torch.float8_e4m3fn:
  366. g = torch.randn_like(out)
  367. do_o = (g.float() * out.float()).sum(-1)
  368. dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
  369. dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g)
  370. dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g)
  371. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  372. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  373. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  374. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  375. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  376. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  377. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  378. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  379. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  380. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  381. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  382. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  383. # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
  384. # P = torch.softmax(qk, -1)
  385. # dP = P * (dS - do_o.unsqueeze(1))
  386. # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
  387. # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
  388. # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
  389. # breakpoint()
  390. # Check that FlashAttention's numerical error is at most twice the numerical error
  391. # of a Pytorch implementation.
  392. # breakpoint()
  393. if dtype != torch.float8_e4m3fn:
  394. assert (out - out_ref).abs().max().item() <= 2 * (
  395. out_pt - out_ref
  396. ).abs().max().item() + 3e-5
  397. else:
  398. # just test correctness of fp8 kernel w/o further quantization techniques
  399. assert (out - out_ref).abs().max().item() <= 4 * (
  400. out_pt - out_ref
  401. ).abs().max().item() + 2e-2
  402. if d <= 128 and dtype != torch.float8_e4m3fn:
  403. assert (dq - dq_ref).abs().max().item() <= 2 * (
  404. dq_pt - dq_ref
  405. ).abs().max().item() + 3e-5
  406. assert (dk - dk_ref).abs().max().item() <= 2 * (
  407. dk_pt - dk_ref
  408. ).abs().max().item() + 3e-5
  409. assert (dv - dv_ref).abs().max().item() <= 2 * (
  410. dv_pt - dv_ref
  411. ).abs().max().item() + 3e-5
  412. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  413. # @pytest.mark.parametrize("dtype", [torch.float16])
  414. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  415. # @pytest.mark.parametrize("mha_type", ["mha"])
  416. @pytest.mark.parametrize("causal", [False, True])
  417. # @pytest.mark.parametrize("causal", [True])
  418. @pytest.mark.parametrize("local", [False, True])
  419. # @pytest.mark.parametrize("local", [False])
  420. @pytest.mark.parametrize("deterministic", [False, True])
  421. # @pytest.mark.parametrize("deterministic", [False])
  422. @pytest.mark.parametrize("add_unused_qkv", [False, True])
  423. # @pytest.mark.parametrize("add_unused_qkv", [True])
  424. # @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  425. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  426. # @pytest.mark.parametrize('d', [256])
  427. # @pytest.mark.parametrize("d", [64, 128, 256])
  428. @pytest.mark.parametrize("d", [64, 128])
  429. # @pytest.mark.parametrize("d", [128])
  430. @pytest.mark.parametrize(
  431. "seqlen_q,seqlen_k",
  432. [
  433. (1, 1),
  434. (1, 3),
  435. (2, 1),
  436. (511, 1),
  437. (3, 513),
  438. (64, 128),
  439. (113, 203),
  440. (128, 128),
  441. (128, 217),
  442. (113, 211),
  443. (108, 256),
  444. (256, 512),
  445. (384, 256),
  446. (512, 256),
  447. (640, 128),
  448. (1024, 1024),
  449. (1023, 1024),
  450. (1024, 1023),
  451. (2048, 2048),
  452. ],
  453. )
  454. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  455. def test_flash_attn_varlen_output(
  456. seqlen_q, seqlen_k, d, causal, local, deterministic, add_unused_qkv, mha_type, dtype
  457. ):
  458. if (
  459. max(seqlen_q, seqlen_k) >= 2048
  460. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  461. ):
  462. pytest.skip() # Reference implementation OOM
  463. device = "cuda"
  464. # set seed
  465. torch.random.manual_seed(0)
  466. # batch_size = 1
  467. # nheads = 1
  468. # nheads_kv = 1
  469. batch_size = 9
  470. nheads = 6
  471. nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  472. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  473. q = torch.randn(
  474. batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True
  475. )
  476. k = torch.randn(
  477. batch_size,
  478. seqlen_k,
  479. nheads_kv,
  480. d,
  481. device=device,
  482. dtype=dtype,
  483. requires_grad=True,
  484. )
  485. v = torch.randn(
  486. batch_size,
  487. seqlen_k,
  488. nheads_kv,
  489. d,
  490. device=device,
  491. dtype=dtype,
  492. requires_grad=True,
  493. )
  494. query_padding_mask = generate_random_padding_mask(
  495. seqlen_q, batch_size, device, mode="random", zero_lengths=False
  496. )
  497. key_padding_mask = generate_random_padding_mask(
  498. seqlen_k, batch_size, device, mode="random", zero_lengths=True
  499. )
  500. # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
  501. def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
  502. if add_unused:
  503. another_mask = generate_random_padding_mask(max_seq_len, bs, device)
  504. attn_mask = torch.logical_and(padding_mask, another_mask)
  505. unused_mask = torch.logical_xor(
  506. torch.logical_or(padding_mask, another_mask), attn_mask
  507. )
  508. else:
  509. attn_mask = padding_mask
  510. unused_mask = None
  511. return attn_mask, unused_mask
  512. query_padding_mask, query_unused_mask = _gen_unused_masks(
  513. query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device
  514. )
  515. key_padding_mask, key_unused_mask = _gen_unused_masks(
  516. key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device
  517. )
  518. (
  519. q_unpad,
  520. k_unpad,
  521. v_unpad,
  522. cu_seqlens_q,
  523. cu_seqlens_k,
  524. seqused_q,
  525. seqused_k,
  526. max_seqlen_q,
  527. max_seqlen_k,
  528. q,
  529. k,
  530. v,
  531. output_pad_fn,
  532. dq_pad_fn,
  533. dk_pad_fn,
  534. ) = generate_qkv(
  535. q,
  536. k,
  537. v,
  538. query_padding_mask,
  539. key_padding_mask,
  540. kvpacked=False,
  541. query_unused_mask=query_unused_mask,
  542. key_unused_mask=key_unused_mask,
  543. )
  544. # print("cu_seqlens_q: ", cu_seqlens_q)
  545. # print("cu_seqlens_k: ", cu_seqlens_k)
  546. # print("q_unpad, shape: ", q_unpad.shape)
  547. # print("k_unpad, shape: ", k_unpad.shape)
  548. # print("v_unpad, shape: ", v_unpad.shape)
  549. out_unpad, sm_lse = flash_attn_varlen_func(
  550. q_unpad,
  551. k_unpad,
  552. v_unpad,
  553. cu_seqlens_q,
  554. cu_seqlens_k,
  555. max_seqlen_q,
  556. max_seqlen_k,
  557. causal=causal,
  558. deterministic=deterministic,
  559. seqused_q=seqused_q,
  560. seqused_k=seqused_k,
  561. window_size=window_size,
  562. )
  563. out = output_pad_fn(out_unpad)
  564. if query_unused_mask is not None:
  565. q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
  566. out.masked_fill_(q_zero_masking, 0.0)
  567. dropout_mask = None
  568. out_ref, attn_ref = attention_ref(
  569. q,
  570. k,
  571. v,
  572. query_padding_mask,
  573. key_padding_mask,
  574. causal=causal,
  575. window_size=window_size,
  576. )
  577. out_pt, attn_pt = attention_ref(
  578. q,
  579. k,
  580. v,
  581. query_padding_mask,
  582. key_padding_mask,
  583. causal=causal,
  584. window_size=window_size,
  585. upcast=False,
  586. reorder_ops=True,
  587. )
  588. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  589. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  590. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  591. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  592. g = torch.randn_like(out)
  593. if d <= 128:
  594. (
  595. dq_unpad,
  596. dk_unpad,
  597. dv_unpad,
  598. ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
  599. dk = dk_pad_fn(dk_unpad)
  600. dv = dk_pad_fn(dv_unpad)
  601. if key_unused_mask is not None:
  602. k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
  603. dk.masked_fill_(k_zero_masking, 0.0)
  604. dv.masked_fill_(k_zero_masking, 0.0)
  605. (
  606. dq_ref,
  607. dk_ref,
  608. dv_ref,
  609. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  610. zero_masking = rearrange(
  611. torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1"
  612. )
  613. dk_ref.masked_fill_(zero_masking, 0.0)
  614. dv_ref.masked_fill_(zero_masking, 0.0)
  615. (
  616. dq_pt,
  617. dk_pt,
  618. dv_pt,
  619. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  620. dk_pt.masked_fill_(zero_masking, 0.0)
  621. dv_pt.masked_fill_(zero_masking, 0.0)
  622. dq = dq_pad_fn(dq_unpad)
  623. if query_unused_mask is not None:
  624. dq.masked_fill_(q_zero_masking, 0.0)
  625. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  626. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  627. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  628. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  629. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  630. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  631. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  632. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  633. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  634. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  635. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  636. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  637. # Check that FlashAttention's numerical error is at most twice the numerical error
  638. # of a Pytorch implementation.
  639. assert (out - out_ref).abs().max().item() <= 2 * (
  640. out_pt - out_ref
  641. ).abs().max().item()
  642. if d <= 128:
  643. assert (dq - dq_ref).abs().max().item() < 1e-4 or (
  644. dq - dq_ref
  645. ).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
  646. assert (dk - dk_ref).abs().max().item() < 1e-4 or (
  647. dk - dk_ref
  648. ).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
  649. assert (dv - dv_ref).abs().max().item() < 1e-4 or (
  650. dv - dv_ref
  651. ).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
  652. @pytest.mark.parametrize("dtype", [torch.bfloat16])
  653. # @pytest.mark.parametrize("dtype", [torch.float16])
  654. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  655. @pytest.mark.parametrize("causal", [False, True])
  656. # @pytest.mark.parametrize("causal", [False])
  657. @pytest.mark.parametrize("deterministic", [True, False])
  658. # @pytest.mark.parametrize("deterministic", [False])
  659. # @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  660. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  661. # @pytest.mark.parametrize('d', [128])
  662. # @pytest.mark.parametrize("d", [64, 128, 256])
  663. @pytest.mark.parametrize("d", [128, 64])
  664. # @pytest.mark.parametrize("d", [128])
  665. @pytest.mark.parametrize(
  666. "seqlen_q,seqlen_k",
  667. [
  668. # (1, 1),
  669. # (1, 3),
  670. # (2, 1),
  671. # (511, 1),
  672. # (3, 513),
  673. # (64, 128),
  674. # (113, 203),
  675. # (128, 128),
  676. # (128, 217),
  677. # (113, 211),
  678. # (108, 256),
  679. (256, 512),
  680. # (384, 256),
  681. (768, 512),
  682. # (512, 256),
  683. # (640, 128),
  684. (1024, 1024),
  685. # (1023, 1024),
  686. # (1024, 1023),
  687. # (2048, 2048),
  688. ],
  689. )
  690. @pytest.mark.parametrize("add_unused_qkv", [False])
  691. @pytest.mark.parametrize("shuffle_pages", [True, False])
  692. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  693. def test_flash_attn_paged1(
  694. seqlen_q,
  695. seqlen_k,
  696. d,
  697. causal,
  698. deterministic,
  699. add_unused_qkv,
  700. mha_type,
  701. dtype,
  702. shuffle_pages,
  703. ):
  704. run_conf = locals()
  705. if (
  706. max(seqlen_q, seqlen_k) >= 2048
  707. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  708. ):
  709. pytest.skip() # Reference implementation OOM
  710. device = "cuda"
  711. # set seed
  712. torch.random.manual_seed(0)
  713. # batch_size = 1
  714. # nheads = 1
  715. batch_size = 9
  716. nheads = 6
  717. nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  718. q = torch.randn(
  719. batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True
  720. )
  721. page_size = 256
  722. num_pages = batch_size * seqlen_k // page_size
  723. assert seqlen_k % page_size == 0, "Max seqlen must be divisible by page size"
  724. block_table = torch.reshape(
  725. torch.arange(num_pages, dtype=torch.int32, device=device), (batch_size, -1)
  726. )
  727. k_paged = torch.randn(
  728. num_pages,
  729. page_size,
  730. nheads_kv,
  731. d,
  732. device=device,
  733. dtype=dtype,
  734. requires_grad=True,
  735. )
  736. v_paged = torch.randn(
  737. num_pages,
  738. page_size,
  739. nheads_kv,
  740. d,
  741. device=device,
  742. dtype=dtype,
  743. requires_grad=True,
  744. )
  745. if shuffle_pages:
  746. block_table = torch.randperm(num_pages, dtype=torch.int32, device=device).view(
  747. batch_size, -1
  748. )
  749. k = torch.index_select(k_paged, 0, block_table.view(-1)).view(
  750. batch_size, seqlen_k, nheads_kv, d
  751. )
  752. v = torch.index_select(v_paged, 0, block_table.view(-1)).view(
  753. batch_size, seqlen_k, nheads_kv, d
  754. )
  755. else:
  756. k = torch.reshape(k_paged, (batch_size, seqlen_k, nheads_kv, d))
  757. v = torch.reshape(v_paged, (batch_size, seqlen_k, nheads_kv, d))
  758. query_padding_mask = generate_random_padding_mask(
  759. seqlen_q, batch_size, device, mode="random"
  760. )
  761. key_padding_mask = generate_random_padding_mask(
  762. seqlen_k, batch_size, device, mode="random"
  763. )
  764. # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
  765. def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
  766. if add_unused:
  767. another_mask = generate_random_padding_mask(max_seq_len, bs, device)
  768. attn_mask = torch.logical_and(padding_mask, another_mask)
  769. unused_mask = torch.logical_xor(
  770. torch.logical_or(padding_mask, another_mask), attn_mask
  771. )
  772. else:
  773. attn_mask = padding_mask
  774. unused_mask = None
  775. return attn_mask, unused_mask
  776. query_padding_mask, query_unused_mask = _gen_unused_masks(
  777. query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device
  778. )
  779. key_padding_mask, key_unused_mask = _gen_unused_masks(
  780. key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device
  781. )
  782. (
  783. q_unpad,
  784. k_unpad,
  785. v_unpad,
  786. cu_seqlens_q,
  787. cu_seqlens_k,
  788. seqused_q,
  789. seqused_k,
  790. max_seqlen_q,
  791. max_seqlen_k,
  792. q,
  793. k,
  794. v,
  795. output_pad_fn,
  796. dq_pad_fn,
  797. dk_pad_fn,
  798. ) = generate_qkv(
  799. q,
  800. k,
  801. v,
  802. query_padding_mask,
  803. key_padding_mask,
  804. kvpacked=False,
  805. query_unused_mask=query_unused_mask,
  806. key_unused_mask=key_unused_mask,
  807. )
  808. # print("cu_seqlens_q: ", cu_seqlens_q)
  809. # print("cu_seqlens_k: ", cu_seqlens_k)
  810. # print("q_unpad, shape: ", q_unpad.shape)
  811. # print("k_unpad, shape: ", k_unpad.shape)
  812. # print("v_unpad, shape: ", v_unpad.shape)
  813. out_unpad, sm_lse = flash_attn_varlen_func(
  814. q_unpad,
  815. k_paged,
  816. v_paged,
  817. cu_seqlens_q,
  818. cu_seqlens_k,
  819. max_seqlen_q,
  820. max_seqlen_k,
  821. causal=causal,
  822. deterministic=deterministic,
  823. block_table=block_table,
  824. )
  825. out = output_pad_fn(out_unpad)
  826. out_unpaged_unpad, sm_unpaged_lse = flash_attn_varlen_func(
  827. q_unpad,
  828. k_unpad,
  829. v_unpad,
  830. cu_seqlens_q,
  831. cu_seqlens_k,
  832. max_seqlen_q,
  833. max_seqlen_k,
  834. causal=causal,
  835. deterministic=deterministic,
  836. )
  837. out_unpaged = output_pad_fn(out_unpaged_unpad)
  838. dropout_mask = None
  839. out_ref, attn_ref = attention_ref(
  840. q,
  841. k,
  842. v,
  843. query_padding_mask,
  844. key_padding_mask,
  845. causal=causal,
  846. )
  847. out_pt, attn_pt = attention_ref(
  848. q,
  849. k,
  850. v,
  851. query_padding_mask,
  852. key_padding_mask,
  853. causal=causal,
  854. upcast=False,
  855. reorder_ops=True,
  856. )
  857. # print(f"{k.stride()=}, {v.stride()=}, {k_paged.stride()=}, {v_paged.stride()=}, {block_table.stride()=}")
  858. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  859. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  860. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  861. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  862. print(f"Output max diff paged vs varlen: {(out - out_unpaged).abs().max().item()}")
  863. print(
  864. f"Output mean diff paged vs varlen: {(out - out_unpaged).abs().mean().item()}"
  865. )
  866. # Check that FlashAttention's numerical error is at most twice the numerical error
  867. # of a Pytorch implementation.
  868. # import fbvscode; fbvscode.set_trace()
  869. assert (out - out_ref).abs().max().item() <= 2 * (
  870. out_pt - out_ref
  871. ).abs().max().item()
  872. @pytest.mark.parametrize("dtype", ([torch.bfloat16]))
  873. # @pytest.mark.parametrize("dtype", [torch.bfloat16])
  874. @pytest.mark.parametrize("local", [False])
  875. # @pytest.mark.parametrize("local", [True])
  876. @pytest.mark.parametrize(
  877. "d", [128, 64]
  878. ) # [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  879. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  880. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  881. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  882. # @pytest.mark.parametrize('d', [56, 80])
  883. # @pytest.mark.parametrize("d", [64])
  884. @pytest.mark.parametrize("swap_sq_sk", [False, True])
  885. # @pytest.mark.parametrize("swap_sq_sk", [True])
  886. @pytest.mark.parametrize(
  887. "seqlen_q,seqlen_k",
  888. [
  889. (1, 239),
  890. (3, 799),
  891. (127, 512),
  892. (127, 513),
  893. (113, 203),
  894. (128, 217),
  895. (113, 211),
  896. (108, 256),
  897. (256, 512),
  898. (1023, 1024),
  899. ],
  900. )
  901. # TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
  902. @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
  903. # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
  904. def test_flash_attn_varlen_paged2(
  905. seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype
  906. ):
  907. # Test ported from FlashAttention V2 test test_flash_attn_varlen_causal
  908. def _generate_block_kvcache(
  909. seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
  910. ):
  911. num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
  912. k_cache_paged = torch.randn(
  913. num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
  914. )
  915. v_cache_paged = torch.randn(
  916. num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
  917. )
  918. block_table = rearrange(
  919. torch.randperm(num_blocks, dtype=torch.int32, device=device),
  920. "(b nblocks) -> b nblocks",
  921. b=batch_size,
  922. )
  923. k_cache = rearrange(
  924. # pytorch 1.12 doesn't have indexing with int32
  925. k_cache_paged[block_table.to(dtype=torch.long).flatten()],
  926. "(b nblocks) block_size ... -> b (nblocks block_size) ...",
  927. b=batch_size,
  928. )[:, :seqlen_k]
  929. v_cache = rearrange(
  930. v_cache_paged[block_table.to(dtype=torch.long).flatten()],
  931. "(b nblocks) block_size ... -> b (nblocks block_size) ...",
  932. b=batch_size,
  933. )[:, :seqlen_k]
  934. return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks
  935. if (
  936. max(seqlen_q, seqlen_k) >= 2048
  937. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  938. ):
  939. pytest.skip() # Reference implementation OOM
  940. if swap_sq_sk:
  941. seqlen_q, seqlen_k = seqlen_k, seqlen_q
  942. device = "cuda"
  943. causal = True
  944. # set seed
  945. torch.random.manual_seed(0)
  946. batch_size = 8
  947. nheads = 9
  948. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  949. q = torch.randn(
  950. batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True
  951. )
  952. if paged_kv_block_size is None:
  953. k = torch.randn(
  954. batch_size,
  955. seqlen_k,
  956. nheads,
  957. d,
  958. device=device,
  959. dtype=dtype,
  960. requires_grad=True,
  961. )
  962. v = torch.randn(
  963. batch_size,
  964. seqlen_k,
  965. nheads,
  966. d,
  967. device=device,
  968. dtype=dtype,
  969. requires_grad=True,
  970. )
  971. block_table = None
  972. else:
  973. k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = (
  974. _generate_block_kvcache(
  975. seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype
  976. )
  977. )
  978. query_padding_mask = generate_random_padding_mask(
  979. seqlen_q, batch_size, device, mode="random"
  980. )
  981. key_padding_mask = generate_random_padding_mask(
  982. seqlen_k, batch_size, device, mode="random"
  983. )
  984. def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
  985. if add_unused:
  986. another_mask = generate_random_padding_mask(max_seq_len, bs, device)
  987. attn_mask = torch.logical_and(padding_mask, another_mask)
  988. unused_mask = torch.logical_xor(
  989. torch.logical_or(padding_mask, another_mask), attn_mask
  990. )
  991. else:
  992. attn_mask = padding_mask
  993. unused_mask = None
  994. return attn_mask, unused_mask
  995. query_padding_mask, query_unused_mask = _gen_unused_masks(
  996. query_padding_mask, False, seqlen_q, batch_size, q.device
  997. )
  998. key_padding_mask, key_unused_mask = _gen_unused_masks(
  999. key_padding_mask, False, seqlen_k, batch_size, k.device
  1000. )
  1001. (
  1002. q_unpad,
  1003. k_unpad,
  1004. v_unpad,
  1005. cu_seqlens_q,
  1006. cu_seqlens_k,
  1007. seqused_q,
  1008. seqused_k,
  1009. max_seqlen_q,
  1010. max_seqlen_k,
  1011. q,
  1012. k,
  1013. v,
  1014. output_pad_fn,
  1015. dq_pad_fn,
  1016. dk_pad_fn,
  1017. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
  1018. out_unpad, sm_lse = flash_attn_varlen_func(
  1019. q_unpad,
  1020. k_unpad if paged_kv_block_size is None else k_cache_paged,
  1021. v_unpad if paged_kv_block_size is None else v_cache_paged,
  1022. cu_seqlens_q,
  1023. cu_seqlens_k,
  1024. max_seqlen_q,
  1025. max_seqlen_k,
  1026. causal=causal,
  1027. block_table=block_table,
  1028. )
  1029. out = output_pad_fn(out_unpad)
  1030. out_ref, attn_ref = attention_ref(
  1031. q,
  1032. k,
  1033. v,
  1034. query_padding_mask,
  1035. key_padding_mask,
  1036. None,
  1037. 0.0,
  1038. None,
  1039. causal=causal,
  1040. window_size=window_size,
  1041. )
  1042. out_pt, attn_pt = attention_ref(
  1043. q,
  1044. k,
  1045. v,
  1046. query_padding_mask,
  1047. key_padding_mask,
  1048. None,
  1049. 0.0,
  1050. None,
  1051. causal=causal,
  1052. window_size=window_size,
  1053. upcast=False,
  1054. reorder_ops=True,
  1055. )
  1056. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  1057. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  1058. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  1059. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  1060. g = torch.randn_like(out)
  1061. do_o = (g.float() * out.float()).sum(-1)
  1062. test_backward = block_table is None
  1063. if test_backward:
  1064. (
  1065. dq_unpad,
  1066. dk_unpad,
  1067. dv_unpad,
  1068. ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
  1069. dq = dq_pad_fn(dq_unpad)
  1070. dk = dk_pad_fn(dk_unpad)
  1071. dv = dk_pad_fn(dv_unpad)
  1072. (
  1073. dq_ref,
  1074. dk_ref,
  1075. dv_ref,
  1076. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  1077. (
  1078. dq_pt,
  1079. dk_pt,
  1080. dv_pt,
  1081. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  1082. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  1083. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  1084. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  1085. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  1086. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  1087. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  1088. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  1089. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  1090. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  1091. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  1092. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  1093. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  1094. # Check that FlashAttention's numerical error is at most twice the numerical error
  1095. # of a Pytorch implementation.
  1096. assert (out - out_ref).abs().max().item() <= 2 * (
  1097. out_pt - out_ref
  1098. ).abs().max().item() + 1e-5
  1099. if test_backward:
  1100. assert (dq - dq_ref).abs().max().item() <= 2 * (
  1101. dq_pt - dq_ref
  1102. ).abs().max().item() + 1e-5
  1103. assert (dk - dk_ref).abs().max().item() <= 2 * (
  1104. dk_pt - dk_ref
  1105. ).abs().max().item() + 1e-5
  1106. assert (dv - dv_ref).abs().max().item() <= 2 * (
  1107. dv_pt - dv_ref
  1108. ).abs().max().item() + 1e-5
  1109. if __name__ == "__main__":
  1110. test_flash_attn_varlen_causal(512, 768, False, 128, False, 256, torch.bfloat16)