test_flash_attn.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. import math
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from einops import rearrange, repeat
  6. from flash_attn.bert_padding import pad_input, unpad_input
  7. from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
  8. ABS_TOL = 5e-3
  9. REL_TOL = 1e-1
  10. def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
  11. assert mode in ["full", "random", "third"]
  12. if mode == "full":
  13. lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
  14. elif mode == "random":
  15. lengths = torch.randint(
  16. max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
  17. )
  18. elif mode == "third":
  19. lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
  20. padding_mask = (
  21. repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
  22. )
  23. return padding_mask
  24. def generate_qkv(
  25. q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
  26. ):
  27. """
  28. Arguments:
  29. q: (batch_size, seqlen_q, nheads, d)
  30. k: (batch_size, seqlen_k, nheads_k, d)
  31. v: (batch_size, seqlen_k, nheads_k, d)
  32. query_padding_mask: (batch_size, seqlen), bool
  33. key_padding_mask: (batch_size, seqlen), bool
  34. """
  35. assert not (kvpacked and qkvpacked)
  36. batch_size, seqlen_q, nheads, d = q.shape
  37. _, seqlen_k, nheads_k, _ = k.shape
  38. assert k.shape == (batch_size, seqlen_k, nheads_k, d)
  39. assert v.shape == (batch_size, seqlen_k, nheads_k, d)
  40. if query_padding_mask is not None:
  41. q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask)
  42. output_pad_fn = lambda output_unpad: pad_input(
  43. output_unpad, indices_q, batch_size, seqlen_q
  44. )
  45. else:
  46. q_unpad = rearrange(q, "b s h d -> (b s) h d")
  47. cu_seqlens_q = torch.arange(
  48. 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
  49. )
  50. max_seqlen_q = seqlen_q
  51. output_pad_fn = lambda output_unpad: rearrange(
  52. output_unpad, "(b s) h d -> b s h d", b=batch_size
  53. )
  54. if key_padding_mask is not None:
  55. k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, *rest = unpad_input(k, key_padding_mask)
  56. v_unpad, _, _, _, *rest = unpad_input(v, key_padding_mask)
  57. else:
  58. k_unpad = rearrange(k, "b s h d -> (b s) h d")
  59. v_unpad = rearrange(v, "b s h d -> (b s) h d")
  60. cu_seqlens_k = torch.arange(
  61. 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
  62. )
  63. max_seqlen_k = seqlen_k
  64. if qkvpacked:
  65. assert (query_padding_mask == key_padding_mask).all()
  66. assert nheads == nheads_k
  67. qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
  68. qkv = torch.stack([q, k, v], dim=2)
  69. if query_padding_mask is not None:
  70. dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
  71. else:
  72. dqkv_pad_fn = lambda dqkv_unpad: rearrange(
  73. dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
  74. )
  75. return (
  76. qkv_unpad.detach().requires_grad_(),
  77. cu_seqlens_q,
  78. max_seqlen_q,
  79. qkv.detach().requires_grad_(),
  80. output_pad_fn,
  81. dqkv_pad_fn,
  82. )
  83. elif kvpacked:
  84. kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
  85. kv = torch.stack([k, v], dim=2)
  86. dq_pad_fn = output_pad_fn
  87. if key_padding_mask is not None:
  88. dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
  89. else:
  90. dkv_pad_fn = lambda dkv_unpad: rearrange(
  91. dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
  92. )
  93. return (
  94. q_unpad.detach().requires_grad_(),
  95. kv_unpad.detach().requires_grad_(),
  96. cu_seqlens_q,
  97. cu_seqlens_k,
  98. max_seqlen_q,
  99. max_seqlen_k,
  100. q.detach().requires_grad_(),
  101. kv.detach().requires_grad_(),
  102. output_pad_fn,
  103. dq_pad_fn,
  104. dkv_pad_fn,
  105. )
  106. else:
  107. dq_pad_fn = output_pad_fn
  108. if key_padding_mask is not None:
  109. dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
  110. else:
  111. dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
  112. return (
  113. q_unpad.detach().requires_grad_(),
  114. k_unpad.detach().requires_grad_(),
  115. v_unpad.detach().requires_grad_(),
  116. cu_seqlens_q,
  117. cu_seqlens_k,
  118. max_seqlen_q,
  119. max_seqlen_k,
  120. q.detach().requires_grad_(),
  121. k.detach().requires_grad_(),
  122. v.detach().requires_grad_(),
  123. output_pad_fn,
  124. dq_pad_fn,
  125. dk_pad_fn,
  126. )
  127. def construct_local_mask(
  128. seqlen_q,
  129. seqlen_k,
  130. window_size=(-1, -1), # -1 means infinite window size
  131. query_padding_mask=None,
  132. key_padding_mask=None,
  133. device=None,
  134. ):
  135. row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
  136. col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
  137. sk = (
  138. seqlen_k
  139. if key_padding_mask is None
  140. else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
  141. )
  142. sq = (
  143. seqlen_q
  144. if query_padding_mask is None
  145. else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
  146. )
  147. if window_size[0] < 0:
  148. return col_idx > row_idx + sk - sq + window_size[1]
  149. else:
  150. sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
  151. return torch.logical_or(
  152. col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
  153. col_idx < row_idx + sk - sq - window_size[0],
  154. )
  155. def print_diffs(out, out_ref):
  156. out_1d = out.flatten()
  157. out_ref_1d = out_ref.flatten()
  158. for idx, (e_o, e_o_ref) in enumerate(zip(out_1d, out_ref_1d)):
  159. diff = e_o - e_o_ref
  160. abs_diff = abs(diff)
  161. abs_ref = abs(e_o_ref + 1e-5)
  162. relative_diff = abs_diff / abs_ref
  163. if abs_diff > ABS_TOL or relative_diff > REL_TOL:
  164. print(f"==== diff ==== {idx}, test: {e_o}, ref: {e_o_ref}")
  165. def attention_ref(
  166. q,
  167. k,
  168. v,
  169. query_padding_mask=None,
  170. key_padding_mask=None,
  171. attn_bias=None,
  172. dropout_p=0.0,
  173. dropout_mask=None,
  174. causal=False,
  175. q_scale=None, k_scale=None, v_scale=None,
  176. window_size=(-1, -1), # -1 means infinite window size
  177. softcap=0.0,
  178. upcast=True,
  179. reorder_ops=False,
  180. intermediate_dtype=None,
  181. ):
  182. """
  183. Arguments:
  184. q: (batch_size, seqlen_q, nheads, head_dim)
  185. k: (batch_size, seqlen_k, nheads, head_dim)
  186. v: (batch_size, seqlen_k, nheads, head_dim)
  187. query_padding_mask: (batch_size, seqlen_q)
  188. key_padding_mask: (batch_size, seqlen_k)
  189. attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
  190. dropout_p: float
  191. dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
  192. causal: whether to apply causal masking
  193. upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
  194. output back to fp16/bf16.
  195. reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
  196. without changing the math. This is to estimate the numerical error from operation
  197. reordering.
  198. Output:
  199. output: (batch_size, seqlen_q, nheads, head_dim)
  200. attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
  201. """
  202. if causal:
  203. window_size = (window_size[0], 0)
  204. dtype_og = q.dtype
  205. if upcast:
  206. q, k, v = q.float(), k.float(), v.float()
  207. if q_scale is not None:
  208. q = (q.float() * q_scale).to(dtype=q.dtype)
  209. if k_scale is not None:
  210. k = (k.float() * k_scale).to(dtype=k.dtype)
  211. if v_scale is not None:
  212. v = (v.float() * v_scale).to(dtype=v.dtype)
  213. seqlen_q, seqlen_k = q.shape[1], k.shape[1]
  214. k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
  215. v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
  216. d = q.shape[-1]
  217. if not reorder_ops:
  218. scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
  219. else:
  220. scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
  221. if softcap > 0:
  222. scores = torch.tanh(scores / softcap) * softcap
  223. if key_padding_mask is not None:
  224. scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
  225. if window_size[0] >= 0 or window_size[1] >= 0:
  226. local_mask = construct_local_mask(
  227. seqlen_q,
  228. seqlen_k,
  229. window_size,
  230. query_padding_mask,
  231. key_padding_mask,
  232. q.device,
  233. )
  234. scores.masked_fill_(local_mask, float("-inf"))
  235. if attn_bias is not None:
  236. scores = scores + attn_bias
  237. attention = torch.softmax(scores, dim=-1).to(v.dtype)
  238. # We want to mask here so that the attention matrix doesn't have any NaNs
  239. # Otherwise we'll get NaN in dV
  240. if query_padding_mask is not None:
  241. attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
  242. # Some rows might be completely masked out so we fill them with zero instead of NaN
  243. if window_size[0] >= 0 or window_size[1] >= 0:
  244. attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
  245. dropout_scaling = 1.0 / (1 - dropout_p)
  246. # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
  247. # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
  248. if dropout_mask is not None:
  249. attention_drop = attention.masked_fill(~dropout_mask, 0.0)
  250. else:
  251. attention_drop = attention
  252. if intermediate_dtype is not None:
  253. attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)
  254. output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
  255. if query_padding_mask is not None:
  256. output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
  257. return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
  258. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
  259. # @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float8_e4m3fn])
  260. # @pytest.mark.parametrize("dtype", [torch.bfloat16])
  261. # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
  262. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  263. # @pytest.mark.parametrize("mha_type", ["mha"])
  264. # @pytest.mark.parametrize("deterministic", [False, True])
  265. @pytest.mark.parametrize("deterministic", [False])
  266. @pytest.mark.parametrize("softcap", [0.0, 50.0])
  267. # @pytest.mark.parametrize("softcap", [50.0])
  268. @pytest.mark.parametrize("causal,local", [(False, False), (True, False), (False, True)])
  269. # @pytest.mark.parametrize("causal,local", [(False, False)])
  270. # @pytest.mark.parametrize("causal", [False])
  271. @pytest.mark.parametrize("V_colmajor", [False, True])
  272. # @pytest.mark.parametrize("V_colmajor", [False])
  273. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  274. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
  275. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  276. # @pytest.mark.parametrize('d', [56, 80])
  277. # @pytest.mark.parametrize("d", [64, 128, 256])
  278. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
  279. # @pytest.mark.parametrize("d", [64, 96, 128, 192])
  280. @pytest.mark.parametrize("d", [64, 96, 128, 192, 256])
  281. # @pytest.mark.parametrize("d", [128])
  282. @pytest.mark.parametrize(
  283. "seqlen_q,seqlen_k",
  284. [
  285. (64, 128),
  286. (128, 192),
  287. (256, 256),
  288. (113, 203),
  289. (113, 128),
  290. (128, 217),
  291. (113, 211),
  292. (108, 256),
  293. (256, 512),
  294. (384, 256),
  295. (640, 128),
  296. (512, 256),
  297. (1024, 1024),
  298. (1023, 1024),
  299. (1024, 1023),
  300. (2048, 2048),
  301. (8192, 8192),
  302. ],
  303. )
  304. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  305. def test_flash_attn_output(
  306. seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, mha_type, dtype
  307. ):
  308. if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn):
  309. pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn")
  310. if softcap > 0.0 and dtype == torch.float8_e4m3fn:
  311. pytest.skip("Softcap is not supported for float8_e4m3fn")
  312. device = "cuda"
  313. # set seed
  314. torch.random.manual_seed(0)
  315. # batch_size = 40
  316. # nheads = 16
  317. batch_size = 9 if seqlen_k <= 2048 else 2
  318. nheads = 6
  319. # batch_size = 1
  320. # nheads = 1
  321. nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  322. dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
  323. q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()
  324. if softcap > 0.0:
  325. # Ensure the values of qk are at least within softcap range.
  326. q_ref = (q_ref * softcap / 2).detach().requires_grad_()
  327. k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()
  328. v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()
  329. # Put window_size after QKV randn so that window_size changes from test to test
  330. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  331. if dtype == torch.float8_e4m3fn:
  332. q_scale, k_scale, v_scale = [torch.rand(1, device=device, dtype=torch.float32) * 2 for _ in range(3)]
  333. else:
  334. q_scale, k_scale, v_scale = None, None, None
  335. q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)]
  336. if V_colmajor:
  337. v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_()
  338. out, lse = flash_attn_func(
  339. q,
  340. k,
  341. v,
  342. causal=causal,
  343. q_scale=q_scale, k_scale=k_scale, v_scale=v_scale,
  344. window_size=window_size,
  345. softcap=softcap,
  346. )
  347. out_ref, attn_ref = attention_ref(
  348. q_ref,
  349. k_ref,
  350. v_ref,
  351. None,
  352. None,
  353. causal=causal,
  354. q_scale=q_scale, k_scale=k_scale, v_scale=v_scale,
  355. window_size=window_size,
  356. softcap=softcap
  357. )
  358. out_pt, attn_pt = attention_ref(
  359. q_ref,
  360. k_ref,
  361. v_ref,
  362. None,
  363. None,
  364. causal=causal,
  365. q_scale=q_scale, k_scale=k_scale, v_scale=v_scale,
  366. window_size=window_size,
  367. softcap=softcap,
  368. upcast=False,
  369. reorder_ops=True,
  370. intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
  371. )
  372. # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float()
  373. # m = qk.amax(-1, keepdim=True)
  374. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  375. # exp_sum = s_tmp.sum(-1)
  376. # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float())
  377. # lse_ref = torch.logsumexp(qk, dim=-1)
  378. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  379. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  380. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  381. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  382. # if not causal:
  383. # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
  384. # breakpoint()
  385. if dtype != torch.float8_e4m3fn and not V_colmajor:
  386. g = torch.randn_like(out)
  387. do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)
  388. import flashattn_hopper_cuda
  389. dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flashattn_hopper_cuda.bwd(
  390. g,
  391. q,
  392. k,
  393. v,
  394. out,
  395. lse,
  396. None,
  397. None,
  398. None,
  399. d ** (-0.5),
  400. causal,
  401. window_size[0], window_size[1],
  402. softcap,
  403. deterministic,
  404. )
  405. # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
  406. # assert (softmax_d - do_o).abs().max().item() <= 1e-5
  407. # assert dq_accum.abs().max().item() == 0.0
  408. # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
  409. # P = torch.softmax(qk, -1)
  410. # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1))
  411. # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
  412. # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
  413. # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
  414. # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
  415. dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)
  416. dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)
  417. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  418. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  419. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  420. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  421. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  422. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  423. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  424. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  425. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  426. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  427. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  428. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  429. # breakpoint()
  430. # Check that FlashAttention's numerical error is at most twice the numerical error
  431. # of a Pytorch implementation.
  432. # multiple = 2 if dtype != torch.float8_e4m3fn else 3
  433. multiple = 2
  434. assert (out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()
  435. if dtype != torch.float8_e4m3fn and not V_colmajor:
  436. multiple = 2 if softcap == 0.0 else 4
  437. assert (dq - dq_ref).abs().max().item() <= multiple * (dq_pt - dq_ref).abs().max().item()
  438. assert (dk - dk_ref).abs().max().item() <= multiple * (dk_pt - dk_ref).abs().max().item()
  439. assert (dv - dv_ref).abs().max().item() <= multiple * (dv_pt - dv_ref).abs().max().item()
  440. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
  441. # @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float8_e4m3fn])
  442. # @pytest.mark.parametrize("dtype", [torch.bfloat16])
  443. # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
  444. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  445. # @pytest.mark.parametrize("mha_type", ["mha"])
  446. # @pytest.mark.parametrize("deterministic", [False, True])
  447. @pytest.mark.parametrize("deterministic", [False])
  448. @pytest.mark.parametrize("softcap", [0.0, 50.0])
  449. # @pytest.mark.parametrize("softcap", [50.0])
  450. @pytest.mark.parametrize("causal,local", [(False, False), (True, False), (False, True)])
  451. # @pytest.mark.parametrize("causal,local", [(False, False)])
  452. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  453. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
  454. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  455. # @pytest.mark.parametrize('d', [56, 80])
  456. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
  457. # @pytest.mark.parametrize("d", [64, 96, 128])
  458. @pytest.mark.parametrize("d", [64, 96, 128, 192, 256])
  459. @pytest.mark.parametrize(
  460. "seqlen_q,seqlen_k",
  461. [
  462. (64, 128),
  463. (128, 128),
  464. (256, 256),
  465. (113, 203),
  466. (128, 217),
  467. (113, 211),
  468. (108, 256),
  469. (256, 512),
  470. (384, 256),
  471. (640, 128),
  472. (512, 256),
  473. (1024, 1024),
  474. (1023, 1024),
  475. (1024, 1023),
  476. (2048, 2048),
  477. (8192, 8192),
  478. ],
  479. )
  480. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  481. def test_flash_attn_varlen_output(
  482. seqlen_q, seqlen_k, d, causal, local, softcap, deterministic, mha_type, dtype
  483. ):
  484. if softcap > 0.0 and dtype == torch.float8_e4m3fn:
  485. pytest.skip("Softcap is not supported for float8_e4m3fn")
  486. device = "cuda"
  487. # set seed
  488. torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal))
  489. # batch_size = 40
  490. # nheads = 16
  491. batch_size = 9 if seqlen_q <= 2048 else 1
  492. nheads = 6
  493. # batch_size = 2
  494. # nheads = 2
  495. nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  496. dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
  497. q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()
  498. if softcap > 0.0:
  499. # Ensure the values of qk are at least within softcap range.
  500. q_ref = (q_ref * softcap / 2).detach().requires_grad_()
  501. k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()
  502. v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()
  503. # Put window_size after QKV randn so that window_size changes from test to test
  504. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  505. if dtype == torch.float8_e4m3fn:
  506. q_scale, k_scale, v_scale = [torch.rand(1, device=device, dtype=torch.float32) * 2 for _ in range(3)]
  507. else:
  508. q_scale, k_scale, v_scale = None, None, None
  509. q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]
  510. query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
  511. key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
  512. (
  513. q_unpad,
  514. k_unpad,
  515. v_unpad,
  516. cu_seqlens_q,
  517. cu_seqlens_k,
  518. max_seqlen_q,
  519. max_seqlen_k,
  520. q,
  521. k,
  522. v,
  523. output_pad_fn,
  524. dq_pad_fn,
  525. dk_pad_fn,
  526. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
  527. q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)]
  528. out_unpad, lse = flash_attn_varlen_func(
  529. q_unpad,
  530. k_unpad,
  531. v_unpad,
  532. cu_seqlens_q,
  533. cu_seqlens_k,
  534. max_seqlen_q,
  535. max_seqlen_k,
  536. causal=causal,
  537. q_scale=q_scale,
  538. k_scale=k_scale, v_scale=v_scale,
  539. window_size=window_size,
  540. softcap=softcap,
  541. )
  542. out = output_pad_fn(out_unpad)
  543. out_ref, attn_ref = attention_ref(
  544. q_ref,
  545. k_ref,
  546. v_ref,
  547. query_padding_mask,
  548. key_padding_mask,
  549. causal=causal,
  550. q_scale=q_scale, k_scale=k_scale, v_scale=v_scale,
  551. window_size=window_size,
  552. softcap=softcap
  553. )
  554. out_pt, attn_pt = attention_ref(
  555. q_ref,
  556. k_ref,
  557. v_ref,
  558. query_padding_mask,
  559. key_padding_mask,
  560. causal=causal,
  561. q_scale=q_scale, k_scale=k_scale, v_scale=v_scale,
  562. window_size=window_size,
  563. softcap=softcap,
  564. upcast=False,
  565. reorder_ops=True,
  566. intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
  567. )
  568. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  569. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  570. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  571. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  572. # if not causal:
  573. # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
  574. # breakpoint()
  575. if dtype != torch.float8_e4m3fn:
  576. g_unpad = torch.randn_like(out_unpad)
  577. do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
  578. import flashattn_hopper_cuda
  579. dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flashattn_hopper_cuda.bwd_varlen(
  580. g_unpad,
  581. q_unpad,
  582. k_unpad,
  583. v_unpad,
  584. out_unpad,
  585. lse,
  586. None,
  587. None,
  588. None,
  589. cu_seqlens_q,
  590. cu_seqlens_k,
  591. None, None,
  592. max_seqlen_q,
  593. max_seqlen_k,
  594. d ** (-0.5),
  595. causal,
  596. window_size[0], window_size[1],
  597. softcap,
  598. deterministic,
  599. )
  600. dq = dq_pad_fn(dq_unpad)
  601. dk = dk_pad_fn(dk_unpad)
  602. dv = dk_pad_fn(dv_unpad)
  603. # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
  604. # assert (softmax_d - do_o).abs().max().item() <= 1e-5
  605. # assert dq_accum.abs().max().item() == 0.0
  606. g = output_pad_fn(g_unpad)
  607. # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()
  608. # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
  609. # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
  610. # P = torch.softmax(qk, -1)
  611. # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1))
  612. # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
  613. # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
  614. # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
  615. # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
  616. dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)
  617. dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)
  618. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  619. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  620. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  621. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  622. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  623. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  624. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  625. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  626. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  627. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  628. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  629. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  630. # breakpoint()
  631. # Check that FlashAttention's numerical error is at most twice the numerical error
  632. # of a Pytorch implementation.
  633. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  634. if dtype != torch.float8_e4m3fn:
  635. multiple = 2 if softcap == 0.0 else 4
  636. assert (dq - dq_ref).abs().max().item() <= multiple * (dq_pt - dq_ref).abs().max().item()
  637. assert (dk - dk_ref).abs().max().item() <= multiple * (dk_pt - dk_ref).abs().max().item()
  638. assert (dv - dv_ref).abs().max().item() <= multiple * (dv_pt - dv_ref).abs().max().item()