test_flash_attn.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. import math
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from einops import rearrange, repeat
  6. from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, _flash_attn_forward
  7. from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref
  8. ABS_TOL = 5e-3
  9. REL_TOL = 1e-1
  10. def print_diffs(out, out_ref):
  11. out_1d = out.flatten()
  12. out_ref_1d = out_ref.flatten()
  13. for idx, (e_o, e_o_ref) in enumerate(zip(out_1d, out_ref_1d)):
  14. diff = e_o - e_o_ref
  15. abs_diff = abs(diff)
  16. abs_ref = abs(e_o_ref + 1e-5)
  17. relative_diff = abs_diff / abs_ref
  18. if abs_diff > ABS_TOL or relative_diff > REL_TOL:
  19. print(f"==== diff ==== {idx}, test: {e_o}, ref: {e_o_ref}")
  20. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  21. # @pytest.mark.parametrize("dtype", [torch.float16])
  22. # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
  23. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  24. # @pytest.mark.parametrize("mha_type", ["mha"])
  25. @pytest.mark.parametrize("causal", [False, True])
  26. # @pytest.mark.parametrize("causal", [True])
  27. @pytest.mark.parametrize("local", [False, True])
  28. # @pytest.mark.parametrize("local", [True])
  29. @pytest.mark.parametrize("deterministic", [False, True])
  30. # @pytest.mark.parametrize("deterministic", [True])
  31. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  32. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  33. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  34. # @pytest.mark.parametrize('d', [56, 80])
  35. # @pytest.mark.parametrize("d", [64, 128, 256])
  36. # @pytest.mark.parametrize("d", [64, 96, 128])
  37. # @pytest.mark.parametrize("d", [256])
  38. @pytest.mark.parametrize("d", [64, 128, 256])
  39. @pytest.mark.parametrize("descale", [1.0])
  40. # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0])
  41. @pytest.mark.parametrize(
  42. "seqlen_q,seqlen_k",
  43. [
  44. (1, 1),
  45. # (257, 1),
  46. (64, 128),
  47. (128, 128),
  48. (256, 256),
  49. (113, 203),
  50. (128, 217),
  51. (113, 211),
  52. (108, 256),
  53. (256, 512),
  54. (384, 256),
  55. (640, 128),
  56. (512, 256),
  57. (1024, 1024),
  58. (1023, 1024),
  59. (1024, 1023),
  60. (4096, 4096),
  61. ],
  62. )
  63. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  64. def test_flash_attn_output(
  65. seqlen_q, seqlen_k, d, causal, local, deterministic, mha_type, dtype, descale
  66. ):
  67. device = "cuda"
  68. if(dtype == torch.float8_e4m3fn):
  69. dtype_init = torch.float16
  70. else:
  71. dtype_init = dtype
  72. print(dtype)
  73. # set seed
  74. torch.random.manual_seed(0)
  75. # batch_size = 40
  76. # nheads = 16
  77. batch_size = 4
  78. nheads = 6
  79. nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  80. # nheads_kv = 2
  81. # batch_size = 9
  82. # nheads = 6
  83. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  84. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
  85. k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
  86. v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
  87. q = q.to(dtype)
  88. k = k.to(dtype)
  89. v = v.to(dtype)
  90. softmax_scale = q.shape[-1] ** (-0.5)
  91. descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda')
  92. descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda')
  93. descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda')
  94. if(dtype != torch.float8_e4m3fn):
  95. out, lse = flash_attn_func(q, k, v, causal=causal, window_size=window_size, deterministic=deterministic)
  96. else:
  97. out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward(
  98. q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v
  99. )
  100. q = q.to(dtype_init)
  101. k = k.to(dtype_init)
  102. v = v.to(dtype_init)
  103. if(dtype == torch.float8_e4m3fn):
  104. descale_q = descale_q.to(dtype_init)
  105. descale_k = descale_k.to(dtype_init)
  106. descale_v = descale_v.to(dtype_init)
  107. q = q * descale_q
  108. k = k * descale_k
  109. v = v * descale_v
  110. out_ref, attn_ref = attention_ref(
  111. q,
  112. k,
  113. v,
  114. None,
  115. None,
  116. causal=causal,
  117. window_size=window_size,
  118. )
  119. out_pt, attn_pt = attention_ref(
  120. q,
  121. k,
  122. v,
  123. None,
  124. None,
  125. causal=causal,
  126. window_size=window_size,
  127. upcast=False,
  128. reorder_ops=True,
  129. )
  130. # qk = torch.einsum('bshd,bthd->bhst', q, k).float()
  131. # m = qk.amax(-1, keepdim=True)
  132. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  133. # exp_sum = s_tmp.sum(-1)
  134. # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
  135. # lse_ref = torch.logsumexp(qk, dim=-1)
  136. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  137. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  138. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  139. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  140. # if not causal:
  141. # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
  142. # breakpoint()
  143. if d <= 128 and dtype != torch.float8_e4m3fn:
  144. g = torch.randn_like(out)
  145. do_o = (g.float() * out.float()).sum(-1)
  146. dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
  147. dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g)
  148. dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g)
  149. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  150. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  151. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  152. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  153. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  154. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  155. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  156. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  157. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  158. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  159. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  160. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  161. # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
  162. # P = torch.softmax(qk, -1)
  163. # dP = P * (dS - do_o.unsqueeze(1))
  164. # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
  165. # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
  166. # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
  167. # breakpoint()
  168. # Check that FlashAttention's numerical error is at most twice the numerical error
  169. # of a Pytorch implementation.
  170. # breakpoint()
  171. if(dtype != torch.float8_e4m3fn):
  172. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5
  173. else:
  174. # just test correctness of fp8 kernel w/o further quantization techniques
  175. assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item()
  176. if d <= 128 and dtype != torch.float8_e4m3fn:
  177. assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5
  178. assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 3e-5
  179. assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 3e-5
  180. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  181. # @pytest.mark.parametrize("dtype", [torch.float16])
  182. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  183. # @pytest.mark.parametrize("mha_type", ["mha"])
  184. @pytest.mark.parametrize("causal", [False, True])
  185. # @pytest.mark.parametrize("causal", [True])
  186. @pytest.mark.parametrize("local", [False, True])
  187. # @pytest.mark.parametrize("local", [False])
  188. @pytest.mark.parametrize("deterministic", [False, True])
  189. # @pytest.mark.parametrize("deterministic", [False])
  190. @pytest.mark.parametrize("add_unused_qkv", [False, True])
  191. # @pytest.mark.parametrize("add_unused_qkv", [True])
  192. # @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  193. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  194. # @pytest.mark.parametrize('d', [256])
  195. # @pytest.mark.parametrize("d", [64, 128, 256])
  196. @pytest.mark.parametrize("d", [64, 128])
  197. # @pytest.mark.parametrize("d", [128])
  198. @pytest.mark.parametrize(
  199. "seqlen_q,seqlen_k",
  200. [
  201. (1, 1),
  202. (1, 3),
  203. (2, 1),
  204. (511, 1),
  205. (3, 513),
  206. (64, 128),
  207. (113, 203),
  208. (128, 128),
  209. (128, 217),
  210. (113, 211),
  211. (108, 256),
  212. (256, 512),
  213. (384, 256),
  214. (512, 256),
  215. (640, 128),
  216. (1024, 1024),
  217. (1023, 1024),
  218. (1024, 1023),
  219. (2048, 2048),
  220. ],
  221. )
  222. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  223. def test_flash_attn_varlen_output(
  224. seqlen_q, seqlen_k, d, causal, local, deterministic, add_unused_qkv, mha_type, dtype
  225. ):
  226. if (
  227. max(seqlen_q, seqlen_k) >= 2048
  228. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  229. ):
  230. pytest.skip() # Reference implementation OOM
  231. device = "cuda"
  232. # set seed
  233. torch.random.manual_seed(0)
  234. # batch_size = 1
  235. # nheads = 1
  236. # nheads_kv = 1
  237. batch_size = 9
  238. nheads = 6
  239. nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  240. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  241. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  242. k = torch.randn(
  243. batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
  244. )
  245. v = torch.randn(
  246. batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
  247. )
  248. query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random", zero_lengths=False)
  249. key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True)
  250. # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
  251. def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
  252. if add_unused:
  253. another_mask = generate_random_padding_mask(max_seq_len, bs, device)
  254. attn_mask = torch.logical_and(padding_mask, another_mask)
  255. unused_mask = torch.logical_xor(torch.logical_or(padding_mask, another_mask), attn_mask)
  256. else:
  257. attn_mask = padding_mask
  258. unused_mask = None
  259. return attn_mask, unused_mask
  260. query_padding_mask, query_unused_mask = _gen_unused_masks(query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device)
  261. key_padding_mask, key_unused_mask = _gen_unused_masks(key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device)
  262. (
  263. q_unpad,
  264. k_unpad,
  265. v_unpad,
  266. cu_seqlens_q,
  267. cu_seqlens_k,
  268. seqused_q,
  269. seqused_k,
  270. max_seqlen_q,
  271. max_seqlen_k,
  272. q,
  273. k,
  274. v,
  275. output_pad_fn,
  276. dq_pad_fn,
  277. dk_pad_fn,
  278. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask)
  279. # print("cu_seqlens_q: ", cu_seqlens_q)
  280. # print("cu_seqlens_k: ", cu_seqlens_k)
  281. # print("q_unpad, shape: ", q_unpad.shape)
  282. # print("k_unpad, shape: ", k_unpad.shape)
  283. # print("v_unpad, shape: ", v_unpad.shape)
  284. out_unpad, sm_lse = flash_attn_varlen_func(
  285. q_unpad,
  286. k_unpad,
  287. v_unpad,
  288. cu_seqlens_q,
  289. cu_seqlens_k,
  290. max_seqlen_q,
  291. max_seqlen_k,
  292. causal=causal,
  293. deterministic=deterministic,
  294. seqused_q=seqused_q,
  295. seqused_k=seqused_k,
  296. window_size=window_size,
  297. )
  298. out = output_pad_fn(out_unpad)
  299. if query_unused_mask is not None:
  300. q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
  301. out.masked_fill_(q_zero_masking, 0.0)
  302. dropout_mask = None
  303. out_ref, attn_ref = attention_ref(
  304. q,
  305. k,
  306. v,
  307. query_padding_mask,
  308. key_padding_mask,
  309. causal=causal,
  310. window_size=window_size,
  311. )
  312. out_pt, attn_pt = attention_ref(
  313. q,
  314. k,
  315. v,
  316. query_padding_mask,
  317. key_padding_mask,
  318. causal=causal,
  319. window_size=window_size,
  320. upcast=False,
  321. reorder_ops=True,
  322. )
  323. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  324. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  325. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  326. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  327. g = torch.randn_like(out)
  328. if d <= 128:
  329. (
  330. dq_unpad,
  331. dk_unpad,
  332. dv_unpad,
  333. ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
  334. dk = dk_pad_fn(dk_unpad)
  335. dv = dk_pad_fn(dv_unpad)
  336. if key_unused_mask is not None:
  337. k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
  338. dk.masked_fill_(k_zero_masking, 0.0)
  339. dv.masked_fill_(k_zero_masking, 0.0)
  340. (
  341. dq_ref,
  342. dk_ref,
  343. dv_ref,
  344. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  345. zero_masking = rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1")
  346. dk_ref.masked_fill_(zero_masking, 0.0)
  347. dv_ref.masked_fill_(zero_masking, 0.0)
  348. (
  349. dq_pt,
  350. dk_pt,
  351. dv_pt,
  352. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  353. dk_pt.masked_fill_(zero_masking, 0.0)
  354. dv_pt.masked_fill_(zero_masking, 0.0)
  355. dq = dq_pad_fn(dq_unpad)
  356. if query_unused_mask is not None:
  357. dq.masked_fill_(q_zero_masking, 0.0)
  358. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  359. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  360. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  361. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  362. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  363. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  364. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  365. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  366. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  367. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  368. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  369. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  370. # Check that FlashAttention's numerical error is at most twice the numerical error
  371. # of a Pytorch implementation.
  372. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  373. if d <= 128:
  374. assert (dq - dq_ref).abs().max().item() < 1e-4 or (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
  375. assert (dk - dk_ref).abs().max().item() < 1e-4 or (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
  376. assert (dk - dk_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
  377. def test_doc_mask():
  378. import numpy as np
  379. query = torch.randn(8192, 4, 64, device="cuda", dtype=torch.bfloat16)
  380. key = torch.randn(8192, 4, 64, device="cuda", dtype=torch.bfloat16)
  381. value = torch.randn(8192, 4, 64, device="cuda", dtype=torch.bfloat16)
  382. p = 0
  383. cu_seqlens_q = torch.tensor(
  384. [0, 1177, 2579, 3414, 3899, 4202, 4585, 5283, 5477, 5660, 7056, 7691, 8192],
  385. device="cuda",
  386. dtype=torch.int32,
  387. )
  388. cu_seqlens_k = torch.tensor(
  389. [0, 1177, 2579, 3414, 3899, 4202, 4585, 5283, 5477, 5660, 7056, 7691, 8192],
  390. device="cuda",
  391. dtype=torch.int32,
  392. )
  393. max_seqlen_q = 1402
  394. max_seqlen_k = 1402
  395. cu_seqlens_q_max = torch.full(
  396. (8192,),
  397. 8192,
  398. device="cuda",
  399. dtype=torch.int32,
  400. )
  401. cu_seqlens_k_max = torch.full(
  402. (8192,),
  403. 8192,
  404. device="cuda",
  405. dtype=torch.int32,
  406. )
  407. cu_seqlens_q_max[: cu_seqlens_q.numel()] = cu_seqlens_q
  408. cu_seqlens_k_max[: cu_seqlens_k.numel()] = cu_seqlens_k
  409. ncu_out = flash_attn_varlen_func(
  410. query,
  411. key,
  412. value,
  413. cu_seqlens_q_max,
  414. cu_seqlens_k_max,
  415. 8192,
  416. 8192,
  417. )
  418. # return
  419. times = []
  420. for i in range(100):
  421. start = torch.cuda.Event(enable_timing=True)
  422. end = torch.cuda.Event(enable_timing=True)
  423. start.record()
  424. out0 = flash_attn_varlen_func(
  425. query,
  426. key,
  427. value,
  428. cu_seqlens_q,
  429. cu_seqlens_k,
  430. max_seqlen_q,
  431. max_seqlen_k,
  432. )
  433. end.record()
  434. torch.cuda.synchronize()
  435. if i > 3:
  436. times.append(start.elapsed_time(end))
  437. print(np.mean(times))
  438. times = []
  439. for i in range(100):
  440. start = torch.cuda.Event(enable_timing=True)
  441. end = torch.cuda.Event(enable_timing=True)
  442. start.record()
  443. out1 = flash_attn_varlen_func(
  444. query,
  445. key,
  446. value,
  447. cu_seqlens_q,
  448. cu_seqlens_k,
  449. 8192,
  450. 8192,
  451. )
  452. end.record()
  453. torch.cuda.synchronize()
  454. if i > 3:
  455. times.append(start.elapsed_time(end))
  456. print(np.mean(times))
  457. times = []
  458. for i in range(100):
  459. start = torch.cuda.Event(enable_timing=True)
  460. end = torch.cuda.Event(enable_timing=True)
  461. start.record()
  462. out2 = flash_attn_varlen_func(
  463. query,
  464. key,
  465. value,
  466. cu_seqlens_q_max,
  467. cu_seqlens_k_max,
  468. 8192,
  469. 8192,
  470. optimize_for_doc_masking=True,
  471. )
  472. end.record()
  473. torch.cuda.synchronize()
  474. if i > 3:
  475. times.append(start.elapsed_time(end))
  476. print(np.mean(times))
  477. print(torch.allclose(out0[0], out1[0]))
  478. print(torch.allclose(out1[0], out2[0]))