1
0

benchmark_alibi.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # Copyright (c) 2024, Sanghun Cho, Tri Dao.
  2. import pickle
  3. import math
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from einops import rearrange, repeat
  8. from flash_attn.layers.rotary import apply_rotary_emb
  9. from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
  10. from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
  11. from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
  12. try:
  13. import xformers.ops as xops
  14. except ImportError:
  15. xops = None
  16. def generate_cos_sin(seqlen, rotary_dim, device, dtype):
  17. assert rotary_dim % 2 == 0
  18. angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
  19. cos = torch.cos(angle).to(dtype=dtype)
  20. sin = torch.sin(angle).to(dtype=dtype)
  21. return cos, sin
  22. def flash_rotary(q, k, v, cos, sin, causal=False):
  23. # corrected by @tridao comments
  24. q = apply_rotary_emb(
  25. q, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
  26. )
  27. k = apply_rotary_emb(
  28. k, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
  29. )
  30. return flash_attn_func(q, k, v, causal=causal)
  31. def attn_bias_from_alibi_slopes(
  32. slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
  33. ):
  34. batch, nheads = slopes.shape
  35. device = slopes.device
  36. slopes = rearrange(slopes, "b h -> b h 1 1")
  37. if causal:
  38. return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
  39. else:
  40. row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
  41. col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
  42. sk = (
  43. seqlen_k
  44. if key_padding_mask is None
  45. else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
  46. )
  47. sq = (
  48. seqlen_q
  49. if query_padding_mask is None
  50. else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
  51. )
  52. relative_pos = torch.abs(row_idx + sk - sq - col_idx)
  53. return -slopes * relative_pos.to(dtype=slopes.dtype)
  54. def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
  55. assert mode in ["fwd", "bwd", "fwd_bwd"]
  56. f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
  57. return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
  58. def efficiency(flop, time):
  59. return (flop / time / 10**12) if not math.isnan(time) else 0.0
  60. def attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None):
  61. """
  62. Arguments:
  63. q, k, v: (batch_size, seqlen, nheads, head_dim)
  64. dropout_p: float
  65. attn_bias: (batch_size, nheads, seqlen, seqlen) or (1, nheads, seqlen, seqlen)
  66. Output:
  67. output: (batch_size, seqlen, nheads, head_dim)
  68. """
  69. batch_size, seqlen, nheads, d = q.shape
  70. q = rearrange(q, 'b t h d -> (b h) t d')
  71. k = rearrange(k, 'b s h d -> (b h) d s')
  72. softmax_scale = 1.0 / math.sqrt(d)
  73. # Preallocate attn_weights for `baddbmm`
  74. if attn_bias is not None:
  75. scores = rearrange(attn_bias, 'b h t s -> (b h) t s')
  76. else:
  77. scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device)
  78. scores = rearrange(torch.baddbmm(scores, q, k, beta=1.0, alpha=softmax_scale),
  79. '(b h) t s -> b h t s', h=nheads)
  80. if causal:
  81. # "triu_tril_cuda_template" not implemented for 'BFloat16'
  82. # So we have to construct the mask in float
  83. causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
  84. # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
  85. scores = scores + causal_mask.to(dtype=scores.dtype)
  86. attention = torch.softmax(scores, dim=-1)
  87. attention_drop = F.dropout(attention, dropout_p)
  88. output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
  89. return output.to(dtype=q.dtype)
  90. def time_fwd_bwd(func, *args, **kwargs):
  91. time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
  92. return time_f[1].mean, time_b[1].mean
  93. repeats = 30
  94. device = 'cuda'
  95. dtype = torch.float16
  96. bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
  97. causal_vals = [False, True]
  98. headdim_vals = [64, 128]
  99. dim = 2048
  100. dropout_p = 0.0
  101. methods = (["fa2_alibi", "torch"]
  102. + (["xformers"] if xops is not None else [])
  103. + ["sdpa"]
  104. + ["fa2_baseline"]
  105. + ["fa2_rotary"])
  106. time_f = {}
  107. time_b = {}
  108. time_f_b = {}
  109. speed_f = {}
  110. speed_b = {}
  111. speed_f_b = {}
  112. for causal in causal_vals:
  113. for headdim in headdim_vals:
  114. for batch_size, seqlen in bs_seqlen_vals:
  115. config = (causal, headdim, batch_size, seqlen)
  116. nheads = dim // headdim
  117. q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
  118. requires_grad=True) for _ in range(3)]
  119. # alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  120. alibi_slopes = torch.rand(1, nheads, device=device, dtype=torch.float32) * 0.3
  121. attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal).to(dtype)
  122. attn_bias = repeat(attn_bias, "1 ... -> b ...", b=batch_size)
  123. f, b = time_fwd_bwd(
  124. flash_attn_func,
  125. q, k, v,
  126. dropout_p,
  127. causal=causal,
  128. # alibi_slopes=alibi_slopes,
  129. alibi_slopes=None,
  130. repeats=repeats,
  131. verbose=False
  132. )
  133. time_f[config, "fa2_baseline"] = f
  134. time_b[config, "fa2_baseline"] = b
  135. q = q.detach().requires_grad_(True)
  136. k = k.detach().requires_grad_(True)
  137. v = v.detach().requires_grad_(True)
  138. f, b = time_fwd_bwd(
  139. flash_attn_func,
  140. q, k, v,
  141. dropout_p,
  142. causal=causal,
  143. alibi_slopes=rearrange(alibi_slopes, "1 h -> h"),
  144. # alibi_slopes=None,
  145. repeats=repeats,
  146. verbose=False
  147. )
  148. time_f[config, "fa2_alibi"] = f
  149. time_b[config, "fa2_alibi"] = b
  150. try:
  151. q = q.detach().requires_grad_(True)
  152. k = k.detach().requires_grad_(True)
  153. v = v.detach().requires_grad_(True)
  154. f, b = time_fwd_bwd(
  155. attention_pytorch,
  156. q, k, v,
  157. dropout_p,
  158. causal=causal,
  159. attn_bias=attn_bias,
  160. repeats=repeats,
  161. verbose=False
  162. )
  163. except: # Skip if OOM
  164. f, b = float('nan'), float('nan')
  165. time_f[config, "torch"] = f
  166. time_b[config, "torch"] = b
  167. # F.sdpa doesn't currently (torch 2.1) dispatch to flash-attn but just to be safe
  168. with torch.backends.cuda.sdp_kernel(enable_flash=False):
  169. q_pt = q.detach().requires_grad_(True).transpose(1, 2)
  170. k_pt = k.detach().requires_grad_(True).transpose(1, 2)
  171. v_pt = v.detach().requires_grad_(True).transpose(1, 2)
  172. f, b = time_fwd_bwd(
  173. F.scaled_dot_product_attention,
  174. q_pt, k_pt, v_pt,
  175. attn_mask=attn_bias,
  176. dropout_p=dropout_p,
  177. is_causal=causal,
  178. repeats=repeats,
  179. verbose=False
  180. )
  181. time_f[config, "sdpa"] = f
  182. time_b[config, "sdpa"] = b
  183. if xops is not None:
  184. q = q.detach().requires_grad_(True)
  185. k = k.detach().requires_grad_(True)
  186. v = v.detach().requires_grad_(True)
  187. if causal:
  188. attn_bias_xops = xops.LowerTriangularMask().add_bias(attn_bias.expand(-1, -1, seqlen, -1).to(dtype=q.dtype))
  189. # NotImplementedError: No operator found for `memory_efficient_attention_backward` with inputs:
  190. # `flshattB@v2.3.6` is not supported because:
  191. # attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
  192. # `cutlassB` is not supported because:
  193. # attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
  194. attn_bias_xops = attn_bias_xops.materialize((batch_size, nheads, seqlen, seqlen), dtype=q.dtype, device=device)
  195. else:
  196. attn_bias_xops = attn_bias.to(dtype=q.dtype)
  197. f, b = time_fwd_bwd(
  198. xops.memory_efficient_attention,
  199. q, k, v,
  200. attn_bias_xops,
  201. dropout_p,
  202. repeats=repeats,
  203. verbose=False
  204. )
  205. time_f[config, "xformers"] = f
  206. time_b[config, "xformers"] = b
  207. q = q.detach().requires_grad_(True)
  208. k = k.detach().requires_grad_(True)
  209. v = v.detach().requires_grad_(True)
  210. cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
  211. f, b = time_fwd_bwd(
  212. flash_rotary,
  213. q, k, v,
  214. cos, sin,
  215. causal,
  216. repeats=repeats,
  217. verbose=False
  218. )
  219. time_f[config, "fa2_rotary"] = f
  220. time_b[config, "fa2_rotary"] = b
  221. print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
  222. csv_output = ""
  223. csv_output += f"{causal},{headdim},{batch_size},{seqlen},"
  224. for method in methods:
  225. time_f_b[config, method] = time_f[config, method] + time_b[config, method]
  226. speed_f[config, method] = efficiency(
  227. flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
  228. time_f[config, method]
  229. )
  230. speed_b[config, method] = efficiency(
  231. flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
  232. time_b[config, method]
  233. )
  234. speed_f_b[config, method] = efficiency(
  235. flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
  236. time_f_b[config, method]
  237. )
  238. print(
  239. f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
  240. f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
  241. f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
  242. )
  243. csv_output += f"{speed_f[config, method]:.2f},{speed_b[config, method]:.2f},{speed_f_b[config, method]:.2f},"
  244. print(csv_output)