benchmark_attn.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. from functools import partial
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import time
  7. try:
  8. import cudnn
  9. except ImportError:
  10. cudnn = None
  11. from einops import rearrange, repeat
  12. # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
  13. from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
  14. from flash_attn.flash_attn_interface import flash_attn_func
  15. from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3
  16. # Need to install triton nightly:
  17. # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
  18. try:
  19. from triton_fused_attention import attention as triton_attention
  20. except ImportError:
  21. triton_attention = None
  22. def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, mode='fwd'):
  23. assert mode in ["fwd", "bwd", "fwd_bwd"]
  24. f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
  25. return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
  26. def convert_to_cudnn_type(torch_type):
  27. if torch_type == torch.float16:
  28. return cudnn.data_type.HALF
  29. elif torch_type == torch.bfloat16:
  30. return cudnn.data_type.BFLOAT16
  31. elif torch_type == torch.float32:
  32. return cudnn.data_type.FLOAT
  33. elif torch_type == torch.int32:
  34. return cudnn.data_type.INT32
  35. elif torch_type == torch.int64:
  36. return cudnn.data_type.INT64
  37. else:
  38. raise ValueError("Unsupported tensor data type.")
  39. def cudnn_sdpa_setup(q, k, v, grad, causal=False):
  40. b, nheads, seqlen_q, headdim = q.shape
  41. _, _, seqlen_k, _ = k.shape
  42. assert v.shape == (b, nheads, seqlen_k, headdim)
  43. assert cudnn is not None, 'CUDNN is not available'
  44. q_gpu, k_gpu, v_gpu = q, k, v
  45. o_gpu = torch.empty_like(q_gpu)
  46. stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
  47. graph_forward = cudnn.pygraph(
  48. io_data_type=convert_to_cudnn_type(q.dtype),
  49. intermediate_data_type=cudnn.data_type.FLOAT,
  50. compute_data_type=cudnn.data_type.FLOAT,
  51. )
  52. q_forward = graph_forward.tensor_like(q_gpu.detach())
  53. k_forward = graph_forward.tensor_like(k_gpu.detach())
  54. v_forward = graph_forward.tensor_like(v_gpu.detach())
  55. o_forward, stats_forward = graph_forward.sdpa(
  56. name="sdpa",
  57. q=q_forward,
  58. k=k_forward,
  59. v=v_forward,
  60. is_inference=False,
  61. attn_scale=1.0 / math.sqrt(headdim),
  62. use_causal_mask=causal,
  63. )
  64. o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
  65. stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT)
  66. graph_forward.validate()
  67. graph_forward.build_operation_graph()
  68. graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
  69. graph_forward.check_support()
  70. graph_forward.build_plans()
  71. variant_pack_forward = {
  72. q_forward: q_gpu,
  73. k_forward: k_gpu,
  74. v_forward: v_gpu,
  75. o_forward: o_gpu,
  76. stats_forward: stats_gpu,
  77. }
  78. dQ_gpu = torch.empty_like(q_gpu)
  79. dK_gpu = torch.empty_like(k_gpu)
  80. dV_gpu = torch.empty_like(v_gpu)
  81. dO_gpu = grad
  82. graph_backward = cudnn.pygraph(
  83. io_data_type=cudnn.data_type.HALF,
  84. intermediate_data_type=cudnn.data_type.FLOAT,
  85. compute_data_type=cudnn.data_type.FLOAT,
  86. )
  87. q_backward = graph_backward.tensor_like(q_gpu.detach())
  88. k_backward = graph_backward.tensor_like(k_gpu.detach())
  89. v_backward = graph_backward.tensor_like(v_gpu.detach())
  90. o_backward = graph_backward.tensor_like(o_gpu.detach())
  91. dO_backward = graph_backward.tensor_like(dO_gpu.detach())
  92. stats_backward = graph_backward.tensor_like(stats_gpu.detach())
  93. dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
  94. name="sdpa_backward",
  95. q=q_backward,
  96. k=k_backward,
  97. v=v_backward,
  98. o=o_backward,
  99. dO=dO_backward,
  100. stats=stats_backward,
  101. attn_scale=1.0 / math.sqrt(headdim),
  102. use_causal_mask=causal,
  103. )
  104. dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
  105. dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
  106. dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())
  107. graph_backward.validate()
  108. graph_backward.build_operation_graph()
  109. graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
  110. graph_backward.check_support()
  111. graph_backward.build_plans()
  112. variant_pack_backward = {
  113. q_backward: q_gpu,
  114. k_backward: k_gpu,
  115. v_backward: v_gpu,
  116. o_backward: o_gpu,
  117. dO_backward: dO_gpu,
  118. stats_backward: stats_gpu,
  119. dQ_backward: dQ_gpu,
  120. dK_backward: dK_gpu,
  121. dV_backward: dV_gpu,
  122. }
  123. workspace = torch.empty(
  124. max(graph_forward.get_workspace_size(), graph_backward.get_workspace_size()),
  125. device="cuda", dtype=torch.uint8
  126. )
  127. def run_fwd(*args, **kwargs):
  128. graph_forward.execute(variant_pack_forward, workspace)
  129. return o_gpu, stats_gpu
  130. def run_bwd(*args, **kwargs):
  131. graph_backward.execute(variant_pack_backward, workspace)
  132. return dQ_gpu, dK_gpu, dV_gpu
  133. return run_fwd, run_bwd
  134. torch.manual_seed(0)
  135. repeats = 100
  136. dropout_p = 0.0
  137. causal = False
  138. dtype = torch.float16
  139. device = 'cuda'
  140. verbose = False
  141. batch_size = 2
  142. # seqlen = 2048
  143. seqlen = 8192
  144. # seqlen = 4096
  145. # seqlen = 2047
  146. dim = 2048
  147. # headdim = 128
  148. # headdim = 64
  149. headdim = 256
  150. # for mode in ['fwd', 'bwd']:
  151. for mode in ['fwd']:
  152. for headdim in [64, 128, 256]:
  153. # for headdim in [128]:
  154. for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
  155. # for seqlen in [8192]:
  156. nheads = dim // headdim
  157. # nheads = 24
  158. # headdim = 64
  159. # batch_size = 64
  160. # seqlen = 512
  161. # nheads = 8
  162. # headdim = 128
  163. nheads_kv = nheads
  164. qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
  165. requires_grad=True)
  166. q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
  167. k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
  168. v = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
  169. q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
  170. k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
  171. v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
  172. grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
  173. grad_t = grad.transpose(1, 2).contiguous()
  174. bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)
  175. for causal in [False, True]:
  176. # for causal in [True]:
  177. print(f"\n### {headdim = }, {seqlen = }, {causal = } ###")
  178. if headdim <= 128 and cudnn is not None:
  179. cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal)
  180. f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
  181. _, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
  182. if mode == 'bwd':
  183. ref_dv, v.grad = v.grad.clone(), None
  184. ref_dk, k.grad = k.grad.clone(), None
  185. ref_dq, q.grad = q.grad.clone(), None
  186. # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
  187. if headdim <= 128:
  188. if triton_attention is not None:
  189. if mode == 'fwd':
  190. time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
  191. _, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
  192. # TODO: fix Triton numeric errors.
  193. # if mode == 'bwd':
  194. # dv, v_t.grad = v_t.grad.clone(), None
  195. # dk, k_t.grad = k_t.grad.clone(), None
  196. # dq, q_t.grad = q_t.grad.clone(), None
  197. # torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
  198. # torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
  199. # torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
  200. if cudnn is not None:
  201. time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
  202. if mode == 'fwd':
  203. _, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
  204. else:
  205. cudnn_sdpa_fwd()
  206. _, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
  207. dq, dk, dv = cudnn_sdpa_bwd()
  208. torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
  209. torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
  210. torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
  211. # pytorch_profiler(cudnn_sdpa, backward=False)
  212. if headdim == 128 or mode == 'fwd':
  213. time.sleep(1)
  214. _, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
  215. q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
  216. k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
  217. v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
  218. lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
  219. cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
  220. time.sleep(1)
  221. _, m1_var = bench_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
  222. if mode == 'bwd':
  223. dv, v.grad = v.grad.clone(), None
  224. dk, k.grad = k.grad.clone(), None
  225. dq, q.grad = q.grad.clone(), None
  226. torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05)
  227. torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
  228. torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)
  229. # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
  230. print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
  231. if headdim <= 128:
  232. if triton_attention is not None:
  233. print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
  234. if cudnn is not None:
  235. print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
  236. if headdim == 128 or mode == 'fwd':
  237. print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
  238. print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')