benchmark_flash_attention.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # Install the newest triton version with
  2. # pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
  3. import pickle
  4. import math
  5. import time
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from einops import rearrange, repeat
  10. from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
  11. from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
  12. from flash_attn import flash_attn_qkvpacked_func
  13. from flash_attn_interface import flash_attn_func
  14. try:
  15. from triton.ops.flash_attention import attention as attention_triton
  16. except ImportError:
  17. attention_triton = None
  18. try:
  19. import xformers.ops as xops
  20. except ImportError:
  21. xops = None
  22. try:
  23. import cudnn
  24. except ImportError:
  25. cudnn = None
  26. def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
  27. assert mode in ["fwd", "bwd", "fwd_bwd"]
  28. f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
  29. return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
  30. def efficiency(flop, time):
  31. return (flop / time / 10**12) if not math.isnan(time) else 0.0
  32. def convert_to_cudnn_type(torch_type):
  33. if torch_type == torch.float16:
  34. return cudnn.data_type.HALF
  35. elif torch_type == torch.bfloat16:
  36. return cudnn.data_type.BFLOAT16
  37. elif torch_type == torch.float32:
  38. return cudnn.data_type.FLOAT
  39. elif torch_type == torch.int32:
  40. return cudnn.data_type.INT32
  41. elif torch_type == torch.int64:
  42. return cudnn.data_type.INT64
  43. else:
  44. raise ValueError("Unsupported tensor data type.")
  45. def cudnn_spda_setup(q, k, v, causal=False):
  46. b, nheads, seqlen_q, headdim = q.shape
  47. _, _, seqlen_k, _ = k.shape
  48. assert v.shape == (b, nheads, seqlen_k, headdim)
  49. assert cudnn is not None, 'CUDNN is not available'
  50. q_gpu, k_gpu, v_gpu = q, k, v
  51. o_gpu = torch.empty_like(q_gpu)
  52. stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
  53. graph = cudnn.pygraph(
  54. io_data_type=convert_to_cudnn_type(q.dtype),
  55. intermediate_data_type=cudnn.data_type.FLOAT,
  56. compute_data_type=cudnn.data_type.FLOAT,
  57. )
  58. q = graph.tensor_like(q_gpu.detach())
  59. k = graph.tensor_like(k_gpu.detach())
  60. v = graph.tensor_like(v_gpu.detach())
  61. o, stats = graph.sdpa(
  62. name="sdpa",
  63. q=q,
  64. k=k,
  65. v=v,
  66. is_inference=False,
  67. attn_scale=1.0 / math.sqrt(headdim),
  68. use_causal_mask=causal,
  69. )
  70. o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
  71. stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
  72. graph.validate()
  73. graph.build_operation_graph()
  74. graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
  75. graph.check_support()
  76. graph.build_plans()
  77. variant_pack = {
  78. q: q_gpu,
  79. k: k_gpu,
  80. v: v_gpu,
  81. o: o_gpu,
  82. stats: stats_gpu,
  83. }
  84. workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
  85. def run(*args, **kwargs):
  86. graph.execute(variant_pack, workspace)
  87. return o_gpu
  88. return run
  89. def attention_pytorch(qkv, dropout_p=0.0, causal=True):
  90. """
  91. Arguments:
  92. qkv: (batch_size, seqlen, 3, nheads, head_dim)
  93. dropout_p: float
  94. Output:
  95. output: (batch_size, seqlen, nheads, head_dim)
  96. """
  97. batch_size, seqlen, _, nheads, d = qkv.shape
  98. q, k, v = qkv.unbind(dim=2)
  99. q = rearrange(q, 'b t h d -> (b h) t d')
  100. k = rearrange(k, 'b s h d -> (b h) d s')
  101. softmax_scale = 1.0 / math.sqrt(d)
  102. # Preallocate attn_weights for `baddbmm`
  103. scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
  104. scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
  105. '(b h) t s -> b h t s', h=nheads)
  106. if causal:
  107. # "triu_tril_cuda_template" not implemented for 'BFloat16'
  108. # So we have to construct the mask in float
  109. causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
  110. # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
  111. scores = scores + causal_mask.to(dtype=scores.dtype)
  112. attention = torch.softmax(scores, dim=-1)
  113. attention_drop = F.dropout(attention, dropout_p)
  114. output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
  115. return output.to(dtype=qkv.dtype)
  116. def time_fwd_bwd(func, *args, **kwargs):
  117. time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
  118. time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
  119. return time_f[1].mean, time_b[1].mean
  120. repeats = 30
  121. device = 'cuda'
  122. dtype = torch.float16
  123. # Ideally, seq-len should be divisible by 132 to avoid wave quantization.
  124. # However, the existing Triton implementation doesn't support seq-len like 8448.
  125. bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192)]
  126. # bs_seqlen_vals = [(2, 8192)]
  127. causal_vals = [False]
  128. # headdim_vals = [64, 128]
  129. headdim_vals = [128]
  130. dim = 128
  131. dropout_p = 0.0
  132. methods = (["Flash2", "Pytorch", "Flash3"]
  133. + (["Triton"] if attention_triton is not None else [])
  134. + (["xformers.c"] if xops is not None else [])
  135. + (["xformers.f"] if xops is not None else [])
  136. + (["cudnn"] if cudnn is not None else []))
  137. time_f = {}
  138. time_b = {}
  139. time_f_b = {}
  140. speed_f = {}
  141. speed_b = {}
  142. speed_f_b = {}
  143. for causal in causal_vals:
  144. for headdim in headdim_vals:
  145. for batch_size, seqlen in bs_seqlen_vals:
  146. config = (causal, headdim, batch_size, seqlen)
  147. nheads = dim // headdim
  148. qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
  149. requires_grad=True)
  150. f, b = time_fwd_bwd(
  151. flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
  152. )
  153. time_f[config, "Flash2"] = f
  154. time_b[config, "Flash2"] = b
  155. try:
  156. qkv = qkv.detach().requires_grad_(True)
  157. f, b = time_fwd_bwd(
  158. attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
  159. )
  160. res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
  161. except: # Skip if OOM
  162. f, b = float('nan'), float('nan')
  163. time_f[config, "Pytorch"] = f
  164. time_b[config, "Pytorch"] = b
  165. q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
  166. requires_grad=True) for _ in range(3)]
  167. f, b = time_fwd_bwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)
  168. res = flash_attn_func(q, k, v, causal=causal)
  169. time_f[config, "Flash3"] = f
  170. time_b[config, "Flash3"] = b
  171. if cudnn is not None:
  172. time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
  173. res = benchmark_forward(
  174. cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal),
  175. repeats=repeats, verbose=False
  176. )
  177. f = res[1].mean
  178. time_f[config, "cudnn"] = f
  179. time_b[config, "cudnn"] = math.inf
  180. if attention_triton is not None:
  181. q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
  182. requires_grad=True) for _ in range(3)]
  183. # Try both values of sequence_parallel and pick the faster one
  184. try:
  185. f, b = time_fwd_bwd(
  186. attention_triton, q, k, v, causal, headdim**(-0.5),
  187. False, repeats=repeats, verbose=False
  188. )
  189. except:
  190. f, b = float('nan'), float('inf')
  191. try:
  192. _, b0 = time_fwd_bwd(
  193. attention_triton, q, k, v, causal, headdim**(-0.5),
  194. True, repeats=repeats, verbose=False
  195. )
  196. except:
  197. b0 = float('inf')
  198. time_f[config, "Triton"] = f
  199. time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan')
  200. if xops is not None:
  201. q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
  202. requires_grad=True) for _ in range(3)]
  203. f, b = time_fwd_bwd(
  204. xops.memory_efficient_attention, q, k, v,
  205. attn_bias=xops.LowerTriangularMask() if causal else None,
  206. op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
  207. )
  208. time_f[config, "xformers.c"] = f
  209. time_b[config, "xformers.c"] = b
  210. if xops is not None:
  211. q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
  212. requires_grad=True) for _ in range(3)]
  213. f, b = time_fwd_bwd(
  214. xops.memory_efficient_attention, q, k, v,
  215. attn_bias=xops.LowerTriangularMask() if causal else None,
  216. op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
  217. )
  218. time_f[config, "xformers.f"] = f
  219. time_b[config, "xformers.f"] = b
  220. print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
  221. for method in methods:
  222. time_f_b[config, method] = time_f[config, method] + time_b[config, method]
  223. speed_f[config, method] = efficiency(
  224. flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
  225. time_f[config, method]
  226. )
  227. speed_b[config, method] = efficiency(
  228. flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
  229. time_b[config, method]
  230. )
  231. speed_f_b[config, method] = efficiency(
  232. flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
  233. time_f_b[config, method]
  234. )
  235. #print (time_f[config,method])
  236. print(
  237. f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
  238. f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
  239. f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
  240. )
  241. # with open('flash2_attn_time.plk', 'wb') as fp:
  242. # pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)