test_flash_attn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. import math
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from einops import rearrange, repeat
  6. from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, _flash_attn_forward
  7. from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref
  8. ABS_TOL = 5e-3
  9. REL_TOL = 1e-1
  10. def print_diffs(out, out_ref):
  11. out_1d = out.flatten()
  12. out_ref_1d = out_ref.flatten()
  13. for idx, (e_o, e_o_ref) in enumerate(zip(out_1d, out_ref_1d)):
  14. diff = e_o - e_o_ref
  15. abs_diff = abs(diff)
  16. abs_ref = abs(e_o_ref + 1e-5)
  17. relative_diff = abs_diff / abs_ref
  18. if abs_diff > ABS_TOL or relative_diff > REL_TOL:
  19. print(f"==== diff ==== {idx}, test: {e_o}, ref: {e_o_ref}")
  20. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  21. # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
  22. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  23. # @pytest.mark.parametrize("mha_type", ["mha"])
  24. @pytest.mark.parametrize("causal", [False, True])
  25. # @pytest.mark.parametrize("causal", [True])
  26. @pytest.mark.parametrize("deterministic", [False, True])
  27. # @pytest.mark.parametrize("deterministic", [True])
  28. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  29. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  30. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  31. # @pytest.mark.parametrize('d', [56, 80])
  32. # @pytest.mark.parametrize("d", [64, 128, 256])
  33. # @pytest.mark.parametrize("d", [64, 96, 128])
  34. # @pytest.mark.parametrize("d", [64, 128])
  35. @pytest.mark.parametrize("d", [64, 128, 256])
  36. @pytest.mark.parametrize("descale", [1.0])
  37. # @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0])
  38. @pytest.mark.parametrize(
  39. "seqlen_q,seqlen_k",
  40. [
  41. (1, 1),
  42. # (257, 1),
  43. (64, 128),
  44. (128, 128),
  45. (256, 256),
  46. (113, 203),
  47. (128, 217),
  48. (113, 211),
  49. (108, 256),
  50. (256, 512),
  51. (384, 256),
  52. (640, 128),
  53. (512, 256),
  54. (1024, 1024),
  55. (1023, 1024),
  56. (1024, 1023),
  57. (4096, 4096),
  58. ],
  59. )
  60. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  61. def test_flash_attn_output(
  62. seqlen_q, seqlen_k, d, causal, deterministic, mha_type, dtype, descale
  63. ):
  64. device = "cuda"
  65. if(dtype == torch.float8_e4m3fn):
  66. dtype_init = torch.float16
  67. else:
  68. dtype_init = dtype
  69. print(dtype)
  70. # set seed
  71. torch.random.manual_seed(0)
  72. # batch_size = 40
  73. # nheads = 16
  74. batch_size = 4
  75. nheads = 6
  76. nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  77. # nheads_kv = 2
  78. # batch_size = 9
  79. # nheads = 6
  80. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
  81. k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
  82. v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
  83. q = q.to(dtype)
  84. k = k.to(dtype)
  85. v = v.to(dtype)
  86. softmax_scale = q.shape[-1] ** (-0.5)
  87. descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda')
  88. descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda')
  89. descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda')
  90. if(dtype != torch.float8_e4m3fn):
  91. out, lse = flash_attn_func(q, k, v, causal=causal, deterministic=deterministic)
  92. else:
  93. out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward(
  94. q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v
  95. )
  96. q = q.to(dtype_init)
  97. k = k.to(dtype_init)
  98. v = v.to(dtype_init)
  99. if(dtype == torch.float8_e4m3fn):
  100. descale_q = descale_q.to(dtype_init)
  101. descale_k = descale_k.to(dtype_init)
  102. descale_v = descale_v.to(dtype_init)
  103. q = q * descale_q
  104. k = k * descale_k
  105. v = v * descale_v
  106. out_ref, attn_ref = attention_ref(
  107. q,
  108. k,
  109. v,
  110. None,
  111. None,
  112. causal=causal,
  113. )
  114. out_pt, attn_pt = attention_ref(
  115. q,
  116. k,
  117. v,
  118. None,
  119. None,
  120. causal=causal,
  121. upcast=False,
  122. reorder_ops=True,
  123. )
  124. # qk = torch.einsum('bshd,bthd->bhst', q, k).float()
  125. # m = qk.amax(-1, keepdim=True)
  126. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  127. # exp_sum = s_tmp.sum(-1)
  128. # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
  129. # lse_ref = torch.logsumexp(qk, dim=-1)
  130. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  131. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  132. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  133. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  134. # if not causal:
  135. # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
  136. # breakpoint()
  137. if d <= 128 and dtype != torch.float8_e4m3fn:
  138. g = torch.randn_like(out)
  139. do_o = (g.float() * out.float()).sum(-1)
  140. dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
  141. dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g)
  142. dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g)
  143. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  144. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  145. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  146. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  147. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  148. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  149. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  150. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  151. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  152. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  153. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  154. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  155. # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
  156. # P = torch.softmax(qk, -1)
  157. # dP = P * (dS - do_o.unsqueeze(1))
  158. # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
  159. # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
  160. # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
  161. # breakpoint()
  162. # Check that FlashAttention's numerical error is at most twice the numerical error
  163. # of a Pytorch implementation.
  164. # breakpoint()
  165. if(dtype != torch.float8_e4m3fn):
  166. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 3e-5
  167. else:
  168. # just test correctness of fp8 kernel w/o further quantization techniques
  169. assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item()
  170. if d <= 128 and dtype != torch.float8_e4m3fn:
  171. assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5
  172. assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 3e-5
  173. assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 3e-5
  174. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  175. # @pytest.mark.parametrize("dtype", [torch.float16])
  176. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  177. # @pytest.mark.parametrize("mha_type", ["mha"])
  178. @pytest.mark.parametrize("causal", [False, True])
  179. # @pytest.mark.parametrize("causal", [False])
  180. @pytest.mark.parametrize("deterministic", [False, True])
  181. # @pytest.mark.parametrize("deterministic", [False])
  182. # @pytest.mark.parametrize("add_unused_qkv", [False, True])
  183. @pytest.mark.parametrize("add_unused_qkv", [True])
  184. # @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  185. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  186. # @pytest.mark.parametrize('d', [128])
  187. # @pytest.mark.parametrize("d", [64, 128, 256])
  188. @pytest.mark.parametrize("d", [64, 128])
  189. # @pytest.mark.parametrize("d", [128])
  190. @pytest.mark.parametrize(
  191. "seqlen_q,seqlen_k",
  192. [
  193. (1, 1),
  194. (1, 3),
  195. (2, 1),
  196. (511, 1),
  197. (3, 513),
  198. (64, 128),
  199. (113, 203),
  200. (128, 128),
  201. (128, 217),
  202. (113, 211),
  203. (108, 256),
  204. (256, 512),
  205. (384, 256),
  206. (512, 256),
  207. (640, 128),
  208. (1024, 1024),
  209. (1023, 1024),
  210. (1024, 1023),
  211. (2048, 2048),
  212. ],
  213. )
  214. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  215. def test_flash_attn_varlen_output(
  216. seqlen_q, seqlen_k, d, causal, deterministic, add_unused_qkv, mha_type, dtype
  217. ):
  218. if (
  219. max(seqlen_q, seqlen_k) >= 2048
  220. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  221. ):
  222. pytest.skip() # Reference implementation OOM
  223. device = "cuda"
  224. # set seed
  225. torch.random.manual_seed(0)
  226. # batch_size = 1
  227. # nheads = 1
  228. batch_size = 9
  229. nheads = 6
  230. nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  231. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  232. k = torch.randn(
  233. batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
  234. )
  235. v = torch.randn(
  236. batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
  237. )
  238. query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random", zero_lengths=False)
  239. key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True)
  240. # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
  241. def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
  242. if add_unused:
  243. another_mask = generate_random_padding_mask(max_seq_len, bs, device)
  244. attn_mask = torch.logical_and(padding_mask, another_mask)
  245. unused_mask = torch.logical_xor(torch.logical_or(padding_mask, another_mask), attn_mask)
  246. else:
  247. attn_mask = padding_mask
  248. unused_mask = None
  249. return attn_mask, unused_mask
  250. query_padding_mask, query_unused_mask = _gen_unused_masks(query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device)
  251. key_padding_mask, key_unused_mask = _gen_unused_masks(key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device)
  252. (
  253. q_unpad,
  254. k_unpad,
  255. v_unpad,
  256. cu_seqlens_q,
  257. cu_seqlens_k,
  258. seqused_q,
  259. seqused_k,
  260. max_seqlen_q,
  261. max_seqlen_k,
  262. q,
  263. k,
  264. v,
  265. output_pad_fn,
  266. dq_pad_fn,
  267. dk_pad_fn,
  268. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask)
  269. # print("cu_seqlens_q: ", cu_seqlens_q)
  270. # print("cu_seqlens_k: ", cu_seqlens_k)
  271. # print("q_unpad, shape: ", q_unpad.shape)
  272. # print("k_unpad, shape: ", k_unpad.shape)
  273. # print("v_unpad, shape: ", v_unpad.shape)
  274. out_unpad, sm_lse = flash_attn_varlen_func(
  275. q_unpad,
  276. k_unpad,
  277. v_unpad,
  278. cu_seqlens_q,
  279. cu_seqlens_k,
  280. max_seqlen_q,
  281. max_seqlen_k,
  282. causal=causal,
  283. deterministic=deterministic,
  284. seqused_q=seqused_q,
  285. seqused_k=seqused_k,
  286. )
  287. out = output_pad_fn(out_unpad)
  288. q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
  289. out.masked_fill_(q_zero_masking, 0.0)
  290. dropout_mask = None
  291. out_ref, attn_ref = attention_ref(
  292. q,
  293. k,
  294. v,
  295. query_padding_mask,
  296. key_padding_mask,
  297. causal=causal,
  298. )
  299. out_pt, attn_pt = attention_ref(
  300. q,
  301. k,
  302. v,
  303. query_padding_mask,
  304. key_padding_mask,
  305. causal=causal,
  306. upcast=False,
  307. reorder_ops=True,
  308. )
  309. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  310. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  311. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  312. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  313. g = torch.randn_like(out)
  314. if d <= 128:
  315. (
  316. dq_unpad,
  317. dk_unpad,
  318. dv_unpad,
  319. ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
  320. dk = dk_pad_fn(dk_unpad)
  321. dv = dk_pad_fn(dv_unpad)
  322. k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
  323. dk.masked_fill_(k_zero_masking, 0.0)
  324. dv.masked_fill_(k_zero_masking, 0.0)
  325. (
  326. dq_ref,
  327. dk_ref,
  328. dv_ref,
  329. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  330. zero_masking = rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1")
  331. dk_ref.masked_fill_(zero_masking, 0.0)
  332. dv_ref.masked_fill_(zero_masking, 0.0)
  333. (
  334. dq_pt,
  335. dk_pt,
  336. dv_pt,
  337. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  338. dk_pt.masked_fill_(zero_masking, 0.0)
  339. dv_pt.masked_fill_(zero_masking, 0.0)
  340. dq = dq_pad_fn(dq_unpad)
  341. dq.masked_fill_(q_zero_masking, 0.0)
  342. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  343. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  344. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  345. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  346. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  347. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  348. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  349. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  350. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  351. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  352. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  353. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  354. # Check that FlashAttention's numerical error is at most twice the numerical error
  355. # of a Pytorch implementation.
  356. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  357. if d <= 128:
  358. assert (dq - dq_ref).abs().max().item() < 1e-4 or (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
  359. assert (dk - dk_ref).abs().max().item() < 1e-4 or (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
  360. assert (dk - dk_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()