1
0

test_kvcache.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import torch
  2. #from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache
  3. import flash_attn_interface as fa3
  4. import flash_attn as fa2
  5. import torch.utils.benchmark as benchmark
  6. import time
  7. import argparse
  8. import math
  9. parser = argparse.ArgumentParser(description='Process some integers.')
  10. parser.add_argument('--causal', action='store_true')
  11. parser.add_argument('--splits', type=int, default=1)
  12. parser.add_argument('--repeats', type=int, default=10)
  13. parser.add_argument('--validate', action='store_true')
  14. parser.add_argument('--gqa', action='store_true')
  15. args = parser.parse_args()
  16. def benchmark_fa_kv_old(fn, repeats=10, desc='', verbose=True, **kwinputs):
  17. """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
  18. if verbose:
  19. print(desc, '- Forward pass')
  20. t = benchmark.Timer(
  21. stmt='fn(**kwinputs)',
  22. globals={'fn': fn, 'kwinputs': kwinputs},
  23. num_threads=torch.get_num_threads(),
  24. )
  25. m = t.timeit(repeats)
  26. if verbose:
  27. print(desc, m)
  28. return t, m
  29. def benchmark_fa_kv(fn, repeats=10, *args, **kwargs):
  30. # warmup
  31. for _ in range(5):
  32. fn(*args, **kwargs)
  33. niters = repeats
  34. torch.cuda.synchronize()
  35. start = time.time()
  36. for _ in range(niters):
  37. fn(*args, **kwargs)
  38. torch.cuda.synchronize()
  39. end = time.time()
  40. return (end - start) / niters
  41. def main():
  42. # *SAMPLE CONFIG*
  43. # Model arch params:
  44. nheads_q = 64
  45. nheads_kv = 8
  46. headdim = 128
  47. #dtype = torch.bfloat16
  48. dtype = torch.float16
  49. # Cache settings:
  50. num_caches = 8
  51. cache_seqlen = 1024 * 16
  52. # Batching settings
  53. ntokens = 1024
  54. max_queries_per_batch = 4
  55. small_request_ntokens = 16
  56. # Input settings
  57. query_seqlens = [900, 12, 1]
  58. num_queries = len(query_seqlens)
  59. # Need to add empty queries to fill out `max_queries_per_batch`
  60. num_padding_queries = max_queries_per_batch - num_queries
  61. context_seqlens = [4096, 5120*2, 6145*2]
  62. #context_seqlens = [4096, 5120*2, 6152*2]
  63. # Validation
  64. assert sum(query_seqlens) <= ntokens
  65. assert all(s < small_request_ntokens for s in query_seqlens[1:])
  66. assert num_queries <= max_queries_per_batch
  67. assert all(s < cache_seqlen for s in context_seqlens)
  68. torch.manual_seed(5434)
  69. # Allocate some tensors
  70. k_cache = torch.randn(
  71. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
  72. )
  73. v_cache = torch.randn(
  74. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype
  75. )
  76. q_buf_large = torch.randn(
  77. (1, ntokens, nheads_q, headdim), device="cuda", dtype=dtype
  78. )
  79. cache_seqlen_large = torch.tensor(
  80. [context_seqlens[0]], dtype=torch.int32, device="cuda"
  81. )
  82. cache_idx_large = torch.tensor([1], dtype=torch.int32, device="cuda")
  83. q_buf_small = torch.randn(
  84. (max_queries_per_batch - 1, small_request_ntokens, nheads_q, headdim),
  85. device="cuda",
  86. dtype=dtype,
  87. )
  88. cache_seqlens_small = torch.tensor(
  89. context_seqlens[1:] + [0] * num_padding_queries, dtype=torch.int32, device="cuda"
  90. )
  91. cache_idxs_small = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[
  92. : max_queries_per_batch - 1
  93. ]
  94. if args.validate:
  95. # Call flash attn
  96. # First for the single full-sized query
  97. out0, lse0 = fa3.flash_attn_with_kvcache(
  98. q=q_buf_large,
  99. k_cache=k_cache,
  100. v_cache=v_cache,
  101. cache_seqlens=cache_seqlen_large,
  102. cache_batch_idx=cache_idx_large,
  103. causal=bool(args.causal),
  104. num_splits=args.splits,
  105. return_softmax_lse=True,
  106. #num_splits=1
  107. )
  108. # Second for n-1 small queries
  109. out1_split1, lse1_split1 = fa3.flash_attn_with_kvcache(
  110. q=q_buf_small,
  111. k_cache=k_cache,
  112. v_cache=v_cache,
  113. cache_seqlens=cache_seqlens_small,
  114. cache_batch_idx=cache_idxs_small,
  115. causal=bool(args.causal),
  116. num_splits=1,
  117. gqa_decoding=bool(args.gqa),
  118. return_softmax_lse=True,
  119. )
  120. # Second for n-1 small queries
  121. out1, lse1 = fa3.flash_attn_with_kvcache(
  122. q=q_buf_small,
  123. k_cache=k_cache,
  124. v_cache=v_cache,
  125. cache_seqlens=cache_seqlens_small,
  126. cache_batch_idx=cache_idxs_small,
  127. causal=bool(args.causal),
  128. num_splits=args.splits,
  129. gqa_decoding=bool(args.gqa),
  130. return_softmax_lse=True,
  131. )
  132. # Call flash attn
  133. # First for the single full-sized query
  134. out2 = fa2.flash_attn_with_kvcache(
  135. q=q_buf_large,
  136. k_cache=k_cache,
  137. v_cache=v_cache,
  138. cache_seqlens=cache_seqlen_large,
  139. cache_batch_idx=cache_idx_large,
  140. causal=bool(args.causal),
  141. num_splits=args.splits,
  142. )
  143. print ('big')
  144. print ('diff-max', (out0 - out2).abs().max().item(), cache_seqlens_small)
  145. print ('diff-mean', (out0 - out2).abs().mean().item())
  146. # Second for n-1 small queries
  147. out3, lse_fa2 = fa2.flash_attn_with_kvcache(
  148. q=q_buf_small,
  149. k_cache=k_cache,
  150. v_cache=v_cache,
  151. cache_seqlens=cache_seqlens_small,
  152. cache_batch_idx=cache_idxs_small,
  153. causal=bool(args.causal),
  154. num_splits=args.splits,
  155. return_softmax_lse=True,
  156. #num_splits=1
  157. )
  158. print ('small') #, out1)
  159. print ('lse', lse1, lse_fa2, (lse1 - lse_fa2).abs(), out1.shape)
  160. print ('lse-dif-max', (lse1 - lse_fa2).abs().max().item())
  161. print ('diff-max', (out1 - out3).abs().max().item())
  162. print ('diff-mean', (out1 - out3).abs().mean().item())
  163. print ('fa3', args.repeats)
  164. time_fa3_big = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats,
  165. q=q_buf_large,
  166. k_cache=k_cache,
  167. v_cache=v_cache,
  168. cache_seqlens=cache_seqlen_large,
  169. cache_batch_idx=cache_idx_large,
  170. causal=bool(args.causal),
  171. num_splits=args.splits,
  172. )
  173. time_fa3_small = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats,
  174. q=q_buf_small,
  175. k_cache=k_cache,
  176. v_cache=v_cache,
  177. cache_seqlens=cache_seqlens_small,
  178. cache_batch_idx=cache_idxs_small,
  179. causal=bool(args.causal),
  180. num_splits=args.splits,
  181. )
  182. print ('fa2 ')
  183. time_fa2_big = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats,
  184. q=q_buf_large,
  185. k_cache=k_cache,
  186. v_cache=v_cache,
  187. cache_seqlens=cache_seqlen_large,
  188. cache_batch_idx=cache_idx_large,
  189. causal=bool(args.causal),
  190. num_splits=args.splits
  191. )
  192. time_fa2_small = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats,
  193. q=q_buf_small,
  194. k_cache=k_cache,
  195. v_cache=v_cache,
  196. cache_seqlens=cache_seqlens_small,
  197. cache_batch_idx=cache_idxs_small,
  198. causal=bool(args.causal),
  199. num_splits=args.splits
  200. )
  201. print ('big (split, fa3, fa2, ratio):', args.splits, time_fa3_big * 1000000, time_fa2_big * 1000000, time_fa3_big / time_fa2_big)
  202. print ('small (split, fa3, fa2, ratio):', args.splits, time_fa3_small * 1000000, time_fa2_small * 1000000, time_fa3_small / time_fa2_small)
  203. if __name__ == "__main__":
  204. main()