benchmark_attn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. from functools import partial
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import time
  7. try:
  8. import cudnn
  9. except ImportError:
  10. cudnn = None
  11. from einops import rearrange, repeat
  12. # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
  13. from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
  14. from flash_attn.flash_attn_interface import flash_attn_func
  15. from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3
  16. # Need to install triton nightly:
  17. # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
  18. try:
  19. from triton_fused_attention import attention as triton_attention
  20. except ImportError:
  21. triton_attention = None
  22. def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, mode='fwd'):
  23. assert mode in ["fwd", "bwd", "fwd_bwd"]
  24. f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
  25. return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
  26. def convert_to_cudnn_type(torch_type):
  27. if torch_type == torch.float16:
  28. return cudnn.data_type.HALF
  29. elif torch_type == torch.bfloat16:
  30. return cudnn.data_type.BFLOAT16
  31. elif torch_type == torch.float32:
  32. return cudnn.data_type.FLOAT
  33. elif torch_type == torch.int32:
  34. return cudnn.data_type.INT32
  35. elif torch_type == torch.int64:
  36. return cudnn.data_type.INT64
  37. else:
  38. raise ValueError("Unsupported tensor data type.")
  39. def cudnn_sdpa_setup(q, k, v, grad, o, stats, causal=False, varlen=False, seqlens=None):
  40. b, nheads, seqlen_q, headdim = q.shape
  41. _, nheads_kv, seqlen_k, _ = k.shape
  42. assert v.shape == (b, nheads_kv, seqlen_k, headdim)
  43. assert cudnn is not None, 'CUDNN is not available'
  44. q_gpu, k_gpu, v_gpu = q, k, v
  45. o_gpu, stats_gpu = o, stats
  46. graph_forward = cudnn.pygraph(
  47. io_data_type=convert_to_cudnn_type(q.dtype),
  48. intermediate_data_type=cudnn.data_type.FLOAT,
  49. compute_data_type=cudnn.data_type.FLOAT,
  50. )
  51. q_forward = graph_forward.tensor_like(q_gpu.detach())
  52. k_forward = graph_forward.tensor_like(k_gpu.detach())
  53. v_forward = graph_forward.tensor_like(v_gpu.detach())
  54. seqlens_reshaped = seqlens if varlen else None
  55. seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
  56. seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
  57. o_forward, stats_forward = graph_forward.sdpa(
  58. name="sdpa",
  59. q=q_forward,
  60. k=k_forward,
  61. v=v_forward,
  62. is_inference=False,
  63. attn_scale=1.0 / math.sqrt(headdim),
  64. use_causal_mask=causal,
  65. use_padding_mask=varlen,
  66. seq_len_q=seq_len_q,
  67. seq_len_kv=seq_len_kv,
  68. )
  69. o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
  70. stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT)
  71. graph_forward.validate()
  72. graph_forward.build_operation_graph()
  73. graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
  74. graph_forward.check_support()
  75. graph_forward.build_plans()
  76. variant_pack_forward = {
  77. q_forward: q_gpu,
  78. k_forward: k_gpu,
  79. v_forward: v_gpu,
  80. o_forward: o_gpu,
  81. stats_forward: stats_gpu,
  82. seq_len_q: seqlens_reshaped,
  83. seq_len_kv: seqlens_reshaped,
  84. }
  85. dQ_gpu = torch.empty_like(q_gpu)
  86. dK_gpu = torch.empty_like(k_gpu)
  87. dV_gpu = torch.empty_like(v_gpu)
  88. dO_gpu = grad
  89. graph_backward = cudnn.pygraph(
  90. io_data_type=cudnn.data_type.HALF,
  91. intermediate_data_type=cudnn.data_type.FLOAT,
  92. compute_data_type=cudnn.data_type.FLOAT,
  93. )
  94. q_backward = graph_backward.tensor_like(q_gpu.detach())
  95. k_backward = graph_backward.tensor_like(k_gpu.detach())
  96. v_backward = graph_backward.tensor_like(v_gpu.detach())
  97. o_backward = graph_backward.tensor_like(o_gpu.detach())
  98. dO_backward = graph_backward.tensor_like(dO_gpu.detach())
  99. stats_backward = graph_backward.tensor_like(stats_gpu.detach())
  100. seq_len_q = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
  101. seq_len_kv = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None
  102. dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
  103. name="sdpa_backward",
  104. q=q_backward,
  105. k=k_backward,
  106. v=v_backward,
  107. o=o_backward,
  108. dO=dO_backward,
  109. stats=stats_backward,
  110. attn_scale=1.0 / math.sqrt(headdim),
  111. use_causal_mask=causal,
  112. use_padding_mask=varlen,
  113. seq_len_q=seq_len_q,
  114. seq_len_kv=seq_len_kv,
  115. )
  116. dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
  117. dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
  118. dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())
  119. graph_backward.validate()
  120. graph_backward.build_operation_graph()
  121. graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
  122. graph_backward.check_support()
  123. graph_backward.build_plans()
  124. variant_pack_backward = {
  125. q_backward: q_gpu,
  126. k_backward: k_gpu,
  127. v_backward: v_gpu,
  128. o_backward: o_gpu,
  129. dO_backward: dO_gpu,
  130. stats_backward: stats_gpu,
  131. dQ_backward: dQ_gpu,
  132. dK_backward: dK_gpu,
  133. dV_backward: dV_gpu,
  134. seq_len_q: seqlens_reshaped,
  135. seq_len_kv: seqlens_reshaped,
  136. }
  137. workspace = torch.empty(
  138. max(graph_forward.get_workspace_size(), graph_backward.get_workspace_size()),
  139. device="cuda", dtype=torch.uint8
  140. )
  141. def run_fwd(*args, **kwargs):
  142. graph_forward.execute(variant_pack_forward, workspace)
  143. return o_gpu, stats_gpu
  144. def run_bwd(*args, **kwargs):
  145. graph_backward.execute(variant_pack_backward, workspace)
  146. return dQ_gpu, dK_gpu, dV_gpu
  147. return run_fwd, run_bwd
  148. torch.manual_seed(0)
  149. repeats = 100
  150. dropout_p = 0.0
  151. causal = False
  152. dtype = torch.float16
  153. device = 'cuda'
  154. verbose = False
  155. batch_size = 2
  156. # seqlen = 2048
  157. seqlen = 8192
  158. # seqlen = 4096
  159. # seqlen = 2047
  160. dim = 2048
  161. # headdim = 128
  162. # headdim = 64
  163. headdim = 256
  164. for mode in ['fwd', 'bwd']:
  165. # for mode in ['bwd']:
  166. for headdim in [64, 128, 256]:
  167. # for headdim in [128]:
  168. for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
  169. # for seqlen in [8192]:
  170. nheads = dim // headdim
  171. # nheads = 24
  172. # headdim = 64
  173. # batch_size = 64
  174. # seqlen = 512
  175. # nheads = 8
  176. # headdim = 128
  177. # nheads = 16
  178. # headdim = 128
  179. nheads_kv = nheads
  180. # nheads_kv = 1
  181. qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
  182. requires_grad=True)
  183. q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
  184. k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
  185. v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
  186. q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
  187. k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
  188. v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
  189. grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
  190. grad_t = grad.transpose(1, 2).contiguous()
  191. o_t = torch.empty_like(q.transpose(1, 2))
  192. stats = torch.empty(batch_size, nheads, seqlen, 1, dtype=torch.float32, device=q.device)
  193. bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)
  194. for causal in [False, True]:
  195. # for causal in [True]:
  196. print(f"\n### {mode = }, {batch_size = }, {headdim = }, {seqlen = }, {causal = } ###")
  197. # For var-seq-len
  198. lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
  199. seqlens_cudnn = lens.reshape(batch_size, 1, 1, 1).contiguous().cuda()
  200. cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
  201. if headdim <= 128 and cudnn is not None:
  202. cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal)
  203. cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal, varlen=True, seqlens=seqlens_cudnn)
  204. f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
  205. ref_o = flash_attn_func(q, k, v, dropout_p, causal=causal)
  206. _, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
  207. if mode == 'bwd':
  208. ref_dv, v.grad = v.grad.clone(), None
  209. ref_dk, k.grad = k.grad.clone(), None
  210. ref_dq, q.grad = q.grad.clone(), None
  211. # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
  212. if headdim <= 128:
  213. if triton_attention is not None and nheads_kv == nheads:
  214. if mode == 'fwd':
  215. time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
  216. _, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
  217. # TODO: fix Triton numeric errors.
  218. # if mode == 'bwd':
  219. # dv, v_t.grad = v_t.grad.clone(), None
  220. # dk, k_t.grad = k_t.grad.clone(), None
  221. # dq, q_t.grad = q_t.grad.clone(), None
  222. # torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
  223. # torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
  224. # torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
  225. if cudnn is not None:
  226. time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
  227. if mode == 'fwd':
  228. _, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
  229. _, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
  230. cudnn_sdpa_fwd()
  231. torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
  232. cudnn_sdpa_fwd_varlen()
  233. torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
  234. else:
  235. cudnn_sdpa_fwd()
  236. _, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
  237. _, m2_var = benchmark_forward(cudnn_sdpa_bwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
  238. dq, dk, dv = cudnn_sdpa_bwd()
  239. torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
  240. torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
  241. torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
  242. dq, dk, dv = cudnn_sdpa_bwd_varlen()
  243. torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
  244. torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
  245. torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
  246. # pytorch_profiler(cudnn_sdpa, backward=False)
  247. if headdim <= 128 or mode == 'fwd':
  248. time.sleep(1)
  249. _, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
  250. q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
  251. k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
  252. v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
  253. time.sleep(1)
  254. if mode == 'bwd':
  255. dv, v.grad = v.grad.clone(), None
  256. dk, k.grad = k.grad.clone(), None
  257. dq, q.grad = q.grad.clone(), None
  258. torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05)
  259. torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
  260. torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)
  261. bench_var_fn = bench_fn
  262. if mode == 'bwd':
  263. grad_var = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
  264. bench_var_fn = partial(benchmark_backward, grad=grad_var)
  265. _, m1_var = bench_var_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
  266. # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
  267. print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
  268. if headdim <= 128:
  269. if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads:
  270. print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
  271. if cudnn is not None:
  272. print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
  273. print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS')
  274. if headdim <= 128 or mode == 'fwd':
  275. print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
  276. print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')