123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234 |
- import torch
- #from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache
- import flash_attn_interface as fa3
- import flash_attn as fa2
- import torch.utils.benchmark as benchmark
- import time
- import argparse
- import math
- parser = argparse.ArgumentParser(description='Process some integers.')
- parser.add_argument('--causal', action='store_true')
- parser.add_argument('--splits', type=int, default=1)
- parser.add_argument('--repeats', type=int, default=10)
- parser.add_argument('--validate', action='store_true')
- parser.add_argument('--gqa', action='store_true')
- args = parser.parse_args()
- def benchmark_fa_kv_old(fn, repeats=10, desc='', verbose=True, **kwinputs):
- """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
- if verbose:
- print(desc, '- Forward pass')
- t = benchmark.Timer(
- stmt='fn(**kwinputs)',
- globals={'fn': fn, 'kwinputs': kwinputs},
- num_threads=torch.get_num_threads(),
- )
- m = t.timeit(repeats)
- if verbose:
- print(desc, m)
- return t, m
- def benchmark_fa_kv(fn, repeats=10, *args, **kwargs):
- # warmup
- for _ in range(5):
- fn(*args, **kwargs)
- niters = repeats
- torch.cuda.synchronize()
- start = time.time()
- for _ in range(niters):
- fn(*args, **kwargs)
- torch.cuda.synchronize()
- end = time.time()
- return (end - start) / niters
- def main():
- # *SAMPLE CONFIG*
- # Model arch params:
- nheads_q = 64
- nheads_kv = 8
- headdim = 128
- #dtype = torch.bfloat16
- dtype = torch.float16
- # Cache settings:
- num_caches = 8
- cache_seqlen = 1024 * 16
- # Batching settings
- ntokens = 1024
- max_queries_per_batch = 4
- small_request_ntokens = 16
- # Input settings
- query_seqlens = [900, 12, 1]
- num_queries = len(query_seqlens)
- # Need to add empty queries to fill out `max_queries_per_batch`
- num_padding_queries = max_queries_per_batch - num_queries
- context_seqlens = [4096, 5120*2, 6145*2]
- #context_seqlens = [4096, 5120*2, 6152*2]
- # Validation
- assert sum(query_seqlens) <= ntokens
- assert all(s < small_request_ntokens for s in query_seqlens[1:])
- assert num_queries <= max_queries_per_batch
- assert all(s < cache_seqlen for s in context_seqlens)
- torch.manual_seed(5434)
- # Allocate some tensors
- k_cache = torch.randn(
- (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
- )
- v_cache = torch.randn(
- (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
- )
- q_buf_large = torch.randn(
- (1, ntokens, nheads_q, headdim), device="cuda", dtype=dtype
- )
- cache_seqlen_large = torch.tensor(
- [context_seqlens[0]], dtype=torch.int32, device="cuda"
- )
- cache_idx_large = torch.tensor([1], dtype=torch.int32, device="cuda")
- q_buf_small = torch.randn(
- (max_queries_per_batch - 1, small_request_ntokens, nheads_q, headdim),
- device="cuda",
- dtype=dtype,
- )
- cache_seqlens_small = torch.tensor(
- context_seqlens[1:] + [0] * num_padding_queries, dtype=torch.int32, device="cuda"
- )
- cache_idxs_small = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[
- : max_queries_per_batch - 1
- ]
- if args.validate:
- # Call flash attn
- # First for the single full-sized query
- out0, lse0 = fa3.flash_attn_with_kvcache(
- q=q_buf_large,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlen_large,
- cache_batch_idx=cache_idx_large,
- causal=bool(args.causal),
- num_splits=args.splits,
- return_softmax_lse=True,
- #num_splits=1
- )
- # Second for n-1 small queries
- out1_split1, lse1_split1 = fa3.flash_attn_with_kvcache(
- q=q_buf_small,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens_small,
- cache_batch_idx=cache_idxs_small,
- causal=bool(args.causal),
- num_splits=1,
- gqa_decoding=bool(args.gqa),
- return_softmax_lse=True,
- )
- # Second for n-1 small queries
- out1, lse1 = fa3.flash_attn_with_kvcache(
- q=q_buf_small,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens_small,
- cache_batch_idx=cache_idxs_small,
- causal=bool(args.causal),
- num_splits=args.splits,
- gqa_decoding=bool(args.gqa),
- return_softmax_lse=True,
- )
- # Call flash attn
- # First for the single full-sized query
- out2 = fa2.flash_attn_with_kvcache(
- q=q_buf_large,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlen_large,
- cache_batch_idx=cache_idx_large,
- causal=bool(args.causal),
- num_splits=args.splits,
- )
- print ('big')
- print ('diff-max', (out0 - out2).abs().max().item(), cache_seqlens_small)
- print ('diff-mean', (out0 - out2).abs().mean().item())
- # Second for n-1 small queries
- out3, lse_fa2 = fa2.flash_attn_with_kvcache(
- q=q_buf_small,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens_small,
- cache_batch_idx=cache_idxs_small,
- causal=bool(args.causal),
- num_splits=args.splits,
- return_softmax_lse=True,
- #num_splits=1
- )
- print ('small') #, out1)
- print ('lse', lse1, lse_fa2, (lse1 - lse_fa2).abs(), out1.shape)
- print ('lse-dif-max', (lse1 - lse_fa2).abs().max().item())
- print ('diff-max', (out1 - out3).abs().max().item())
- print ('diff-mean', (out1 - out3).abs().mean().item())
- print ('fa3', args.repeats)
- time_fa3_big = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats,
- q=q_buf_large,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlen_large,
- cache_batch_idx=cache_idx_large,
- causal=bool(args.causal),
- num_splits=args.splits,
- )
- time_fa3_small = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats,
- q=q_buf_small,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens_small,
- cache_batch_idx=cache_idxs_small,
- causal=bool(args.causal),
- num_splits=args.splits,
- )
- print ('fa2 ')
- time_fa2_big = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats,
- q=q_buf_large,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlen_large,
- cache_batch_idx=cache_idx_large,
- causal=bool(args.causal),
- num_splits=args.splits
- )
- time_fa2_small = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats,
- q=q_buf_small,
- k_cache=k_cache,
- v_cache=v_cache,
- cache_seqlens=cache_seqlens_small,
- cache_batch_idx=cache_idxs_small,
- causal=bool(args.causal),
- num_splits=args.splits
- )
- print ('big (split, fa3, fa2, ratio):', args.splits, time_fa3_big * 1000000, time_fa2_big * 1000000, time_fa3_big / time_fa2_big)
- print ('small (split, fa3, fa2, ratio):', args.splits, time_fa3_small * 1000000, time_fa2_small * 1000000, time_fa3_small / time_fa2_small)
- if __name__ == "__main__":
- main()
|