1
0

benchmark_attn.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. from collections import namedtuple
  2. from functools import partial
  3. import math
  4. from typing import NamedTuple
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import time
  9. try:
  10. import cudnn
  11. except ImportError:
  12. cudnn = None
  13. # cudnn = None
  14. Timing = NamedTuple('timing', [('mean', float)])
  15. from einops import rearrange, repeat
  16. # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
  17. from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
  18. from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
  19. from flash_attn_interface import flash_attn_func as flash_attn_func_v3
  20. # from flash_attn_interface import flash_attn_with_kvcache as flash_attn_func_v3
  21. from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
  22. from triton.testing import do_bench
  23. try:
  24. from triton_fused_attention import attention as triton_attention
  25. except ImportError:
  26. triton_attention = None
  27. triton_attention = None
  28. def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs):
  29. # # Warmup
  30. # for _ in range(5):
  31. # func(*args, **kwargs)
  32. # time.sleep(1)
  33. # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1]
  34. # s = torch.cuda.Stream()
  35. # s.wait_stream(torch.cuda.current_stream())
  36. # with torch.cuda.stream(s):
  37. # for _ in range(2):
  38. # out = func(*args, **kwargs)
  39. # torch.cuda.current_stream().wait_stream(s)
  40. # graph = torch.cuda.CUDAGraph()
  41. # with torch.cuda.graph(graph):
  42. # out = func(*args, **kwargs)
  43. # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc)
  44. # # return time_f[1].mean
  45. # return time_f[1]
  46. return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3)
  47. def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size=(-1, -1)):
  48. if causal:
  49. avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2
  50. else:
  51. if window_size == (-1, -1):
  52. avg_seqlen = seqlen_k
  53. else:
  54. row_idx = torch.arange(seqlen_q, device='cuda')
  55. col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0))
  56. col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1))
  57. avg_seqlen = (col_right - col_left + 1).float().mean().item()
  58. return batch * nheads * 2 * seqlen_q * avg_seqlen * headdim * 2
  59. def convert_to_cudnn_type(torch_type):
  60. if torch_type == torch.float16:
  61. return cudnn.data_type.HALF
  62. elif torch_type == torch.bfloat16:
  63. return cudnn.data_type.BFLOAT16
  64. elif torch_type == torch.float32:
  65. return cudnn.data_type.FLOAT
  66. elif torch_type == torch.int32:
  67. return cudnn.data_type.INT32
  68. elif torch_type == torch.int64:
  69. return cudnn.data_type.INT64
  70. else:
  71. raise ValueError("Unsupported tensor data type.")
  72. def cudnn_spda_setup(q, k, v, causal=False, window_size_left=-1):
  73. b, nheads, seqlen_q, headdim = q.shape
  74. _, nheads_k, seqlen_k, _ = k.shape
  75. assert v.shape == (b, nheads_k, seqlen_k, headdim)
  76. assert cudnn is not None, 'CUDNN is not available'
  77. q_gpu, k_gpu, v_gpu = q, k, v
  78. o_gpu = torch.empty_like(q_gpu)
  79. stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
  80. graph = cudnn.pygraph(
  81. io_data_type=convert_to_cudnn_type(q.dtype),
  82. intermediate_data_type=cudnn.data_type.FLOAT,
  83. compute_data_type=cudnn.data_type.FLOAT,
  84. )
  85. q = graph.tensor_like(q_gpu.detach())
  86. k = graph.tensor_like(k_gpu.detach())
  87. v = graph.tensor_like(v_gpu.detach())
  88. o, stats = graph.sdpa(
  89. name="sdpa",
  90. q=q,
  91. k=k,
  92. v=v,
  93. is_inference=False,
  94. attn_scale=1.0 / math.sqrt(headdim),
  95. # use_causal_mask_bottom_right=causal or window_size_left >= 0,
  96. use_causal_mask=causal or window_size_left >= 0,
  97. sliding_window_length=window_size_left if window_size_left >= 0 and not causal else None,
  98. )
  99. o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
  100. stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
  101. graph.validate()
  102. graph.build_operation_graph()
  103. graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
  104. graph.check_support()
  105. graph.build_plans()
  106. variant_pack = {
  107. q: q_gpu,
  108. k: k_gpu,
  109. v: v_gpu,
  110. o: o_gpu,
  111. stats: stats_gpu,
  112. }
  113. workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
  114. def run(*args, **kwargs):
  115. graph.execute(variant_pack, workspace)
  116. return o_gpu
  117. return run
  118. def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=-1):
  119. b, nheads, seqlen_q, headdim = q.shape
  120. _, nheads_k, seqlen_k, _ = k.shape
  121. assert v.shape == (b, nheads_k, seqlen_k, headdim)
  122. assert g.shape == (b, nheads, seqlen_q, headdim)
  123. assert o.shape == (b, nheads, seqlen_q, headdim)
  124. assert lse.shape == (b, nheads, seqlen_q, 1)
  125. assert cudnn is not None, 'CUDNN is not available'
  126. q_gpu, k_gpu, v_gpu, o_gpu, g_gpu = q, k, v, o, g
  127. dq_gpu = torch.empty_like(q_gpu)
  128. dk_gpu = torch.empty_like(k_gpu)
  129. dv_gpu = torch.empty_like(v_gpu)
  130. graph = cudnn.pygraph(
  131. io_data_type=convert_to_cudnn_type(q.dtype),
  132. intermediate_data_type=cudnn.data_type.FLOAT,
  133. compute_data_type=cudnn.data_type.FLOAT,
  134. )
  135. q = graph.tensor_like(q_gpu.detach())
  136. k = graph.tensor_like(k_gpu.detach())
  137. v = graph.tensor_like(v_gpu.detach())
  138. o = graph.tensor_like(o_gpu.detach())
  139. g = graph.tensor_like(g_gpu.detach())
  140. stats = graph.tensor_like(lse.detach())
  141. dq, dk, dv = graph.sdpa_backward(
  142. name="sdpa_backward",
  143. q=q,
  144. k=k,
  145. v=v,
  146. o=o,
  147. dO=g,
  148. stats=stats,
  149. attn_scale=1.0 / math.sqrt(headdim),
  150. # use_causal_mask_bottom_right=causal or window_size_left >= 0,
  151. use_causal_mask=causal or window_size_left >= 0,
  152. sliding_window_length=window_size_left if window_size_left >= 0 and not causal else None,
  153. )
  154. dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride())
  155. dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride())
  156. dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride())
  157. graph.validate()
  158. graph.build_operation_graph()
  159. graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
  160. graph.check_support()
  161. graph.build_plans()
  162. variant_pack = {
  163. q: q_gpu,
  164. k: k_gpu,
  165. v: v_gpu,
  166. o: o_gpu,
  167. g: g_gpu,
  168. stats: lse,
  169. dq: dq_gpu,
  170. dk: dk_gpu,
  171. dv: dv_gpu,
  172. }
  173. workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
  174. def run(*args, **kwargs):
  175. graph.execute(variant_pack, workspace)
  176. return dq_gpu, dk_gpu, dv_gpu
  177. return run
  178. torch.manual_seed(0)
  179. repeats = 10
  180. dropout_p = 0.0
  181. causal = False
  182. dtype = torch.bfloat16
  183. # dtype = torch.float8_e4m3fn
  184. dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
  185. device = 'cuda'
  186. verbose = True
  187. varlen = False
  188. page_size = None
  189. softcap = 0.0
  190. V_colmajor = False
  191. deterministic = False
  192. batch_size = 2
  193. # seqlen = 2048
  194. seqlen = 8192
  195. # seqlen = 4096
  196. # seqlen = 2047
  197. dim = 2048
  198. # headdim = 128
  199. # headdim = 64
  200. headdim = 256
  201. # for headdim in [64, 128, 256]:
  202. # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
  203. # bs_seqlen_vals = [(16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
  204. # bs_seqlen_vals = [(32, 512), (16, 1024)]
  205. # bs_seqlen_vals = [(2, 64 * 132)]
  206. bs_seqlen_vals = [(2, 8192)]
  207. # bs_seqlen_vals = [(1, 16 * 1024)]
  208. time_f = {}
  209. time_b = {}
  210. # tflops_matmul = {}
  211. # m, n = 8192, 8192
  212. # for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:
  213. # a = torch.randn(m, k, device=device, dtype=dtype)
  214. # b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
  215. # nFLOPS_matmul = 2 * m * n * k
  216. # m5 = time_fwd(torch.matmul, a, b, desc='cuBLAS')
  217. # print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS')
  218. # tflops_matmul[k] = nFLOPS_matmul / m5.mean * 1e-12
  219. # # import pickle
  220. # # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp:
  221. # # with open(f'flash3_matmul_tflops_h100.plk', 'wb') as fp:
  222. # # pickle.dump(tflops_matmul, fp, protocol=pickle.HIGHEST_PROTOCOL)
  223. # exit(0)
  224. # for headdim in [64, 128, 256]:
  225. # for headdim in [64, 96, 128, 192]:
  226. # for headdim in [64, 96, 128, 192, 256]:
  227. # for headdim in [64, 96, 128]:
  228. # for headdim in [64, 128, 256]:
  229. # for headdim in [64, 96, 128, 192, 256]:
  230. for headdim in [128]:
  231. nheads = dim // headdim
  232. # headdim = 64
  233. # batch_size = 64
  234. # seqlen = 512
  235. # nheads = 8
  236. # headdim = 128
  237. nheads_kv = nheads
  238. # nheads_kv = nheads // 4
  239. for batch_size, seqlen in bs_seqlen_vals:
  240. num_splits = 1
  241. window_size = (-1, -1)
  242. # window_size = (seqlen // 2 - 1, 0)
  243. sink_token_length = 0
  244. pack_gqa = None
  245. # seqlen_q = 64
  246. seqlen_q = seqlen
  247. leftpad_k = None
  248. # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32)
  249. q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
  250. k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True)
  251. v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True)
  252. q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]]
  253. v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_()
  254. v_fa3 = v if not V_colmajor else v_colmajor
  255. # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
  256. # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
  257. # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
  258. g = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
  259. o = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
  260. stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32)
  261. a = torch.randn(batch_size, seqlen, seqlen, device=device, dtype=dtype_gen)
  262. b = torch.randn(batch_size, dim * 2, seqlen, device=device, dtype=dtype_gen).transpose(-1, -2)
  263. # x = torch.randn(batch_size * seqlen, 4096, device=device, dtype=dtype)
  264. # w = torch.randn(4096 * 2, 4096, device=device, dtype=dtype).transpose(-1, -2)
  265. if varlen:
  266. q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_() for x in [q, k, v]]
  267. cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q
  268. cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen
  269. # cu_seqlens_q = torch.tensor([0, 248, 249, 250, 251, 252, 253, 254, 255, 256], device=device, dtype=torch.int32)
  270. # q_unpad = q_unpad[:256]
  271. # seqlen_q = 256
  272. # cu_seqlens_q = torch.tensor([0, 376, 377, 378, 379, 380, 381, 382, 383, 384], device=device, dtype=torch.int32)
  273. # q_unpad = q_unpad[:384]
  274. # seqlen_q = 384
  275. if page_size is not None:
  276. assert seqlen % page_size == 0
  277. k_paged, v_paged = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k, v]]
  278. page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32),
  279. "(b s) -> b s", s=seqlen // page_size)
  280. else:
  281. page_table = None
  282. for causal in [False, True]:
  283. # for causal in [False]:
  284. print(f"\n### {headdim = }, {causal = }, {seqlen = } ###")
  285. nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, causal=causal, window_size=window_size)
  286. if cudnn is not None:
  287. # if False:
  288. if headdim <= 256 and dtype != torch.float8_e4m3fn:
  289. cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0])
  290. cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0])
  291. # _, m0 = benchmark_forward(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
  292. if dtype != torch.float8_e4m3fn:
  293. # if False:
  294. if not varlen:
  295. m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
  296. else:
  297. m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
  298. time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = m0.mean
  299. time.sleep(1)
  300. if not varlen:
  301. _, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
  302. repeats=repeats, verbose=verbose, desc='Fav2')
  303. else:
  304. _, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
  305. repeats=repeats, verbose=verbose, desc='Fav2')
  306. time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean
  307. # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True)
  308. if headdim <= 256 and dtype != torch.float8_e4m3fn:
  309. if triton_attention is not None:
  310. qt, kt, vt = [x.detach().transpose(1, 2).contiguous().requires_grad_() for x in [q, k, v]]
  311. time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
  312. m3 = time_fwd(triton_attention, qt, kt, vt, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
  313. time_f[(causal, headdim, batch_size, seqlen), "Triton"] = m3.mean
  314. # if causal: # triton bwd only works w causal for now
  315. # time.sleep(1)
  316. # _, m3b = benchmark_backward(triton_attention, qt, kt, vt, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
  317. # time_b[(causal, headdim, batch_size, seqlen), "Triton"] = m3b.mean
  318. # # pytorch_profiler(triton_attention, q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), v.transpose(1, 2).contiguous(), causal, 1 / math.sqrt(headdim), backward=True)
  319. if cudnn is not None:
  320. # if False:
  321. if headdim <= 256 and dtype != torch.float8_e4m3fn:
  322. time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
  323. m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN')
  324. time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean
  325. time.sleep(1)
  326. m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
  327. time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean
  328. # pytorch_profiler(cudnn_spda, backward=False)
  329. # pytorch_profiler(cudnn_spda_bwd, backward=False)
  330. time.sleep(1)
  331. if not varlen:
  332. # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
  333. m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
  334. # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)
  335. else:
  336. m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
  337. # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)
  338. time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean
  339. # time.sleep(1)
  340. # m5 = time_fwd(torch.bmm, a, b, desc='cuBLAS', repeats=repeats, verbose=False)
  341. # nFLOPS_matmul = nFLOPS
  342. # nFLOPS_matmul = 2 * x.shape[0] * x.shape[1] * w.shape[1]
  343. # m5 = time_fwd(torch.matmul, x, w, desc='cuBLAS')
  344. if dtype != torch.float8_e4m3fn:
  345. time.sleep(1)
  346. if not varlen:
  347. _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic,
  348. repeats=repeats, verbose=verbose, desc='Fav3')
  349. else:
  350. _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
  351. repeats=repeats, verbose=verbose, desc='Fav3')
  352. time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean
  353. # time.sleep(1)
  354. # if not varlen:
  355. # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True)
  356. # else:
  357. # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True)
  358. # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy')
  359. if dtype != torch.float8_e4m3fn:
  360. # if False:
  361. print(f'Fav2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS')
  362. print(f'Fav2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS')
  363. if headdim <= 256 and dtype != torch.float8_e4m3fn:
  364. if triton_attention is not None:
  365. print(f'Triton fwd: {m3.mean * 1e3:.3f}ms, {(nFLOPS / m3.mean * 1e-12):.1f} TFLOPS')
  366. # if causal:
  367. # print(f'Triton bwd: {m3b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m3b.mean * 1e-12):.1f} TFLOPS')
  368. if cudnn is not None:
  369. print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS')
  370. print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS')
  371. print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS')
  372. if dtype != torch.float8_e4m3fn:
  373. print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS')
  374. # benchmark_forward(torch.square, k)
  375. # print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS')
  376. # print(time_f)
  377. # print(time_b)
  378. # import pickle
  379. # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp:
  380. # # with open(f'flash3_attn_time_h100_cudnn_triton_20241208.plk', 'wb') as fp:
  381. # with open(f'flash3_attn_time_h100_fa3_20241208.plk', 'wb') as fp:
  382. # # with open(f'flash3_attn_time_h100_fp8_hdim{headdim}.plk', 'wb') as fp:
  383. # # with open(f'flash3_attn_time_h100_hdim{headdim}_1031.plk', 'wb') as fp:
  384. # pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)