benchmark_flash_attention.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Install the newest triton version with
  2. # pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
  3. import pickle
  4. import math
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from einops import rearrange, repeat
  9. from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
  10. from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
  11. from flash_attn import flash_attn_qkvpacked_func
  12. try:
  13. from triton.ops.flash_attention import attention as attention_triton
  14. except ImportError:
  15. attention_triton = None
  16. try:
  17. import xformers.ops as xops
  18. except ImportError:
  19. xops = None
  20. def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
  21. assert mode in ["fwd", "bwd", "fwd_bwd"]
  22. f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
  23. return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
  24. def efficiency(flop, time):
  25. return (flop / time / 10**12) if not math.isnan(time) else 0.0
  26. def attention_pytorch(qkv, dropout_p=0.0, causal=True):
  27. """
  28. Arguments:
  29. qkv: (batch_size, seqlen, 3, nheads, head_dim)
  30. dropout_p: float
  31. Output:
  32. output: (batch_size, seqlen, nheads, head_dim)
  33. """
  34. batch_size, seqlen, _, nheads, d = qkv.shape
  35. q, k, v = qkv.unbind(dim=2)
  36. q = rearrange(q, 'b t h d -> (b h) t d')
  37. k = rearrange(k, 'b s h d -> (b h) d s')
  38. softmax_scale = 1.0 / math.sqrt(d)
  39. # Preallocate attn_weights for `baddbmm`
  40. scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
  41. scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
  42. '(b h) t s -> b h t s', h=nheads)
  43. if causal:
  44. # "triu_tril_cuda_template" not implemented for 'BFloat16'
  45. # So we have to construct the mask in float
  46. causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
  47. # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
  48. scores = scores + causal_mask.to(dtype=scores.dtype)
  49. attention = torch.softmax(scores, dim=-1)
  50. attention_drop = F.dropout(attention, dropout_p)
  51. output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
  52. return output.to(dtype=qkv.dtype)
  53. def time_fwd_bwd(func, *args, **kwargs):
  54. time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
  55. return time_f[1].mean, time_b[1].mean
  56. repeats = 30
  57. device = 'cuda'
  58. dtype = torch.float16
  59. bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
  60. causal_vals = [False, True]
  61. headdim_vals = [64, 128]
  62. dim = 2048
  63. dropout_p = 0.0
  64. methods = (["Flash2", "Pytorch"]
  65. + (["Triton"] if attention_triton is not None else [])
  66. + (["xformers.c"] if xops is not None else [])
  67. + (["xformers.f"] if xops is not None else []))
  68. time_f = {}
  69. time_b = {}
  70. time_f_b = {}
  71. speed_f = {}
  72. speed_b = {}
  73. speed_f_b = {}
  74. for causal in causal_vals:
  75. for headdim in headdim_vals:
  76. for batch_size, seqlen in bs_seqlen_vals:
  77. config = (causal, headdim, batch_size, seqlen)
  78. nheads = dim // headdim
  79. qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
  80. requires_grad=True)
  81. f, b = time_fwd_bwd(
  82. flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
  83. )
  84. time_f[config, "Flash2"] = f
  85. time_b[config, "Flash2"] = b
  86. try:
  87. qkv = qkv.detach().requires_grad_(True)
  88. f, b = time_fwd_bwd(
  89. attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
  90. )
  91. except: # Skip if OOM
  92. f, b = float('nan'), float('nan')
  93. time_f[config, "Pytorch"] = f
  94. time_b[config, "Pytorch"] = b
  95. if attention_triton is not None:
  96. q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
  97. requires_grad=True) for _ in range(3)]
  98. # Try both values of sequence_parallel and pick the faster one
  99. try:
  100. f, b = time_fwd_bwd(
  101. attention_triton, q, k, v, causal, headdim**(-0.5),
  102. False, repeats=repeats, verbose=False
  103. )
  104. except:
  105. f, b = float('nan'), float('inf')
  106. try:
  107. _, b0 = time_fwd_bwd(
  108. attention_triton, q, k, v, causal, headdim**(-0.5),
  109. True, repeats=repeats, verbose=False
  110. )
  111. except:
  112. b0 = float('inf')
  113. time_f[config, "Triton"] = f
  114. time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan')
  115. if xops is not None:
  116. q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
  117. requires_grad=True) for _ in range(3)]
  118. f, b = time_fwd_bwd(
  119. xops.memory_efficient_attention, q, k, v,
  120. attn_bias=xops.LowerTriangularMask() if causal else None,
  121. op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
  122. )
  123. time_f[config, "xformers.c"] = f
  124. time_b[config, "xformers.c"] = b
  125. if xops is not None:
  126. q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
  127. requires_grad=True) for _ in range(3)]
  128. f, b = time_fwd_bwd(
  129. xops.memory_efficient_attention, q, k, v,
  130. attn_bias=xops.LowerTriangularMask() if causal else None,
  131. op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
  132. )
  133. time_f[config, "xformers.f"] = f
  134. time_b[config, "xformers.f"] = b
  135. print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
  136. for method in methods:
  137. time_f_b[config, method] = time_f[config, method] + time_b[config, method]
  138. speed_f[config, method] = efficiency(
  139. flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
  140. time_f[config, method]
  141. )
  142. speed_b[config, method] = efficiency(
  143. flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
  144. time_b[config, method]
  145. )
  146. speed_f_b[config, method] = efficiency(
  147. flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
  148. time_f_b[config, method]
  149. )
  150. print(
  151. f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
  152. f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
  153. f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
  154. )
  155. # with open('flash2_attn_time.plk', 'wb') as fp:
  156. # pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)