benchmark_split_kv.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import torch
  2. import flash_attn
  3. import flash_attn_interface
  4. import itertools
  5. import time
  6. import math
  7. import torch.utils.benchmark as benchmark
  8. def round_up_to_power_of_2(x):
  9. if x <= 1:
  10. return 1
  11. return 1 << (x - 1).bit_length()
  12. def timeit(fn, *args, **kwargs):
  13. torch.cuda.synchronize()
  14. # Warmup
  15. for _ in range(5):
  16. fn(*args, **kwargs)
  17. # Benchmark using PyTorch Timer
  18. t = benchmark.Timer(
  19. stmt='fn(*args, **kwargs)',
  20. globals={'fn': fn, 'args': args, 'kwargs': kwargs}
  21. )
  22. # Measure execution time
  23. measurement = t.timeit(20) # Runs the function 20 times
  24. # measurement = t.blocked_autorange(min_run_time=1)
  25. avg_time = measurement.mean # Average time in seconds
  26. return avg_time
  27. def main():
  28. num_sms = torch.cuda.get_device_properties(
  29. torch.cuda.current_device()
  30. ).multi_processor_count
  31. max_splits = 129
  32. check_all_splits = True
  33. causal = True
  34. # causal = False
  35. # dtype=torch.float16
  36. dtype=torch.bfloat16
  37. tp_degree = 1
  38. torch.manual_seed(42)
  39. model_configs = [
  40. # ("Gemma-2-2B", 8, 4, 256),
  41. # ("Gemma-2-9B", 16, 8, 256),
  42. # ("Gemma-2-27B", 32, 16, 128),
  43. # ("Qwen-2.5-0.5B", 14, 2, 64),
  44. # ("Qwen-2.5-1.5B", 12, 2, 128),
  45. # ("Qwen-2.5-7B", 28, 4, 128),
  46. # ("Llama-3.1-8B", 32, 8, 128),
  47. ("Llama-3.1-70B", 64, 8, 128),
  48. # ("Mistral Large", 96, 8, 128),
  49. # ("Llama-3.1-405B", 128, 8, 128),
  50. # ("Llama-3.2-1B", 32, 8, 64),
  51. # ("Llama-3.2-3B", 24, 8, 128),
  52. # ("Nemotron-4-15B", 48, 8, 128),
  53. ]
  54. all_batch_configs = []
  55. all_batch_configs.extend(itertools.product(
  56. # [1024, 2048, 4096, 8192, 16384, 32768, 131072], # context_seqlen
  57. # [4096, 16384, 65536], # context_seqlen
  58. [131072], # context_seqlen
  59. # [i for i in range(1, (num_sms) + 1)], # num_requests
  60. [1, 4, 8, 16], # num_requests
  61. # [1], # num_requests
  62. # [1, 4, 8, 16], # query_seqlen
  63. [1], # query_seqlen
  64. ))
  65. num_caches = max(reqs for _, reqs, _ in all_batch_configs)
  66. cache_seqlen = max(seqlen for seqlen, _, _ in all_batch_configs)
  67. for model_name, nheads_q, nheads_kv, headdim in model_configs:
  68. assert nheads_kv % tp_degree == 0
  69. print(f"***{model_name}***")
  70. print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}, TP:{tp_degree}")
  71. nheads_q //= tp_degree
  72. nheads_kv //= tp_degree
  73. k_cache = torch.randn(
  74. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
  75. )
  76. v_cache = torch.randn(
  77. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
  78. )
  79. if check_all_splits is False:
  80. print(f"{'CONTEXT':<9}{'BSZ':<5}{'QLEN':<6}{'FA2':<10}{'FA3':<9}{'RATIO':<7}{'GB/s':<10}")
  81. for context_seqlen, num_requests, query_seqlen in all_batch_configs:
  82. bytes_kv = (context_seqlen * num_requests * nheads_kv * headdim * 4)
  83. bytes_q = (query_seqlen * num_requests * nheads_q * headdim * 4)
  84. blockH = round_up_to_power_of_2(nheads_q//nheads_kv)
  85. blockM = 128 # true for hdim 128 causal and hdim 64
  86. blockM_div_H = blockM//blockH
  87. num_work_tiles = nheads_kv * num_requests * math.ceil(query_seqlen/blockM_div_H)
  88. q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=dtype)
  89. cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests]
  90. cache_seqlens = torch.tensor(
  91. [context_seqlen] * num_requests, dtype=torch.int32, device="cuda"
  92. )
  93. fa2_time_heuristic = timeit(
  94. flash_attn.flash_attn_with_kvcache,
  95. q=q,
  96. k_cache=k_cache,
  97. v_cache=v_cache,
  98. cache_seqlens=cache_seqlens,
  99. cache_batch_idx=cache_idxs,
  100. causal=causal,
  101. ) * 1000. * 1000.
  102. # fastest_splitk_time = float("inf")
  103. # fastest_splitk = 0
  104. # for i in range(1, max_splits):
  105. # t = timeit(
  106. # flash_attn.flash_attn_with_kvcache,
  107. # q=q,
  108. # k_cache=k_cache,
  109. # v_cache=v_cache,
  110. # cache_seqlens=cache_seqlens,
  111. # cache_batch_idx=cache_idxs,
  112. # causal=causal,
  113. # num_splits=i,
  114. # ) * 1000. * 1000.
  115. # if t < fastest_splitk_time:
  116. # fastest_splitk_time = t
  117. # fastest_splitk = i
  118. fa3_time_one_split = timeit(
  119. flash_attn_interface.flash_attn_with_kvcache,
  120. q=q,
  121. k_cache=k_cache,
  122. v_cache=v_cache,
  123. cache_seqlens=cache_seqlens,
  124. cache_batch_idx=cache_idxs,
  125. causal=causal,
  126. pack_gqa=False,
  127. num_splits=1,
  128. ) * 1000. * 1000.
  129. fa3_time_gqa_heuristic = timeit(
  130. flash_attn_interface.flash_attn_with_kvcache,
  131. q=q,
  132. k_cache=k_cache,
  133. v_cache=v_cache,
  134. cache_seqlens=cache_seqlens,
  135. cache_batch_idx=cache_idxs,
  136. causal=causal,
  137. pack_gqa=True,
  138. num_splits=0,
  139. # max_seqlen_k_hint=context_seqlen
  140. ) * 1000. * 1000.
  141. if check_all_splits:
  142. fa3_fastest_num_splits = 0
  143. fa3_fastest_splitk_time = float("inf")
  144. for num_splits in range(1, max_splits):
  145. t = timeit(
  146. flash_attn_interface.flash_attn_with_kvcache,
  147. q=q,
  148. k_cache=k_cache,
  149. v_cache=v_cache,
  150. cache_seqlens=cache_seqlens,
  151. cache_batch_idx=cache_idxs,
  152. causal=causal,
  153. pack_gqa=False,
  154. num_splits=num_splits
  155. ) * 1000. * 1000.
  156. out0 = flash_attn_interface.flash_attn_with_kvcache(
  157. q=q,
  158. k_cache=k_cache,
  159. v_cache=v_cache,
  160. cache_seqlens=cache_seqlens,
  161. cache_batch_idx=cache_idxs,
  162. causal=causal,
  163. pack_gqa=False,
  164. num_splits=num_splits
  165. )
  166. out1 = flash_attn_interface.flash_attn_with_kvcache(
  167. q=q,
  168. k_cache=k_cache,
  169. v_cache=v_cache,
  170. cache_seqlens=cache_seqlens,
  171. cache_batch_idx=cache_idxs,
  172. causal=causal,
  173. pack_gqa=False,
  174. num_splits=1
  175. )
  176. max_diff = (out0 - out1).abs().max().item()
  177. mean_diff = (out0 - out1).abs().mean().item()
  178. # print (f"splits {num_splits}, out diff-max, {max_diff}, out diff-mean, {mean_diff}, time {t:.2f}")
  179. # print (f"splits {num_splits}, time {t:.2f}")
  180. if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4:
  181. print(f"Numerical error too high: Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}")
  182. if t < fa3_fastest_splitk_time:
  183. fa3_fastest_splitk_time = t
  184. fa3_fastest_num_splits = num_splits
  185. fa3_fastest_num_splits_gqa = 0
  186. fa3_fastest_splitk_time_gqa = float("inf")
  187. for num_splits in range(1, max_splits):
  188. t = timeit(
  189. flash_attn_interface.flash_attn_with_kvcache,
  190. q=q,
  191. k_cache=k_cache,
  192. v_cache=v_cache,
  193. cache_seqlens=cache_seqlens,
  194. cache_batch_idx=cache_idxs,
  195. causal=causal,
  196. pack_gqa=True,
  197. num_splits=num_splits
  198. ) * 1000. * 1000.
  199. out0 = flash_attn_interface.flash_attn_with_kvcache(
  200. q=q,
  201. k_cache=k_cache,
  202. v_cache=v_cache,
  203. cache_seqlens=cache_seqlens,
  204. cache_batch_idx=cache_idxs,
  205. causal=causal,
  206. pack_gqa=True,
  207. num_splits=num_splits
  208. )
  209. out1 = flash_attn_interface.flash_attn_with_kvcache(
  210. q=q,
  211. k_cache=k_cache,
  212. v_cache=v_cache,
  213. cache_seqlens=cache_seqlens,
  214. cache_batch_idx=cache_idxs,
  215. causal=causal,
  216. pack_gqa=True,
  217. num_splits=1
  218. )
  219. max_diff = (out0 - out1).abs().max().item()
  220. mean_diff = (out0 - out1).abs().mean().item()
  221. # print (f"gqa splits {num_splits}, out gqa diff-max {max_diff}, out gqa diff-mean {mean_diff}, time {t:.2f}")
  222. # print (f"gqa splits {num_splits}, time {t:.2f}")
  223. if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4:
  224. print(f"Numerical error too high (gqa): Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}")
  225. if t < fa3_fastest_splitk_time_gqa:
  226. fa3_fastest_splitk_time_gqa = t
  227. fa3_fastest_num_splits_gqa = num_splits
  228. efficiency = (num_work_tiles * fa3_fastest_num_splits_gqa)/num_sms
  229. heuristic_ratio = fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa
  230. # remeasure to smooth anomalies
  231. if heuristic_ratio > 1.1:
  232. fa3_time_gqa_heuristic = timeit(
  233. flash_attn_interface.flash_attn_with_kvcache,
  234. q=q,
  235. k_cache=k_cache,
  236. v_cache=v_cache,
  237. cache_seqlens=cache_seqlens,
  238. cache_batch_idx=cache_idxs,
  239. causal=causal,
  240. pack_gqa=True,
  241. # num_splits=num_splits_select,
  242. # num_splits=1,
  243. num_splits=0,
  244. # max_seqlen_k_hint=context_seqlen
  245. ) * 1000. * 1000.
  246. fa3_fastest_splitk_time_gqa = timeit(
  247. flash_attn_interface.flash_attn_with_kvcache,
  248. q=q,
  249. k_cache=k_cache,
  250. v_cache=v_cache,
  251. cache_seqlens=cache_seqlens,
  252. cache_batch_idx=cache_idxs,
  253. causal=causal,
  254. pack_gqa=True,
  255. num_splits=fa3_fastest_num_splits_gqa
  256. ) * 1000. * 1000.
  257. if check_all_splits is True:
  258. print(
  259. f"CONTEXT:{context_seqlen}, BSZ:{num_requests}, QLEN:{query_seqlen}, "
  260. f"FA2:{fa2_time_heuristic:.2f}, "
  261. # f"FA2 MANUAL:{fastest_splitk_time:.2f}, "
  262. # f"FA2 NUM SPLITS:{fastest_splitk}, "
  263. # f"FA3 NOGQA NOSPLIT:{fa3_time_one_split:.2f}, "
  264. # f"FA3 NOGQA SPLIT MANUAL:{fa3_fastest_splitk_time:.2f}, "
  265. # f"FA3 NOSPLIT:{fa3_time_one_split_gqa:.2f}, "
  266. f"FA3 SPLIT MANUAL:{fa3_fastest_splitk_time_gqa:.2f}, "
  267. f"FA3:{fa3_time_gqa_heuristic:.2f}, "
  268. # f"FA3 RATIO (NONSPLIT/SPLIT):{fa3_time_one_split_gqa/fa3_time_gqa_heuristic:.2f}, "
  269. # f"FA2 NUM SPLITS:{fastest_splitk}, "
  270. # f"FA3 NOGQA NUM SPLITS:{fa3_fastest_num_splits}, "
  271. f"FA3 NUM SPLITS:{fa3_fastest_num_splits_gqa}, "
  272. # f"RATIO (FA2/3):{fa2_time_heuristic/fa3_time_gqa_heuristic:.2f}, "
  273. f"RATIO:{fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa:.2f}, "
  274. f"EFF:{efficiency:.2f}, "
  275. f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}"
  276. )
  277. if check_all_splits is False:
  278. print(
  279. f"{context_seqlen:<9}{num_requests:<5}{query_seqlen:<6}"
  280. f"{fa2_time_heuristic:<10.2f}{fa3_time_gqa_heuristic:<9.2f}"
  281. f"{fa2_time_heuristic/fa3_time_gqa_heuristic:<7.2f}"
  282. f"{bytes_kv/fa3_time_gqa_heuristic * 1e-3:<10.2f}"
  283. )
  284. if __name__ == "__main__":
  285. main()