1
0

benchmark_attn.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. from functools import partial
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import time
  7. try:
  8. import cudnn
  9. except ImportError:
  10. cudnn = None
  11. from einops import rearrange, repeat
  12. # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
  13. from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
  14. from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
  15. # from flash_attn_interface import flash_attn_func as flash_attn_func_v3
  16. from flash_attn_interface import flash_attn_with_kvcache as flash_attn_func_v3
  17. from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
  18. try:
  19. from triton_fused_attention import attention as triton_attention
  20. except ImportError:
  21. triton_attention = None
  22. triton_attention = None
  23. def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs):
  24. # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1]
  25. s = torch.cuda.Stream()
  26. s.wait_stream(torch.cuda.current_stream())
  27. with torch.cuda.stream(s):
  28. for _ in range(2):
  29. out = func(*args, **kwargs)
  30. torch.cuda.current_stream().wait_stream(s)
  31. graph = torch.cuda.CUDAGraph()
  32. with torch.cuda.graph(graph):
  33. out = func(*args, **kwargs)
  34. time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc)
  35. # return time_f[1].mean
  36. return time_f[1]
  37. def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size=(-1, -1)):
  38. if causal:
  39. avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2
  40. else:
  41. if window_size == (-1, -1):
  42. avg_seqlen = seqlen_k
  43. else:
  44. row_idx = torch.arange(seqlen_q, device='cuda')
  45. col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0))
  46. col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1))
  47. avg_seqlen = (col_right - col_left + 1).float().mean().item()
  48. return batch * nheads * 2 * seqlen_q * avg_seqlen * headdim * 2
  49. def convert_to_cudnn_type(torch_type):
  50. if torch_type == torch.float16:
  51. return cudnn.data_type.HALF
  52. elif torch_type == torch.bfloat16:
  53. return cudnn.data_type.BFLOAT16
  54. elif torch_type == torch.float32:
  55. return cudnn.data_type.FLOAT
  56. elif torch_type == torch.int32:
  57. return cudnn.data_type.INT32
  58. elif torch_type == torch.int64:
  59. return cudnn.data_type.INT64
  60. else:
  61. raise ValueError("Unsupported tensor data type.")
  62. def cudnn_spda_setup(q, k, v, causal=False, window_size_left=-1):
  63. b, nheads, seqlen_q, headdim = q.shape
  64. _, nheads_k, seqlen_k, _ = k.shape
  65. assert v.shape == (b, nheads_k, seqlen_k, headdim)
  66. assert cudnn is not None, 'CUDNN is not available'
  67. q_gpu, k_gpu, v_gpu = q, k, v
  68. o_gpu = torch.empty_like(q_gpu)
  69. stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
  70. graph = cudnn.pygraph(
  71. io_data_type=convert_to_cudnn_type(q.dtype),
  72. intermediate_data_type=cudnn.data_type.FLOAT,
  73. compute_data_type=cudnn.data_type.FLOAT,
  74. )
  75. q = graph.tensor_like(q_gpu.detach())
  76. k = graph.tensor_like(k_gpu.detach())
  77. v = graph.tensor_like(v_gpu.detach())
  78. o, stats = graph.sdpa(
  79. name="sdpa",
  80. q=q,
  81. k=k,
  82. v=v,
  83. is_inference=False,
  84. attn_scale=1.0 / math.sqrt(headdim),
  85. # use_causal_mask_bottom_right=causal or window_size_left >= 0,
  86. use_causal_mask=causal or window_size_left >= 0,
  87. sliding_window_length=window_size_left if window_size_left >= 0 and not causal else None,
  88. )
  89. o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
  90. stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
  91. graph.validate()
  92. graph.build_operation_graph()
  93. graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
  94. graph.check_support()
  95. graph.build_plans()
  96. variant_pack = {
  97. q: q_gpu,
  98. k: k_gpu,
  99. v: v_gpu,
  100. o: o_gpu,
  101. stats: stats_gpu,
  102. }
  103. workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
  104. def run(*args, **kwargs):
  105. graph.execute(variant_pack, workspace)
  106. return o_gpu
  107. return run
  108. def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=-1):
  109. b, nheads, seqlen_q, headdim = q.shape
  110. _, nheads_k, seqlen_k, _ = k.shape
  111. assert v.shape == (b, nheads_k, seqlen_k, headdim)
  112. assert g.shape == (b, nheads, seqlen_q, headdim)
  113. assert o.shape == (b, nheads, seqlen_q, headdim)
  114. assert lse.shape == (b, nheads, seqlen_q, 1)
  115. assert cudnn is not None, 'CUDNN is not available'
  116. q_gpu, k_gpu, v_gpu, o_gpu, g_gpu = q, k, v, o, g
  117. dq_gpu = torch.empty_like(q_gpu)
  118. dk_gpu = torch.empty_like(k_gpu)
  119. dv_gpu = torch.empty_like(v_gpu)
  120. graph = cudnn.pygraph(
  121. io_data_type=convert_to_cudnn_type(q.dtype),
  122. intermediate_data_type=cudnn.data_type.FLOAT,
  123. compute_data_type=cudnn.data_type.FLOAT,
  124. )
  125. q = graph.tensor_like(q_gpu.detach())
  126. k = graph.tensor_like(k_gpu.detach())
  127. v = graph.tensor_like(v_gpu.detach())
  128. o = graph.tensor_like(o_gpu.detach())
  129. g = graph.tensor_like(g_gpu.detach())
  130. stats = graph.tensor_like(lse.detach())
  131. dq, dk, dv = graph.sdpa_backward(
  132. name="sdpa_backward",
  133. q=q,
  134. k=k,
  135. v=v,
  136. o=o,
  137. dO=g,
  138. stats=stats,
  139. attn_scale=1.0 / math.sqrt(headdim),
  140. # use_causal_mask_bottom_right=causal or window_size_left >= 0,
  141. use_causal_mask=causal or window_size_left >= 0,
  142. sliding_window_length=window_size_left if window_size_left >= 0 and not causal else None,
  143. )
  144. dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride())
  145. dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride())
  146. dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride())
  147. graph.validate()
  148. graph.build_operation_graph()
  149. graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
  150. graph.check_support()
  151. graph.build_plans()
  152. variant_pack = {
  153. q: q_gpu,
  154. k: k_gpu,
  155. v: v_gpu,
  156. o: o_gpu,
  157. g: g_gpu,
  158. stats: lse,
  159. dq: dq_gpu,
  160. dk: dk_gpu,
  161. dv: dv_gpu,
  162. }
  163. workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
  164. def run(*args, **kwargs):
  165. graph.execute(variant_pack, workspace)
  166. return dq_gpu, dk_gpu, dv_gpu
  167. return run
  168. torch.manual_seed(0)
  169. repeats = 10
  170. dropout_p = 0.0
  171. causal = False
  172. dtype = torch.bfloat16
  173. # dtype = torch.float8_e4m3fn
  174. dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
  175. device = 'cuda'
  176. verbose = True
  177. varlen = False
  178. page_size = None
  179. softcap = 0.0
  180. V_colmajor = False
  181. deterministic = False
  182. batch_size = 2
  183. # seqlen = 2048
  184. seqlen = 8192
  185. # seqlen = 4096
  186. # seqlen = 2047
  187. dim = 2048
  188. # headdim = 128
  189. # headdim = 64
  190. headdim = 256
  191. # for headdim in [64, 128, 256]:
  192. # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
  193. # bs_seqlen_vals = [(32, 512), (16, 1024)]
  194. # bs_seqlen_vals = [(2, 64 * 132)]
  195. # bs_seqlen_vals = [(2 * 8, 8192)]
  196. bs_seqlen_vals = [(2, 8192)]
  197. # bs_seqlen_vals = [(1, 16 * 1024)]
  198. time_f = {}
  199. time_b = {}
  200. # tflops_matmul = {}
  201. # m, n = 8192, 8192
  202. # for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:
  203. # a = torch.randn(m, k, device=device, dtype=dtype)
  204. # b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
  205. # nFLOPS_matmul = 2 * m * n * k
  206. # m5 = time_fwd(torch.matmul, a, b, desc='cuBLAS')
  207. # print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS')
  208. # tflops_matmul[k] = nFLOPS_matmul / m5.mean * 1e-12
  209. # # import pickle
  210. # # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp:
  211. # # with open(f'flash3_matmul_tflops_h100.plk', 'wb') as fp:
  212. # # pickle.dump(tflops_matmul, fp, protocol=pickle.HIGHEST_PROTOCOL)
  213. # exit(0)
  214. # for headdim in [64, 128, 256]:
  215. # for headdim in [64, 96, 128, 192]:
  216. # for headdim in [64, 96, 128, 192, 256]:
  217. # for headdim in [64, 96, 128]:
  218. # for headdim in [64, 128, 256]:
  219. for headdim in [128]:
  220. nheads = dim // headdim
  221. # headdim = 64
  222. # batch_size = 64
  223. # seqlen = 512
  224. # nheads = 8
  225. # headdim = 128
  226. # nheads_kv = nheads
  227. nheads_kv = nheads // 4
  228. for batch_size, seqlen in bs_seqlen_vals:
  229. num_splits = 1
  230. window_size = (-1, -1)
  231. # window_size = (seqlen // 2 - 1, 0)
  232. sink_token_length = 0
  233. pack_gqa = None
  234. # seqlen_q = 64
  235. seqlen_q = seqlen
  236. leftpad_k = None
  237. # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32)
  238. q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
  239. k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True)
  240. v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True)
  241. q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]]
  242. v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_()
  243. v_fa3 = v if not V_colmajor else v_colmajor
  244. # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
  245. # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
  246. # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)
  247. g = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
  248. o = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
  249. stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32)
  250. a = torch.randn(batch_size, seqlen, seqlen, device=device, dtype=dtype_gen)
  251. b = torch.randn(batch_size, dim * 2, seqlen, device=device, dtype=dtype_gen).transpose(-1, -2)
  252. # x = torch.randn(batch_size * seqlen, 4096, device=device, dtype=dtype)
  253. # w = torch.randn(4096 * 2, 4096, device=device, dtype=dtype).transpose(-1, -2)
  254. if varlen:
  255. q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_() for x in [q, k, v]]
  256. cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q
  257. cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen
  258. # cu_seqlens_q = torch.tensor([0, 248, 249, 250, 251, 252, 253, 254, 255, 256], device=device, dtype=torch.int32)
  259. # q_unpad = q_unpad[:256]
  260. # seqlen_q = 256
  261. # cu_seqlens_q = torch.tensor([0, 376, 377, 378, 379, 380, 381, 382, 383, 384], device=device, dtype=torch.int32)
  262. # q_unpad = q_unpad[:384]
  263. # seqlen_q = 384
  264. if page_size is not None:
  265. assert seqlen % page_size == 0
  266. k_paged, v_paged = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k, v]]
  267. page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32),
  268. "(b s) -> b s", s=seqlen // page_size)
  269. else:
  270. page_table = None
  271. for causal in [False, True]:
  272. # for causal in [False]:
  273. print(f"\n### {headdim = }, {causal = }, {seqlen = } ###")
  274. nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, causal=causal, window_size=window_size)
  275. if cudnn is not None:
  276. # if False:
  277. if headdim <= 256 and dtype != torch.float8_e4m3fn:
  278. cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0])
  279. cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0])
  280. # _, m0 = benchmark_forward(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
  281. # if dtype != torch.float8_e4m3fn:
  282. if False:
  283. if not varlen:
  284. m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, repeats=repeats, verbose=verbose, desc='Fav2')
  285. else:
  286. m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, repeats=repeats, verbose=verbose, desc='Fav2')
  287. time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = m0.mean
  288. time.sleep(1)
  289. _, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, deterministic=deterministic,
  290. repeats=repeats, verbose=verbose, desc='Fav2')
  291. time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean
  292. # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True)
  293. if headdim <= 256 and dtype != torch.float8_e4m3fn:
  294. if triton_attention is not None:
  295. qt, kt, vt = [x.detach().transpose(1, 2).contiguous().requires_grad_() for x in [q, k, v]]
  296. time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
  297. _, m3 = benchmark_forward(triton_attention, qt, kt, vt, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
  298. time_f[(causal, headdim, batch_size, seqlen), "Triton"] = m3.mean
  299. # if causal: # triton bwd only works w causal for now
  300. # time.sleep(1)
  301. # _, m3b = benchmark_backward(triton_attention, qt, kt, vt, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
  302. # time_b[(causal, headdim, batch_size, seqlen), "Triton"] = m3b.mean
  303. # pytorch_profiler(triton_attention, q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), v.transpose(1, 2).contiguous(), causal, 1 / math.sqrt(headdim), backward=True)
  304. if cudnn is not None:
  305. # if False:
  306. if headdim <= 256 and dtype != torch.float8_e4m3fn:
  307. time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
  308. _, m2 = benchmark_forward(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN')
  309. # m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN')
  310. time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean
  311. time.sleep(1)
  312. _, m2b = benchmark_forward(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
  313. time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean
  314. # pytorch_profiler(cudnn_spda, backward=False)
  315. # pytorch_profiler(cudnn_spda_bwd, backward=False)
  316. time.sleep(1)
  317. if not varlen:
  318. m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
  319. # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)
  320. else:
  321. m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
  322. # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)
  323. time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean
  324. # # time.sleep(1)
  325. # # m5 = time_fwd(torch.bmm, a, b, desc='cuBLAS', repeats=repeats, verbose=False)
  326. # nFLOPS_matmul = nFLOPS
  327. # # nFLOPS_matmul = 2 * x.shape[0] * x.shape[1] * w.shape[1]
  328. # # m5 = time_fwd(torch.matmul, x, w, desc='cuBLAS')
  329. # if dtype != torch.float8_e4m3fn:
  330. # time.sleep(1)
  331. # if not varlen:
  332. # _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic,
  333. # repeats=repeats, verbose=verbose, desc='Fav3')
  334. # else:
  335. # _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
  336. # repeats=repeats, verbose=verbose, desc='Fav3')
  337. # time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean
  338. # # time.sleep(1)
  339. # # if not varlen:
  340. # # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True)
  341. # # else:
  342. # # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True)
  343. # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy')
  344. # if dtype != torch.float8_e4m3fn:
  345. if False:
  346. print(f'Fav2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS')
  347. print(f'Fav2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS')
  348. if headdim <= 256 and dtype != torch.float8_e4m3fn:
  349. if triton_attention is not None:
  350. print(f'Triton fwd: {m3.mean * 1e3:.3f}ms, {(nFLOPS / m3.mean * 1e-12):.1f} TFLOPS')
  351. # if causal:
  352. # print(f'Triton bwd: {m3b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m3b.mean * 1e-12):.1f} TFLOPS')
  353. if cudnn is not None:
  354. print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS')
  355. print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS')
  356. print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS')
  357. # if dtype != torch.float8_e4m3fn:
  358. # print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS')
  359. # benchmark_forward(torch.square, k)
  360. # print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS')
  361. # print(time_f)
  362. # print(time_b)
  363. # import pickle
  364. # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp:
  365. # # with open(f'flash3_attn_time_h100_hdim{headdim}.plk', 'wb') as fp:
  366. # with open(f'flash3_attn_time_h100_fp8_hdim{headdim}.plk', 'wb') as fp:
  367. # # with open(f'flash3_attn_time_h100_hdim{headdim}_1031.plk', 'wb') as fp:
  368. # pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)