bench.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import argparse
  2. import torch
  3. import triton
  4. from flash_attn.flash_attn_triton_amd.utils import (
  5. MetaData,
  6. input_helper,
  7. varlen_input_helper,
  8. )
  9. from flash_attn.flash_attn_triton_amd.interface_torch import attention_prefill, attention_decode
  10. ARGS_TO_TORCH_DTYPE = {
  11. "fp16": torch.float16,
  12. "bf16": torch.bfloat16,
  13. "fp32": torch.float32,
  14. }
  15. FUNCTIONS = {
  16. "prefill": attention_prefill,
  17. "decode": attention_decode
  18. }
  19. def get_benchmark_configs(args, varlen=False):
  20. """
  21. Returns benchmark configurations based on whether variable-length sequences are used.
  22. """
  23. if args.custom_config:
  24. hk = args.hq if not args.hk else args.hk
  25. sk = args.sq if not args.sk else args.sk
  26. return [(args.b, args.hq, hk, args.sq, sk)]
  27. elif varlen:
  28. return [
  29. (2, 16, 4, 1024, 1024),
  30. (8, 16, 2, 2048, 2048),
  31. (4, 16, 8, 4096, 4096),
  32. (2, 16, 4, 8192, 8192),
  33. (2, 16, 8, 16384, 16384),
  34. (2, 48, 12, 1024, 1024),
  35. (2, 48, 24, 2048, 2048),
  36. (2, 48, 8, 4096, 4096),
  37. (2, 48, 4, 8192, 8192),
  38. (2, 48, 2, 16384, 16384),
  39. (2, 64, 32, 1024, 1024),
  40. (4, 64, 16, 2048, 2048),
  41. (4, 64, 8, 4096, 4096),
  42. (4, 64, 32, 8192, 8192),
  43. (4, 128, 16, 16384, 16384),
  44. ]
  45. else:
  46. return [
  47. (16, 16, 16, 1024, 1024),
  48. (8, 16, 16, 2048, 2048),
  49. (4, 16, 16, 4096, 4096),
  50. (1, 8, 8, 8192, 8192),
  51. (1, 2, 2, 16384, 16384),
  52. (2, 48, 48, 1024, 1024),
  53. (2, 48, 48, 2048, 1024),
  54. (1, 8, 8, 4096, 8192),
  55. (1, 8, 8, 8192, 4096),
  56. (2, 4, 4, 16384, 8192),
  57. (2, 8, 8, 1989, 15344),
  58. (4, 16, 16, 4097, 163),
  59. (2, 16, 16, 8122, 2159),
  60. (1, 16, 16, 16281, 7),
  61. (2, 48, 48, 1021, 1020),
  62. (2, 48, 48, 2001, 2048),
  63. (2, 8, 8, 3996, 9639),
  64. (2, 8, 8, 8181, 1021),
  65. ]
  66. def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal):
  67. flops_per_matmul = 0
  68. if fn_name.startswith("prefill"):
  69. if layout == "thd":
  70. q, k, v, input_metadata = varlen_input_helper(
  71. BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device=device)
  72. for i in range(input_metadata.num_contexts):
  73. seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i]
  74. seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i]
  75. flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2
  76. else:
  77. q, k, v, input_metadata = input_helper(
  78. BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device
  79. )
  80. flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
  81. if causal:
  82. input_metadata.need_causal()
  83. o = torch.empty_like(q)
  84. input_data = (q, k, v, o, input_metadata)
  85. elif fn_name.startswith("decode"):
  86. q = torch.randn(
  87. [BATCH, N_CTX_Q, HK, HQ // HK, D_HEAD],
  88. device=device,
  89. dtype=dtype,
  90. requires_grad=False,
  91. )
  92. k = torch.randn(
  93. [BATCH, N_CTX_K, HK, 1, D_HEAD],
  94. device=device,
  95. dtype=dtype,
  96. requires_grad=False,
  97. ).expand(-1, -1, -1, HQ // HK, -1)
  98. v = torch.randn(
  99. [BATCH, N_CTX_K, HK, 1, D_HEAD],
  100. device=device,
  101. dtype=dtype,
  102. requires_grad=False,
  103. ).expand(-1, -1, -1, HQ // HK, -1)
  104. input_metadata = MetaData(sm_scale=1.3)
  105. input_metadata.layout = "bsghd"
  106. # Adjust flops calculation if needed
  107. flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
  108. input_data = (q, k, v, input_metadata)
  109. else:
  110. raise ValueError("Unsupported benchmark function")
  111. return input_data, flops_per_matmul
  112. def run_benchmark(args, fn_name, fn, mode):
  113. """
  114. Runs the benchmark for the provided function based on the provided arguments.
  115. """
  116. print(f"Benchmarking {fn_name} in {mode} mode...")
  117. dtype = ARGS_TO_TORCH_DTYPE[args.dtype]
  118. head_size = args.d if args.d else 128
  119. causal = args.causal
  120. varlen = args.layout == "thd"
  121. return_tflops = args.return_tflops
  122. line_names = "TFLOPS" if return_tflops else "Time (ms)"
  123. # Determine configurations
  124. x_vals_list = get_benchmark_configs(args, varlen=varlen)
  125. # Setup benchmark configurations
  126. configs = [
  127. triton.testing.Benchmark(
  128. x_names=["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K"],
  129. x_vals=x_vals_list,
  130. line_arg="provider",
  131. line_vals=["triton"],
  132. line_names=[line_names],
  133. styles=[("red", "-")],
  134. ylabel="ms",
  135. plot_name=f"benchmark-{fn_name}-d{head_size}-layout{args.layout}-mode{mode}",
  136. args={
  137. "D_HEAD": head_size,
  138. "dtype": dtype,
  139. "causal": causal,
  140. "mode": mode,
  141. },
  142. )
  143. ]
  144. @triton.testing.perf_report(configs)
  145. def bench_function(
  146. BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda"
  147. ):
  148. warmup = 25
  149. rep = 100
  150. flops_per_matmul = 0
  151. # generate function inputs
  152. fn_inputs, flops_per_matmul = gen_fn_inputs(
  153. fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, args.layout, causal
  154. )
  155. # define the function to benchmark
  156. if mode == "fwd":
  157. benchmark_fn = lambda: fn(*fn_inputs)
  158. total_flops = 2 * flops_per_matmul
  159. elif mode == "bwd":
  160. outputs = fn(*fn_inputs)
  161. output = outputs[0]
  162. grad_output = torch.randn_like(output)
  163. benchmark_fn = lambda: output.backward(grad_output, retain_graph=True)
  164. total_flops = 2 * flops_per_matmul * 2.5
  165. else:
  166. raise ValueError("Unsupported mode. Choose 'fwd' or 'bwd'.")
  167. if causal:
  168. total_flops *= 0.5
  169. # Run the benchmark
  170. ms = triton.testing.do_bench(benchmark_fn, warmup=warmup, rep=rep)
  171. if return_tflops:
  172. return total_flops / ms * 1e-9
  173. else:
  174. return ms
  175. bench_function.run(save_path=".", print_data=True)
  176. def supported_layouts():
  177. """
  178. Returns a string describing the supported layouts.
  179. """
  180. return (
  181. "bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]\n"
  182. "bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]\n"
  183. "thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]\n"
  184. 'This layout is sometimes called "varlen" or "grouped" layout.'
  185. )
  186. def parse_args():
  187. """
  188. Parses command-line arguments.
  189. """
  190. parser = argparse.ArgumentParser(
  191. prog="Benchmark FlashAttention",
  192. allow_abbrev=False,
  193. )
  194. parser.add_argument("-b", type=int, default=0)
  195. parser.add_argument("-hq", type=int, default=0)
  196. parser.add_argument("-hk", type=int, default=0)
  197. parser.add_argument("-sq", type=int, default=0)
  198. parser.add_argument("-sk", type=int, default=0)
  199. parser.add_argument(
  200. "-equal_seqlens",
  201. action="store_true",
  202. default=False,
  203. help="If specified, each context within the thd layout has same seqlen as sq and sk",
  204. )
  205. parser.add_argument("-d", type=int, default=0)
  206. parser.add_argument("-causal", action="store_true", default=False)
  207. parser.add_argument("-dtype", default="fp16")
  208. parser.add_argument("-return_tflops", action="store_true", default=False)
  209. parser.add_argument(
  210. "-layout",
  211. type=str,
  212. default="bhsd",
  213. help=supported_layouts(),
  214. )
  215. parser.add_argument(
  216. "-benchmark_fn",
  217. type=str,
  218. nargs="*",
  219. choices=FUNCTIONS.keys(),
  220. help="Function(s) to benchmark: prefill, decode, or both",
  221. )
  222. parser.add_argument(
  223. "-mode",
  224. type=str,
  225. nargs='*',
  226. default=["fwd", "bwd"],
  227. choices=["fwd", "bwd"],
  228. help="Mode(s) to run: 'fwd' for forward pass, 'bwd' for backward pass",
  229. )
  230. return parser.parse_args()
  231. def main():
  232. """
  233. Main function to run benchmarks.
  234. """
  235. args = parse_args()
  236. # Validate arguments
  237. assert (
  238. args.layout == "thd" or not args.equal_seqlens
  239. ), "Equal sequence lengths arg must be used with the thd layout."
  240. args.custom_config = False
  241. if args.b or args.hq or args.hk or args.sq or args.sk or args.d:
  242. args.custom_config = True
  243. assert args.b and args.hq and args.sq and args.d, (
  244. "If custom config is specified, please provide all of batch, "
  245. "number of Q heads, Q sequence length, and head size."
  246. )
  247. assert args.dtype in ARGS_TO_TORCH_DTYPE, "Only fp16, bf16 and fp32 types currently supported."
  248. # determine the functions to benchmark
  249. if args.benchmark_fn is None or len(args.benchmark_fn) == 0:
  250. bench_fn_list = FUNCTIONS.keys()
  251. else:
  252. bench_fn_list = args.benchmark_fn
  253. # benchmark functions
  254. for fn_name in bench_fn_list:
  255. if fn_name not in FUNCTIONS:
  256. raise ValueError(f"Invalid benchmark function specified: {fn_name}")
  257. for mode in args.mode:
  258. if fn_name == "decode" and mode == "bwd":
  259. print(f"Decode kernel doesnot have a backward pass")
  260. continue
  261. run_benchmark(args, fn_name, FUNCTIONS[fn_name], mode)
  262. if __name__ == "__main__":
  263. main()