123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- import torch
- import flash_attn
- import flash_attn_interface
- import itertools
- import time
- import math
- import torch.utils.benchmark as benchmark
- def round_up_to_power_of_2(x):
- if x <= 1:
- return 1
- return 1 << (x - 1).bit_length()
- def timeit(fn, *args, **kwargs):
- torch.cuda.synchronize()
- # Warmup
- for _ in range(5):
- fn(*args, **kwargs)
-
- # Benchmark using PyTorch Timer
- t = benchmark.Timer(
- stmt='fn(*args, **kwargs)',
- globals={'fn': fn, 'args': args, 'kwargs': kwargs}
- )
-
- # Measure execution time
- measurement = t.timeit(20) # Runs the function 20 times
- # measurement = t.blocked_autorange(min_run_time=1)
- avg_time = measurement.mean # Average time in seconds
- return avg_time
- def main():
- num_sms = torch.cuda.get_device_properties(
- torch.cuda.current_device()
- ).multi_processor_count
- max_splits = 129
- check_all_splits = False
- causal = True
- # causal = False
- # dtype=torch.float16
- dtype=torch.bfloat16
- torch.manual_seed(42)
- model_configs = [
- # ("Gemma-2-2B", 8, 4, 256),
- # ("Gemma-2-9B", 16, 8, 256),
- # ("Gemma-2-27B", 32, 16, 128),
- # ("Qwen-2.5-0.5B", 14, 2, 64),
- # ("Qwen-2.5-1.5B", 12, 2, 128),
- # ("Qwen-2.5-7B", 28, 4, 128),
- # ("Llama-3.1-8B", 32, 8, 128),
- ("Llama-3.1-70B", 64, 8, 128),
- # ("Llama-3.1-405B", 128, 8, 128),
- # ("Llama-3.2-1B", 32, 8, 64),
- # ("Llama-3.2-3B", 24, 8, 128),
- # ("Nemotron-4-15B", 48, 8, 128),
- ]
- all_batch_configs = []
- all_batch_configs.extend(itertools.product(
- # [1024, 2048, 4096, 8192, 16384, 32768, 131072], # context_seqlen
- [4096, 16384, 65536], # context_seqlen
- # [131072], # context_seqlen
- # [i for i in range(1, (num_sms) + 1)], # num_requests
- [1, 4, 8, 16], # num_requests
- # [1], # num_requests
- [1, 4, 8, 16], # query_seqlen
- # [1], # query_seqlen
- ))
- num_caches = max(reqs for _, reqs, _ in all_batch_configs)
- cache_seqlen = max(seqlen for seqlen, _, _ in all_batch_configs)
- for model_name, nheads_q, nheads_kv, headdim in model_configs:
- k_cache = torch.randn(
- (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
- )
- v_cache = torch.randn(
- (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
- )
- print(f"***{model_name}***")
- print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}")
-
- if check_all_splits is False:
- print(f"{'CONTEXT':<9}{'BSZ':<5}{'QLEN':<6}{'FA2':<10}{'FA3':<9}{'RATIO':<7}{'GB/s':<10}")
- for context_seqlen, num_requests, query_seqlen in all_batch_configs:
- bytes_kv = (context_seqlen * num_requests * nheads_kv * headdim * 4)
- bytes_q = (query_seqlen * num_requests * nheads_q * headdim * 4)
- blockH = round_up_to_power_of_2(nheads_q//nheads_kv)
- blockM = 128 # true for hdim 128 causal and hdim 64
- blockM_div_H = blockM//blockH
- num_work_tiles = nheads_kv * num_requests * math.ceil(query_seqlen/blockM_div_H)
- q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=dtype)
- cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests]
- cache_seqlens = torch.tensor(
- [context_seqlen] * num_requests, dtype=torch.int32, device="cuda"
- )
- fa2_time_heuristic = timeit(
- flash_attn.flash_attn_with_kvcache,
- q=q,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens,
- cache_batch_idx=cache_idxs,
- causal=causal,
- ) * 1000. * 1000.
- # fastest_splitk_time = float("inf")
- # fastest_splitk = 0
- # for i in range(1, max_splits):
- # t = timeit(
- # flash_attn.flash_attn_with_kvcache,
- # q=q,
- # k_cache=k_cache,
- # v_cache=v_cache,
- # cache_seqlens=cache_seqlens,
- # cache_batch_idx=cache_idxs,
- # causal=causal,
- # num_splits=i,
- # ) * 1000. * 1000.
- # if t < fastest_splitk_time:
- # fastest_splitk_time = t
- # fastest_splitk = i
- fa3_time_one_split = timeit(
- flash_attn_interface.flash_attn_with_kvcache,
- q=q,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens,
- cache_batch_idx=cache_idxs,
- causal=causal,
- gqa_parallel=False,
- num_splits=1,
- ) * 1000. * 1000.
- fa3_time_gqa_heuristic = timeit(
- flash_attn_interface.flash_attn_with_kvcache,
- q=q,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens,
- cache_batch_idx=cache_idxs,
- causal=causal,
- gqa_parallel=True,
- num_splits=0,
- max_seqlen_k_hint=context_seqlen
- ) * 1000. * 1000.
- if check_all_splits:
-
- fa3_fastest_num_splits = 0
- fa3_fastest_splitk_time = float("inf")
-
- for num_splits in range(1, max_splits):
- t = timeit(
- flash_attn_interface.flash_attn_with_kvcache,
- q=q,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens,
- cache_batch_idx=cache_idxs,
- causal=causal,
- gqa_parallel=False,
- num_splits=num_splits
- ) * 1000. * 1000.
- out0 = flash_attn_interface.flash_attn_with_kvcache(
- q=q,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens,
- cache_batch_idx=cache_idxs,
- causal=causal,
- gqa_parallel=False,
- num_splits=num_splits
- )
- out1 = flash_attn_interface.flash_attn_with_kvcache(
- q=q,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens,
- cache_batch_idx=cache_idxs,
- causal=causal,
- gqa_parallel=False,
- num_splits=1
- )
- max_diff = (out0 - out1).abs().max().item()
- mean_diff = (out0 - out1).abs().mean().item()
- # print (f"splits {num_splits}, out diff-max, {max_diff}, out diff-mean, {mean_diff}, time {t:.2f}")
- # print (f"splits {num_splits}, time {t:.2f}")
- if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4:
- print(f"Numerical error too high: Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}")
- if t < fa3_fastest_splitk_time:
- fa3_fastest_splitk_time = t
- fa3_fastest_num_splits = num_splits
- fa3_fastest_num_splits_gqa = 0
- fa3_fastest_splitk_time_gqa = float("inf")
- for num_splits in range(1, max_splits):
- t = timeit(
- flash_attn_interface.flash_attn_with_kvcache,
- q=q,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens,
- cache_batch_idx=cache_idxs,
- causal=causal,
- gqa_parallel=True,
- num_splits=num_splits
- ) * 1000. * 1000.
- out0 = flash_attn_interface.flash_attn_with_kvcache(
- q=q,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens,
- cache_batch_idx=cache_idxs,
- causal=causal,
- gqa_parallel=True,
- num_splits=num_splits
- )
- out1 = flash_attn_interface.flash_attn_with_kvcache(
- q=q,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens,
- cache_batch_idx=cache_idxs,
- causal=causal,
- gqa_parallel=True,
- num_splits=1
- )
- max_diff = (out0 - out1).abs().max().item()
- mean_diff = (out0 - out1).abs().mean().item()
- # print (f"gqa splits {num_splits}, out gqa diff-max {max_diff}, out gqa diff-mean {mean_diff}, time {t:.2f}")
- # print (f"gqa splits {num_splits}, time {t:.2f}")
- if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4:
- print(f"Numerical error too high (gqa): Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}")
- if t < fa3_fastest_splitk_time_gqa:
- fa3_fastest_splitk_time_gqa = t
- fa3_fastest_num_splits_gqa = num_splits
-
- efficiency = (num_work_tiles * fa3_fastest_num_splits_gqa)/num_sms
- heuristic_ratio = fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa
- # remeasure to smooth anomalies
- if heuristic_ratio > 1.1:
- fa3_time_gqa_heuristic = timeit(
- flash_attn_interface.flash_attn_with_kvcache,
- q=q,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens,
- cache_batch_idx=cache_idxs,
- causal=causal,
- gqa_parallel=True,
- # num_splits=num_splits_select,
- # num_splits=1,
- num_splits=0,
- max_seqlen_k_hint=context_seqlen
- ) * 1000. * 1000.
- fa3_fastest_splitk_time_gqa = timeit(
- flash_attn_interface.flash_attn_with_kvcache,
- q=q,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens,
- cache_batch_idx=cache_idxs,
- causal=causal,
- gqa_parallel=True,
- num_splits=fa3_fastest_num_splits_gqa
- ) * 1000. * 1000.
- if check_all_splits is True:
- print(
- f"CONTEXT:{context_seqlen}, BSZ:{num_requests}, QLEN:{query_seqlen}, "
- f"FA2:{fa2_time_heuristic:.2f}, "
- # f"FA2 MANUAL:{fastest_splitk_time:.2f}, "
- # f"FA2 NUM SPLITS:{fastest_splitk}, "
- # f"FA3 NOGQA NOSPLIT:{fa3_time_one_split:.2f}, "
- # f"FA3 NOGQA SPLIT MANUAL:{fa3_fastest_splitk_time:.2f}, "
- # f"FA3 NOSPLIT:{fa3_time_one_split_gqa:.2f}, "
- f"FA3 SPLIT MANUAL:{fa3_fastest_splitk_time_gqa:.2f}, "
- f"FA3:{fa3_time_gqa_heuristic:.2f}, "
- # f"FA3 RATIO (NONSPLIT/SPLIT):{fa3_time_one_split_gqa/fa3_time_gqa_heuristic:.2f}, "
- # f"FA2 NUM SPLITS:{fastest_splitk}, "
- # f"FA3 NOGQA NUM SPLITS:{fa3_fastest_num_splits}, "
- f"FA3 NUM SPLITS:{fa3_fastest_num_splits_gqa}, "
- # f"RATIO (FA2/3):{fa2_time_heuristic/fa3_time_gqa_heuristic:.2f}, "
- f"RATIO:{fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa:.2f}, "
- f"EFF:{efficiency:.2f}, "
- f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}"
- )
- if check_all_splits is False:
- print(
- f"{context_seqlen:<9}{num_requests:<5}{query_seqlen:<6}"
- f"{fa2_time_heuristic:<10.2f}{fa3_time_gqa_heuristic:<9.2f}"
- f"{fa2_time_heuristic/fa3_time_gqa_heuristic:<7.2f}"
- f"{bytes_kv/fa3_time_gqa_heuristic * 1e-3:<10.2f}"
- )
- if __name__ == "__main__":
- main()
|