test_attn_kvcache.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. import pytest
  2. from einops import rearrange, repeat
  3. import torch
  4. import flash_attn
  5. import flash_attn_interface
  6. import itertools
  7. import math
  8. import time
  9. def construct_local_mask(
  10. seqlen_q,
  11. seqlen_k,
  12. window_size=(-1, -1), # -1 means infinite window size
  13. query_padding_mask=None,
  14. key_padding_mask=None,
  15. device=None,
  16. key_leftpad=None,
  17. ):
  18. row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
  19. col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
  20. if key_leftpad is not None:
  21. key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
  22. col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
  23. col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
  24. sk = (
  25. seqlen_k
  26. if key_padding_mask is None
  27. else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
  28. )
  29. sq = (
  30. seqlen_q
  31. if query_padding_mask is None
  32. else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
  33. )
  34. if window_size[0] < 0:
  35. return col_idx > row_idx + sk - sq + window_size[1]
  36. else:
  37. sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
  38. return torch.logical_or(
  39. col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
  40. col_idx < row_idx + sk - sq - window_size[0],
  41. )
  42. def attention_ref(
  43. q,
  44. k,
  45. v,
  46. query_padding_mask=None,
  47. key_padding_mask=None,
  48. attn_bias=None,
  49. dropout_p=0.0,
  50. dropout_mask=None,
  51. causal=False,
  52. window_size=(-1, -1), # -1 means infinite window size
  53. softcap=0.0,
  54. upcast=True,
  55. reorder_ops=False,
  56. key_leftpad=None,
  57. ):
  58. """
  59. Arguments:
  60. q: (batch_size, seqlen_q, nheads, head_dim)
  61. k: (batch_size, seqlen_k, nheads_k, head_dim)
  62. v: (batch_size, seqlen_k, nheads_k, head_dim)
  63. query_padding_mask: (batch_size, seqlen_q)
  64. key_padding_mask: (batch_size, seqlen_k)
  65. attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
  66. dropout_p: float
  67. dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
  68. causal: whether to apply causal masking
  69. window_size: (int, int), left and right window size
  70. upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
  71. output back to fp16/bf16.
  72. reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
  73. without changing the math. This is to estimate the numerical error from operation
  74. reordering.
  75. Output:
  76. output: (batch_size, seqlen_q, nheads, head_dim)
  77. attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
  78. """
  79. if causal:
  80. window_size = (window_size[0], 0)
  81. dtype_og = q.dtype
  82. if upcast:
  83. q, k, v = q.float(), k.float(), v.float()
  84. seqlen_q, seqlen_k = q.shape[1], k.shape[1]
  85. k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
  86. v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
  87. d = q.shape[-1]
  88. if not reorder_ops:
  89. scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
  90. else:
  91. scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
  92. if softcap > 0:
  93. scores = scores / softcap
  94. scores = scores.tanh()
  95. scores = scores * softcap
  96. if key_padding_mask is not None:
  97. scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
  98. if window_size[0] >= 0 or window_size[1] >= 0:
  99. local_mask = construct_local_mask(
  100. seqlen_q,
  101. seqlen_k,
  102. window_size,
  103. query_padding_mask,
  104. key_padding_mask,
  105. q.device,
  106. key_leftpad=key_leftpad,
  107. )
  108. scores.masked_fill_(local_mask, float("-inf"))
  109. if attn_bias is not None:
  110. scores = scores + attn_bias
  111. attention = torch.softmax(scores, dim=-1).to(v.dtype)
  112. # Some rows might be completely masked out so we fill them with zero instead of NaN
  113. if window_size[0] >= 0 or window_size[1] >= 0:
  114. attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
  115. # We want to mask here so that the attention matrix doesn't have any NaNs
  116. # Otherwise we'll get NaN in dV
  117. if query_padding_mask is not None:
  118. attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
  119. dropout_scaling = 1.0 / (1 - dropout_p)
  120. # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
  121. # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
  122. if dropout_mask is not None:
  123. attention_drop = attention.masked_fill(~dropout_mask, 0.0)
  124. else:
  125. attention_drop = attention
  126. output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
  127. if query_padding_mask is not None:
  128. output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
  129. return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
  130. @pytest.mark.parametrize("causal", [True, False])
  131. @pytest.mark.parametrize("num_requests", [1, 4])
  132. @pytest.mark.parametrize("query_seqlen", [1, 8, 120])
  133. @pytest.mark.parametrize("context_seqlen", [1024, 3131, 4224])
  134. @pytest.mark.parametrize("headdim", [64, 128, 256])
  135. @pytest.mark.parametrize("gqa_parallel", [False, True])
  136. @pytest.mark.parametrize(
  137. "nheads_kv, gqa_ratio",
  138. [
  139. (1, 1),
  140. (2, 5),
  141. (3, 3),
  142. (1, 32),
  143. (5, 7),
  144. (8, 1),
  145. (1, 16),
  146. (12, 4),
  147. (8, 2),
  148. ],
  149. )
  150. def test_flash_attn_kvcache_nosplit(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, gqa_parallel):
  151. device = "cuda"
  152. num_caches = num_requests
  153. cache_seqlen = context_seqlen
  154. nheads_q = nheads_kv * gqa_ratio
  155. k_cache = torch.randn(
  156. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16
  157. )
  158. v_cache = torch.randn(
  159. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16
  160. )
  161. q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=torch.bfloat16)
  162. # cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests]
  163. cache_seqlens = torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device="cuda")
  164. torch.cuda.synchronize()
  165. out_ref, _ = attention_ref(
  166. q,
  167. k_cache,
  168. v_cache,
  169. causal=causal,
  170. )
  171. out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache(
  172. q=q,
  173. k_cache=k_cache,
  174. v_cache=v_cache,
  175. cache_seqlens=cache_seqlens,
  176. # cache_batch_idx=cache_idxs,
  177. causal=causal,
  178. num_splits=1,
  179. return_softmax_lse=True,
  180. gqa_parallel=gqa_parallel
  181. )
  182. torch.cuda.synchronize()
  183. assert ((out_ref - out_fa3).abs().max().item() <= 4e-3)
  184. assert ((out_ref - out_fa3).abs().mean().item() <= 2e-4)
  185. @pytest.mark.parametrize("causal", [True, False])
  186. @pytest.mark.parametrize("num_requests", [1, 3])
  187. @pytest.mark.parametrize("query_seqlen", [1, 8, 120])
  188. @pytest.mark.parametrize("context_seqlen", [1600, 4000, 5555])
  189. @pytest.mark.parametrize("headdim", [64, 128, 256])
  190. @pytest.mark.parametrize("gqa_parallel", [True, False])
  191. @pytest.mark.parametrize(
  192. "nheads_kv, gqa_ratio",
  193. [
  194. (1, 1),
  195. (2, 5),
  196. (3, 3),
  197. (1, 32),
  198. (5, 7),
  199. (8, 1),
  200. (1, 16),
  201. (12, 4),
  202. (8, 2),
  203. ],
  204. )
  205. def test_flash_attn_kvcache_nosplit_fp8(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, gqa_parallel):
  206. device = "cuda"
  207. num_caches = num_requests
  208. cache_seqlen = context_seqlen
  209. nheads_q = nheads_kv * gqa_ratio
  210. k_cache = torch.randn(
  211. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16
  212. )
  213. v_cache = torch.randn(
  214. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16
  215. )
  216. q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=torch.bfloat16)
  217. q = q.to(torch.float8_e4m3fn)
  218. k_cache = k_cache.to(torch.float8_e4m3fn)
  219. v_cache = v_cache.to(torch.float8_e4m3fn)
  220. # cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests]
  221. cache_seqlens = torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device="cuda")
  222. torch.cuda.synchronize()
  223. out_ref, _ = attention_ref(
  224. q,
  225. k_cache,
  226. v_cache,
  227. causal=causal,
  228. )
  229. descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda')
  230. descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda')
  231. descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda')
  232. out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache(
  233. q=q,
  234. k_cache=k_cache,
  235. v_cache=v_cache,
  236. cache_seqlens=cache_seqlens,
  237. # cache_batch_idx=cache_idxs,
  238. causal=causal,
  239. num_splits=1,
  240. return_softmax_lse=True,
  241. gqa_parallel=gqa_parallel,
  242. descale_q=descale_q, descale_k=descale_k, descale_v=descale_v
  243. )
  244. torch.cuda.synchronize()
  245. assert ((out_ref - out_fa3).abs().max().item() <= 4e-2)
  246. assert ((out_ref - out_fa3).abs().mean().item() <= 2e-3)
  247. @pytest.mark.parametrize("dtype", [torch.bfloat16])
  248. @pytest.mark.parametrize("use_heuristic_only", [True])
  249. # @pytest.mark.parametrize("use_heuristic_only", [False])
  250. @pytest.mark.parametrize("causal", [True, False])
  251. # @pytest.mark.parametrize("num_requests", [1, 4, 16])
  252. @pytest.mark.parametrize("num_requests", [1, 3])
  253. # @pytest.mark.parametrize("query_seqlen", [1, 16, 32, 128])
  254. @pytest.mark.parametrize("query_seqlen", [1, 8, 25])
  255. # @pytest.mark.parametrize("context_seqlen", [4096, 16384, 65536])
  256. @pytest.mark.parametrize("context_seqlen", [1600, 4000, 5555])
  257. @pytest.mark.parametrize("headdim", [64, 128, 256])
  258. @pytest.mark.parametrize("cache_seqlen_rand", [True, False])
  259. @pytest.mark.parametrize("gqa_parallel", [True, False])
  260. @pytest.mark.parametrize(
  261. "nheads_kv, gqa_ratio",
  262. [
  263. (1, 1),
  264. (4, 1),
  265. (2, 2),
  266. (3, 3),
  267. (4, 4),
  268. (2, 5),
  269. (3, 9),
  270. (1, 16),
  271. (1, 32),
  272. ],
  273. )
  274. def test_flash_attn_kvcache_output(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, use_heuristic_only, cache_seqlen_rand, gqa_parallel, dtype):
  275. device = "cuda"
  276. num_caches = 16
  277. if context_seqlen <= 65536:
  278. cache_seqlen = 65536
  279. else:
  280. cache_seqlen = context_seqlen
  281. nheads_q = nheads_kv * gqa_ratio
  282. if use_heuristic_only:
  283. max_splits = 1
  284. else:
  285. max_splits = 128
  286. k_cache = torch.randn(
  287. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16
  288. )
  289. v_cache = torch.randn(
  290. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16
  291. )
  292. q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=torch.bfloat16)
  293. q = q.to(dtype)
  294. k_cache = k_cache.to(dtype)
  295. v_cache = v_cache.to(dtype)
  296. cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests]
  297. cache_seqlens = torch.randint(1, context_seqlen-1, (num_requests,), dtype=torch.int32).to(device) if cache_seqlen_rand else torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device="cuda")
  298. torch.cuda.synchronize()
  299. out_ref, lse_ref = flash_attn_interface.flash_attn_with_kvcache(
  300. q=q,
  301. k_cache=k_cache,
  302. v_cache=v_cache,
  303. cache_seqlens=cache_seqlens,
  304. cache_batch_idx=cache_idxs,
  305. causal=causal,
  306. num_splits=1,
  307. return_softmax_lse=True,
  308. gqa_parallel=False
  309. )
  310. # i=0 case is with num splits heuristic
  311. for i in range(0, max_splits+1):
  312. out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache(
  313. q=q,
  314. k_cache=k_cache,
  315. v_cache=v_cache,
  316. cache_seqlens=cache_seqlens,
  317. cache_batch_idx=cache_idxs,
  318. causal=causal,
  319. num_splits=i,
  320. return_softmax_lse=True,
  321. gqa_parallel=gqa_parallel,
  322. max_seqlen_k_hint=context_seqlen
  323. )
  324. torch.cuda.synchronize()
  325. print ('output-ref', i, out_ref)
  326. print ('output-fa3',i, out_fa3)
  327. print ('output-max-diff', i, context_seqlen, (out_ref - out_fa3).abs().max().item())
  328. print ('output-mean-diff',i, context_seqlen, (out_ref - out_fa3).abs().mean().item())
  329. print ('lse-max-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().max().item())
  330. print ('lse-mean-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().mean().item())
  331. if cache_seqlen_rand:
  332. assert ((out_ref - out_fa3).abs().max().item() <= 1e-2)
  333. assert ((out_ref - out_fa3).abs().mean().item() <= 1e-3)
  334. else:
  335. assert ((out_ref - out_fa3).abs().max().item() <= 2e-3)
  336. assert ((out_ref - out_fa3).abs().mean().item() <= 1e-4)
  337. lse_max_ref = lse_ref.abs().max().item()
  338. lse_mean_ref = lse_ref.abs().mean().item()
  339. lse_max_fa3 = lse_fa3.abs().max().item()
  340. lse_mean_fa3 = lse_fa3.abs().mean().item()
  341. lse_max_diff = (lse_ref - lse_fa3).abs().max().item()
  342. lse_mean_diff = (lse_ref - lse_fa3).abs().mean().item()
  343. assert ((lse_max_ref == math.inf and lse_max_fa3 == math.inf) or lse_max_diff <= 1e-3)
  344. assert ((lse_mean_ref == math.inf and lse_mean_fa3 == math.inf) or lse_mean_diff <= 1e-4)
  345. @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
  346. @pytest.mark.parametrize("use_heuristic_only", [True])
  347. # @pytest.mark.parametrize("use_heuristic_only", [False])
  348. @pytest.mark.parametrize("causal", [True, False])
  349. # @pytest.mark.parametrize("num_requests", [1, 4, 16])
  350. @pytest.mark.parametrize("num_requests", [1, 3])
  351. # @pytest.mark.parametrize("query_seqlen", [1, 16, 32, 128])
  352. @pytest.mark.parametrize("query_seqlen", [1, 8, 25])
  353. # @pytest.mark.parametrize("context_seqlen", [4096, 16384, 65536])
  354. @pytest.mark.parametrize("context_seqlen", [1600, 4000, 5555])
  355. @pytest.mark.parametrize("headdim", [64, 128, 256])
  356. @pytest.mark.parametrize("cache_seqlen_rand", [True, False])
  357. @pytest.mark.parametrize("gqa_parallel", [True, False])
  358. @pytest.mark.parametrize(
  359. "nheads_kv, gqa_ratio",
  360. [
  361. (1, 1),
  362. (4, 1),
  363. (2, 2),
  364. (3, 3),
  365. (4, 4),
  366. (2, 5),
  367. (3, 9),
  368. (1, 16),
  369. (1, 32),
  370. ],
  371. )
  372. def test_flash_attn_kvcache_output_fp8(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, use_heuristic_only, cache_seqlen_rand, gqa_parallel, dtype):
  373. device = "cuda"
  374. num_caches = 16
  375. if context_seqlen <= 65536:
  376. cache_seqlen = 65536
  377. else:
  378. cache_seqlen = context_seqlen
  379. nheads_q = nheads_kv * gqa_ratio
  380. if use_heuristic_only:
  381. max_splits = 1
  382. else:
  383. max_splits = 128
  384. k_cache = torch.randn(
  385. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16
  386. )
  387. v_cache = torch.randn(
  388. (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16
  389. )
  390. q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=torch.bfloat16)
  391. q = q.to(dtype)
  392. k_cache = k_cache.to(dtype)
  393. v_cache = v_cache.to(dtype)
  394. cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests]
  395. cache_seqlens = torch.randint(1, context_seqlen-1, (num_requests,), dtype=torch.int32).to(device) if cache_seqlen_rand else torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device="cuda")
  396. torch.cuda.synchronize()
  397. descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda')
  398. descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda')
  399. descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda')
  400. out_ref, lse_ref = flash_attn_interface.flash_attn_with_kvcache(
  401. q=q,
  402. k_cache=k_cache,
  403. v_cache=v_cache,
  404. cache_seqlens=cache_seqlens,
  405. cache_batch_idx=cache_idxs,
  406. causal=causal,
  407. num_splits=1,
  408. return_softmax_lse=True,
  409. gqa_parallel=False,
  410. descale_q=descale_q, descale_k=descale_k, descale_v=descale_v
  411. )
  412. # i=0 case is with num splits heuristic
  413. for i in range(0, max_splits+1):
  414. out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache(
  415. q=q,
  416. k_cache=k_cache,
  417. v_cache=v_cache,
  418. cache_seqlens=cache_seqlens,
  419. cache_batch_idx=cache_idxs,
  420. causal=causal,
  421. num_splits=i,
  422. return_softmax_lse=True,
  423. gqa_parallel=gqa_parallel,
  424. max_seqlen_k_hint=context_seqlen,
  425. descale_q=descale_q, descale_k=descale_k, descale_v=descale_v
  426. )
  427. torch.cuda.synchronize()
  428. print ('output-ref', i, out_ref)
  429. print ('output-fa3',i, out_fa3)
  430. print ('output-max-diff', i, context_seqlen, (out_ref - out_fa3).abs().max().item())
  431. print ('output-mean-diff',i, context_seqlen, (out_ref - out_fa3).abs().mean().item())
  432. print ('lse-max-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().max().item())
  433. print ('lse-mean-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().mean().item())
  434. if cache_seqlen_rand:
  435. assert ((out_ref - out_fa3).abs().max().item() <= 1e-1)
  436. assert ((out_ref - out_fa3).abs().mean().item() <= 1e-2)
  437. else:
  438. assert ((out_ref - out_fa3).abs().max().item() <= 2e-2)
  439. assert ((out_ref - out_fa3).abs().mean().item() <= 2e-3)
  440. lse_max_ref = lse_ref.abs().max().item()
  441. lse_mean_ref = lse_ref.abs().mean().item()
  442. lse_max_fa3 = lse_fa3.abs().max().item()
  443. lse_mean_fa3 = lse_fa3.abs().mean().item()
  444. lse_max_diff = (lse_ref - lse_fa3).abs().max().item()
  445. lse_mean_diff = (lse_ref - lse_fa3).abs().mean().item()
  446. assert ((lse_max_ref == math.inf and lse_max_fa3 == math.inf) or lse_max_diff <= 1e-3)
  447. assert ((lse_mean_ref == math.inf and lse_mean_fa3 == math.inf) or lse_mean_diff <= 1e-4)
  448. if __name__ == "__main__":
  449. main()