benchmark_split_kv.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import torch
  2. import flash_attn
  3. import flash_attn_interface
  4. import itertools
  5. import time
  6. import math
  7. import torch.utils.benchmark as benchmark
  8. def round_up_to_power_of_2(x):
  9. if x <= 1:
  10. return 1
  11. return 1 << (x - 1).bit_length()
  12. def timeit(fn, *args, **kwargs):
  13. torch.cuda.synchronize()
  14. # Warmup
  15. for _ in range(5):
  16. fn(*args, **kwargs)
  17. # Benchmark using PyTorch Timer
  18. t = benchmark.Timer(
  19. stmt='fn(*args, **kwargs)',
  20. globals={'fn': fn, 'args': args, 'kwargs': kwargs}
  21. )
  22. # Measure execution time
  23. measurement = t.timeit(20) # Runs the function 20 times
  24. # measurement = t.blocked_autorange(min_run_time=1)
  25. avg_time = measurement.mean # Average time in seconds
  26. return avg_time
  27. def main():
  28. num_sms = torch.cuda.get_device_properties(
  29. torch.cuda.current_device()
  30. ).multi_processor_count
  31. max_splits = 129
  32. check_all_splits = False
  33. causal = True
  34. # causal = False
  35. # dtype=torch.float16
  36. dtype=torch.bfloat16
  37. torch.manual_seed(42)
  38. model_configs = [
  39. # ("Gemma-2-2B", 8, 4, 256),
  40. # ("Gemma-2-9B", 16, 8, 256),
  41. # ("Gemma-2-27B", 32, 16, 128),
  42. # ("Qwen-2.5-0.5B", 14, 2, 64),
  43. # ("Qwen-2.5-1.5B", 12, 2, 128),
  44. # ("Qwen-2.5-7B", 28, 4, 128),
  45. # ("Llama-3.1-8B", 32, 8, 128),
  46. ("Llama-3.1-70B", 64, 8, 128),
  47. # ("Llama-3.1-405B", 128, 8, 128),
  48. # ("Llama-3.2-1B", 32, 8, 64),
  49. # ("Llama-3.2-3B", 24, 8, 128),
  50. # ("Nemotron-4-15B", 48, 8, 128),
  51. ]
  52. all_batch_configs = []
  53. all_batch_configs.extend(itertools.product(
  54. # [1024, 2048, 4096, 8192, 16384, 32768, 131072], # context_seqlen
  55. [4096, 16384, 65536], # context_seqlen
  56. # [131072], # context_seqlen
  57. # [i for i in range(1, (num_sms) + 1)], # num_requests
  58. [1, 4, 8, 16], # num_requests
  59. # [1], # num_requests
  60. [1, 4, 8, 16], # query_seqlen
  61. # [1], # query_seqlen
  62. ))
  63. num_caches = max(reqs for _, reqs, _ in all_batch_configs)
  64. cache_seqlen = max(seqlen for seqlen, _, _ in all_batch_configs)
  65. for model_name, nheads_q, nheads_kv, headdim in model_configs:
  66. k_cache = torch.randn(
  67. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
  68. )
  69. v_cache = torch.randn(
  70. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
  71. )
  72. print(f"***{model_name}***")
  73. print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}")
  74. if check_all_splits is False:
  75. print(f"{'CONTEXT':<9}{'BSZ':<5}{'QLEN':<6}{'FA2':<10}{'FA3':<9}{'RATIO':<7}{'GB/s':<10}")
  76. for context_seqlen, num_requests, query_seqlen in all_batch_configs:
  77. bytes_kv = (context_seqlen * num_requests * nheads_kv * headdim * 4)
  78. bytes_q = (query_seqlen * num_requests * nheads_q * headdim * 4)
  79. blockH = round_up_to_power_of_2(nheads_q//nheads_kv)
  80. blockM = 128 # true for hdim 128 causal and hdim 64
  81. blockM_div_H = blockM//blockH
  82. num_work_tiles = nheads_kv * num_requests * math.ceil(query_seqlen/blockM_div_H)
  83. q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=dtype)
  84. cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests]
  85. cache_seqlens = torch.tensor(
  86. [context_seqlen] * num_requests, dtype=torch.int32, device="cuda"
  87. )
  88. fa2_time_heuristic = timeit(
  89. flash_attn.flash_attn_with_kvcache,
  90. q=q,
  91. k_cache=k_cache,
  92. v_cache=v_cache,
  93. cache_seqlens=cache_seqlens,
  94. cache_batch_idx=cache_idxs,
  95. causal=causal,
  96. ) * 1000. * 1000.
  97. # fastest_splitk_time = float("inf")
  98. # fastest_splitk = 0
  99. # for i in range(1, max_splits):
  100. # t = timeit(
  101. # flash_attn.flash_attn_with_kvcache,
  102. # q=q,
  103. # k_cache=k_cache,
  104. # v_cache=v_cache,
  105. # cache_seqlens=cache_seqlens,
  106. # cache_batch_idx=cache_idxs,
  107. # causal=causal,
  108. # num_splits=i,
  109. # ) * 1000. * 1000.
  110. # if t < fastest_splitk_time:
  111. # fastest_splitk_time = t
  112. # fastest_splitk = i
  113. fa3_time_one_split = timeit(
  114. flash_attn_interface.flash_attn_with_kvcache,
  115. q=q,
  116. k_cache=k_cache,
  117. v_cache=v_cache,
  118. cache_seqlens=cache_seqlens,
  119. cache_batch_idx=cache_idxs,
  120. causal=causal,
  121. gqa_parallel=False,
  122. num_splits=1,
  123. ) * 1000. * 1000.
  124. fa3_time_gqa_heuristic = timeit(
  125. flash_attn_interface.flash_attn_with_kvcache,
  126. q=q,
  127. k_cache=k_cache,
  128. v_cache=v_cache,
  129. cache_seqlens=cache_seqlens,
  130. cache_batch_idx=cache_idxs,
  131. causal=causal,
  132. gqa_parallel=True,
  133. num_splits=0,
  134. max_seqlen_k_hint=context_seqlen
  135. ) * 1000. * 1000.
  136. if check_all_splits:
  137. fa3_fastest_num_splits = 0
  138. fa3_fastest_splitk_time = float("inf")
  139. for num_splits in range(1, max_splits):
  140. t = timeit(
  141. flash_attn_interface.flash_attn_with_kvcache,
  142. q=q,
  143. k_cache=k_cache,
  144. v_cache=v_cache,
  145. cache_seqlens=cache_seqlens,
  146. cache_batch_idx=cache_idxs,
  147. causal=causal,
  148. gqa_parallel=False,
  149. num_splits=num_splits
  150. ) * 1000. * 1000.
  151. out0 = flash_attn_interface.flash_attn_with_kvcache(
  152. q=q,
  153. k_cache=k_cache,
  154. v_cache=v_cache,
  155. cache_seqlens=cache_seqlens,
  156. cache_batch_idx=cache_idxs,
  157. causal=causal,
  158. gqa_parallel=False,
  159. num_splits=num_splits
  160. )
  161. out1 = flash_attn_interface.flash_attn_with_kvcache(
  162. q=q,
  163. k_cache=k_cache,
  164. v_cache=v_cache,
  165. cache_seqlens=cache_seqlens,
  166. cache_batch_idx=cache_idxs,
  167. causal=causal,
  168. gqa_parallel=False,
  169. num_splits=1
  170. )
  171. max_diff = (out0 - out1).abs().max().item()
  172. mean_diff = (out0 - out1).abs().mean().item()
  173. # print (f"splits {num_splits}, out diff-max, {max_diff}, out diff-mean, {mean_diff}, time {t:.2f}")
  174. # print (f"splits {num_splits}, time {t:.2f}")
  175. if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4:
  176. print(f"Numerical error too high: Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}")
  177. if t < fa3_fastest_splitk_time:
  178. fa3_fastest_splitk_time = t
  179. fa3_fastest_num_splits = num_splits
  180. fa3_fastest_num_splits_gqa = 0
  181. fa3_fastest_splitk_time_gqa = float("inf")
  182. for num_splits in range(1, max_splits):
  183. t = timeit(
  184. flash_attn_interface.flash_attn_with_kvcache,
  185. q=q,
  186. k_cache=k_cache,
  187. v_cache=v_cache,
  188. cache_seqlens=cache_seqlens,
  189. cache_batch_idx=cache_idxs,
  190. causal=causal,
  191. gqa_parallel=True,
  192. num_splits=num_splits
  193. ) * 1000. * 1000.
  194. out0 = flash_attn_interface.flash_attn_with_kvcache(
  195. q=q,
  196. k_cache=k_cache,
  197. v_cache=v_cache,
  198. cache_seqlens=cache_seqlens,
  199. cache_batch_idx=cache_idxs,
  200. causal=causal,
  201. gqa_parallel=True,
  202. num_splits=num_splits
  203. )
  204. out1 = flash_attn_interface.flash_attn_with_kvcache(
  205. q=q,
  206. k_cache=k_cache,
  207. v_cache=v_cache,
  208. cache_seqlens=cache_seqlens,
  209. cache_batch_idx=cache_idxs,
  210. causal=causal,
  211. gqa_parallel=True,
  212. num_splits=1
  213. )
  214. max_diff = (out0 - out1).abs().max().item()
  215. mean_diff = (out0 - out1).abs().mean().item()
  216. # print (f"gqa splits {num_splits}, out gqa diff-max {max_diff}, out gqa diff-mean {mean_diff}, time {t:.2f}")
  217. # print (f"gqa splits {num_splits}, time {t:.2f}")
  218. if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4:
  219. print(f"Numerical error too high (gqa): Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}")
  220. if t < fa3_fastest_splitk_time_gqa:
  221. fa3_fastest_splitk_time_gqa = t
  222. fa3_fastest_num_splits_gqa = num_splits
  223. efficiency = (num_work_tiles * fa3_fastest_num_splits_gqa)/num_sms
  224. heuristic_ratio = fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa
  225. # remeasure to smooth anomalies
  226. if heuristic_ratio > 1.1:
  227. fa3_time_gqa_heuristic = timeit(
  228. flash_attn_interface.flash_attn_with_kvcache,
  229. q=q,
  230. k_cache=k_cache,
  231. v_cache=v_cache,
  232. cache_seqlens=cache_seqlens,
  233. cache_batch_idx=cache_idxs,
  234. causal=causal,
  235. gqa_parallel=True,
  236. # num_splits=num_splits_select,
  237. # num_splits=1,
  238. num_splits=0,
  239. max_seqlen_k_hint=context_seqlen
  240. ) * 1000. * 1000.
  241. fa3_fastest_splitk_time_gqa = timeit(
  242. flash_attn_interface.flash_attn_with_kvcache,
  243. q=q,
  244. k_cache=k_cache,
  245. v_cache=v_cache,
  246. cache_seqlens=cache_seqlens,
  247. cache_batch_idx=cache_idxs,
  248. causal=causal,
  249. gqa_parallel=True,
  250. num_splits=fa3_fastest_num_splits_gqa
  251. ) * 1000. * 1000.
  252. if check_all_splits is True:
  253. print(
  254. f"CONTEXT:{context_seqlen}, BSZ:{num_requests}, QLEN:{query_seqlen}, "
  255. f"FA2:{fa2_time_heuristic:.2f}, "
  256. # f"FA2 MANUAL:{fastest_splitk_time:.2f}, "
  257. # f"FA2 NUM SPLITS:{fastest_splitk}, "
  258. # f"FA3 NOGQA NOSPLIT:{fa3_time_one_split:.2f}, "
  259. # f"FA3 NOGQA SPLIT MANUAL:{fa3_fastest_splitk_time:.2f}, "
  260. # f"FA3 NOSPLIT:{fa3_time_one_split_gqa:.2f}, "
  261. f"FA3 SPLIT MANUAL:{fa3_fastest_splitk_time_gqa:.2f}, "
  262. f"FA3:{fa3_time_gqa_heuristic:.2f}, "
  263. # f"FA3 RATIO (NONSPLIT/SPLIT):{fa3_time_one_split_gqa/fa3_time_gqa_heuristic:.2f}, "
  264. # f"FA2 NUM SPLITS:{fastest_splitk}, "
  265. # f"FA3 NOGQA NUM SPLITS:{fa3_fastest_num_splits}, "
  266. f"FA3 NUM SPLITS:{fa3_fastest_num_splits_gqa}, "
  267. # f"RATIO (FA2/3):{fa2_time_heuristic/fa3_time_gqa_heuristic:.2f}, "
  268. f"RATIO:{fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa:.2f}, "
  269. f"EFF:{efficiency:.2f}, "
  270. f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}"
  271. )
  272. if check_all_splits is False:
  273. print(
  274. f"{context_seqlen:<9}{num_requests:<5}{query_seqlen:<6}"
  275. f"{fa2_time_heuristic:<10.2f}{fa3_time_gqa_heuristic:<9.2f}"
  276. f"{fa2_time_heuristic/fa3_time_gqa_heuristic:<7.2f}"
  277. f"{bytes_kv/fa3_time_gqa_heuristic * 1e-3:<10.2f}"
  278. )
  279. if __name__ == "__main__":
  280. main()