benchmark_flash_attention_fp8.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # Install the newest triton version with
  2. # pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
  3. import pickle
  4. import math
  5. import time
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from einops import rearrange, repeat
  10. from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
  11. from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
  12. from flash_attn import flash_attn_qkvpacked_func
  13. from flash_attn_interface import flash_attn_func, _flash_attn_forward
  14. try:
  15. from triton_fused_attention import attention as attention_triton
  16. except ImportError:
  17. attention_triton = None
  18. try:
  19. import xformers.ops as xops
  20. except ImportError:
  21. xops = None
  22. try:
  23. import cudnn
  24. except ImportError:
  25. cudnn = None
  26. def convert_to_cudnn_type(torch_type):
  27. if torch_type == torch.float16:
  28. return cudnn.data_type.HALF
  29. elif torch_type == torch.bfloat16:
  30. return cudnn.data_type.BFLOAT16
  31. elif torch_type == torch.float32:
  32. return cudnn.data_type.FLOAT
  33. elif torch_type == torch.int32:
  34. return cudnn.data_type.INT32
  35. elif torch_type == torch.int64:
  36. return cudnn.data_type.INT64
  37. elif torch_type == torch.float8_e4m3fn:
  38. return cudnn.data_type.FP8_E4M3
  39. elif torch_type == torch.float8_e4m3fn:
  40. return cudnn.data_type.FP8_E5M2
  41. else:
  42. raise ValueError("Unsupported tensor data type.")
  43. def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False):
  44. b, _, _, nheads, headdim = qkv.shape
  45. assert cudnn is not None, 'CUDNN is not available'
  46. o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device)
  47. o_gpu_transposed = torch.as_strided(
  48. o_gpu,
  49. [b, nheads, seqlen_q, headdim],
  50. [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1],
  51. )
  52. stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=qkv.device)
  53. amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
  54. amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
  55. graph = cudnn.pygraph(
  56. io_data_type=convert_to_cudnn_type(qkv.dtype),
  57. intermediate_data_type=cudnn.data_type.FLOAT,
  58. compute_data_type=cudnn.data_type.FLOAT,
  59. )
  60. new_q = torch.as_strided(
  61. qkv,
  62. [b, nheads, seqlen_q, headdim],
  63. [seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
  64. storage_offset=0,
  65. )
  66. q = graph.tensor(
  67. name = "Q",
  68. dim = list(new_q.shape),
  69. stride = list(new_q.stride()),
  70. data_type=convert_to_cudnn_type(qkv.dtype)
  71. )
  72. new_k = torch.as_strided(
  73. qkv,
  74. [b, nheads, seqlen_k, headdim],
  75. [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
  76. storage_offset=nheads * headdim,
  77. )
  78. k = graph.tensor(
  79. name = "K",
  80. dim = list(new_k.shape),
  81. stride = list(new_k.stride()),
  82. data_type=convert_to_cudnn_type(qkv.dtype)
  83. )
  84. new_v = torch.as_strided(
  85. qkv,
  86. [b, nheads, seqlen_k, headdim],
  87. [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
  88. storage_offset=nheads * headdim * 2,
  89. )
  90. v = graph.tensor(
  91. name = "V",
  92. dim = list(new_v.shape),
  93. stride = list(new_v.stride()),
  94. data_type=convert_to_cudnn_type(qkv.dtype)
  95. )
  96. def get_default_scale_tensor():
  97. return graph.tensor(
  98. dim = [1, 1, 1, 1],
  99. stride = [1, 1, 1, 1],
  100. data_type=cudnn.data_type.FLOAT
  101. )
  102. default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda")
  103. descale_q = get_default_scale_tensor()
  104. descale_k = get_default_scale_tensor()
  105. descale_v = get_default_scale_tensor()
  106. descale_s = get_default_scale_tensor()
  107. scale_s = get_default_scale_tensor()
  108. scale_o = get_default_scale_tensor()
  109. o, _, amax_s, amax_o = graph.sdpa_fp8(
  110. q=q,
  111. k=k,
  112. v=v,
  113. descale_q=descale_q,
  114. descale_k=descale_k,
  115. descale_v=descale_v,
  116. descale_s=descale_s,
  117. scale_s=scale_s,
  118. scale_o=scale_o,
  119. is_inference=True,
  120. attn_scale=1.0 / math.sqrt(headdim),
  121. use_causal_mask=causal,
  122. name="sdpa",
  123. )
  124. o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride())
  125. amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride())
  126. amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride())
  127. # stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
  128. graph.validate()
  129. graph.build_operation_graph()
  130. graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
  131. graph.check_support()
  132. graph.build_plans()
  133. variant_pack = {
  134. q: new_q,
  135. k: new_k,
  136. v: new_v,
  137. descale_q: default_scale_gpu,
  138. descale_k: default_scale_gpu,
  139. descale_v: default_scale_gpu,
  140. descale_s: default_scale_gpu,
  141. scale_s: default_scale_gpu,
  142. scale_o: default_scale_gpu,
  143. o: o_gpu_transposed,
  144. amax_s: amax_s_gpu,
  145. amax_o: amax_o_gpu,
  146. }
  147. workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
  148. def run(*args, **kwargs):
  149. graph.execute(variant_pack, workspace)
  150. return o_gpu, amax_o_gpu
  151. return run
  152. def attention_pytorch(qkv, dropout_p=0.0, causal=True):
  153. """
  154. Arguments:
  155. qkv: (batch_size, seqlen, 3, nheads, head_dim)
  156. dropout_p: float
  157. Output:
  158. output: (batch_size, seqlen, nheads, head_dim)
  159. """
  160. batch_size, seqlen, _, nheads, d = qkv.shape
  161. q, k, v = qkv.unbind(dim=2)
  162. q = rearrange(q, 'b t h d -> (b h) t d')
  163. k = rearrange(k, 'b s h d -> (b h) d s')
  164. softmax_scale = 1.0 / math.sqrt(d)
  165. # Preallocate attn_weights for `baddbmm`
  166. scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
  167. scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
  168. '(b h) t s -> b h t s', h=nheads)
  169. if causal:
  170. # "triu_tril_cuda_template" not implemented for 'BFloat16'
  171. # So we have to construct the mask in float
  172. causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
  173. # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
  174. scores = scores + causal_mask.to(dtype=scores.dtype)
  175. attention = torch.softmax(scores, dim=-1)
  176. attention_drop = F.dropout(attention, dropout_p)
  177. output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
  178. return output.to(dtype=qkv.dtype)
  179. def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
  180. assert mode in ["fwd", "bwd", "fwd_bwd"]
  181. f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
  182. return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
  183. def efficiency(flop, time):
  184. return (flop / time / 10**12) if not math.isnan(time) else 0.0
  185. def time_fwd(func, *args, **kwargs):
  186. time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
  187. time_f = benchmark_forward(func, *args, **kwargs)
  188. return time_f[1].mean
  189. torch.manual_seed(0)
  190. repeats = 30
  191. device = 'cuda'
  192. # dtype = torch.float16
  193. dtype = torch.float8_e4m3fn
  194. # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
  195. bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
  196. # bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2)]
  197. # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)]
  198. causal_vals = [False, True]
  199. headdim_vals = [64, 128, 256]
  200. dim = 2048
  201. # dim = 256
  202. dropout_p = 0.0
  203. methods = (["Pytorch", "Flash3"]
  204. + (["cuDNN"] if cudnn is not None else [])
  205. # + (["Triton"] if attention_triton is not None else [])
  206. # + (["xformers.c"] if xops is not None else [])
  207. # + (["xformers.f"] if xops is not None else [])
  208. )
  209. time_f = {}
  210. time_b = {}
  211. time_f_b = {}
  212. speed_f = {}
  213. speed_b = {}
  214. speed_f_b = {}
  215. for causal in causal_vals:
  216. for headdim in headdim_vals:
  217. for batch_size, seqlen in bs_seqlen_vals:
  218. torch.cuda.empty_cache()
  219. config = (causal, headdim, batch_size, seqlen)
  220. nheads = dim // headdim
  221. q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16, requires_grad=False) for _ in range(3)]
  222. qkv = torch.stack([q, k, v], dim=2)
  223. qkv = qkv.to(torch.bfloat16)
  224. f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False)
  225. time_f[config, "Pytorch"] = f
  226. res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
  227. if attention_triton is not None:
  228. q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
  229. k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
  230. v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn)
  231. scale = 1 / math.sqrt(headdim)
  232. f = time_fwd(
  233. attention_triton, q_transposed, k_transposed, v_transposed,
  234. causal, scale, repeats=5, verbose=False, desc='Triton'
  235. )
  236. f = time_fwd(
  237. attention_triton, q_transposed, k_transposed, v_transposed,
  238. causal, scale, repeats=repeats, verbose=False, desc='Triton'
  239. )
  240. time_f[config, "Triton"] = f
  241. res = attention_triton(
  242. q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2),
  243. causal, scale
  244. ).half().transpose(1, 2)
  245. torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5)
  246. # out = torch.empty_like(q)
  247. q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)
  248. softmax_scale = q.shape[-1] ** (-0.5)
  249. descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda')
  250. descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda')
  251. descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda')
  252. # f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)
  253. f = time_fwd(
  254. _flash_attn_forward,
  255. q,
  256. k,
  257. v,
  258. softmax_scale,
  259. causal=causal,
  260. window_size=(-1,-1),
  261. descale_q=descale_q,
  262. descale_k=descale_k,
  263. descale_v=descale_v,
  264. repeats=repeats,
  265. verbose=False
  266. )
  267. # res = flash_attn_func(q, k, v, causal=causal)
  268. # torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05)
  269. time_f[config, "Flash3"] = f
  270. if cudnn is not None:
  271. qkv_fp8 = qkv.to(dtype)
  272. time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
  273. f = time_fwd(
  274. cudnn_spda_setup(
  275. qkv_fp8, seqlen, seqlen,
  276. causal=causal
  277. ),
  278. repeats=repeats, verbose=False
  279. )
  280. time_f[config, "cuDNN"] = f
  281. # res, amax_o = cudnn_spda_setup(
  282. # qkv_fp8, seqlen, seqlen,
  283. # causal=causal
  284. # )()
  285. # res = res.half()
  286. # TODO: CUDNN has numerics issues when
  287. # num_heads=16, dim=128, seq_len=1024, batch_size=2
  288. # or larger sizes.
  289. # res_cpu = res.cpu().reshape(-1)
  290. # res_baseline_cpu = res_baseline.cpu().reshape(-1)
  291. # print(amax_o)
  292. # print(res)
  293. # print(res_baseline)
  294. # for i in range(len(res_cpu)):
  295. # item = res_cpu[i]
  296. # item_baseline = res_baseline_cpu[i]
  297. # if abs(item - item_baseline) > 0.5:
  298. # print(i)
  299. # print(item)
  300. # print(item_baseline)
  301. # torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05)
  302. print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
  303. for method in methods:
  304. speed_f[config, method] = efficiency(
  305. flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
  306. time_f[config, method]
  307. )
  308. #print (time_f[config,method])
  309. print(
  310. f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, "
  311. )
  312. # with open('flash3_attn_time.plk', 'wb') as fp:
  313. # pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)