test_flash_attn.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import math
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from einops import rearrange, repeat
  6. from flash_attn_interface import flash_attn_func
  7. ABS_TOL = 5e-3
  8. REL_TOL = 1e-1
  9. def construct_local_mask(
  10. seqlen_q,
  11. seqlen_k,
  12. window_size=(-1, -1), # -1 means infinite window size
  13. query_padding_mask=None,
  14. key_padding_mask=None,
  15. device=None,
  16. ):
  17. row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
  18. col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
  19. sk = (
  20. seqlen_k
  21. if key_padding_mask is None
  22. else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
  23. )
  24. sq = (
  25. seqlen_q
  26. if query_padding_mask is None
  27. else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
  28. )
  29. if window_size[0] < 0:
  30. return col_idx > row_idx + sk - sq + window_size[1]
  31. else:
  32. sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
  33. return torch.logical_or(
  34. col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
  35. col_idx < row_idx + sk - sq - window_size[0],
  36. )
  37. def print_diffs(out, out_ref):
  38. out_1d = out.flatten()
  39. out_ref_1d = out_ref.flatten()
  40. for idx, (e_o, e_o_ref) in enumerate(zip(out_1d, out_ref_1d)):
  41. diff = e_o - e_o_ref
  42. abs_diff = abs(diff)
  43. abs_ref = abs(e_o_ref + 1e-5)
  44. relative_diff = abs_diff / abs_ref
  45. if abs_diff > ABS_TOL or relative_diff > REL_TOL:
  46. print(f"==== diff ==== {idx}, test: {e_o}, ref: {e_o_ref}")
  47. def attention_ref(
  48. q,
  49. k,
  50. v,
  51. query_padding_mask=None,
  52. key_padding_mask=None,
  53. attn_bias=None,
  54. dropout_p=0.0,
  55. dropout_mask=None,
  56. causal=False,
  57. upcast=True,
  58. reorder_ops=False,
  59. ):
  60. """
  61. Arguments:
  62. q: (batch_size, seqlen_q, nheads, head_dim)
  63. k: (batch_size, seqlen_k, nheads, head_dim)
  64. v: (batch_size, seqlen_k, nheads, head_dim)
  65. query_padding_mask: (batch_size, seqlen_q)
  66. key_padding_mask: (batch_size, seqlen_k)
  67. attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
  68. dropout_p: float
  69. dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
  70. causal: whether to apply causal masking
  71. upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
  72. output back to fp16/bf16.
  73. reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
  74. without changing the math. This is to estimate the numerical error from operation
  75. reordering.
  76. Output:
  77. output: (batch_size, seqlen_q, nheads, head_dim)
  78. attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
  79. """
  80. dtype_og = q.dtype
  81. if upcast:
  82. q, k, v = q.float(), k.float(), v.float()
  83. seqlen_q, seqlen_k = q.shape[1], k.shape[1]
  84. k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
  85. v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
  86. d = q.shape[-1]
  87. if not reorder_ops:
  88. scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
  89. else:
  90. scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
  91. if key_padding_mask is not None:
  92. scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
  93. if causal:
  94. local_mask = construct_local_mask(
  95. seqlen_q,
  96. seqlen_k,
  97. (-1, 0),
  98. None,
  99. None,
  100. q.device,
  101. )
  102. scores.masked_fill_(local_mask, float("-inf"))
  103. if attn_bias is not None:
  104. scores = scores + attn_bias
  105. attention = torch.softmax(scores, dim=-1).to(v.dtype)
  106. # We want to mask here so that the attention matrix doesn't have any NaNs
  107. # Otherwise we'll get NaN in dV
  108. if query_padding_mask is not None:
  109. attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
  110. # Some rows might be completely masked out so we fill them with zero instead of NaN
  111. if causal:
  112. attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
  113. dropout_scaling = 1.0 / (1 - dropout_p)
  114. # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
  115. # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
  116. if dropout_mask is not None:
  117. attention_drop = attention.masked_fill(~dropout_mask, 0.0)
  118. else:
  119. attention_drop = attention
  120. output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
  121. if query_padding_mask is not None:
  122. output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
  123. return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
  124. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  125. # @pytest.mark.parametrize("dtype", [torch.bfloat16])
  126. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  127. # @pytest.mark.parametrize("mha_type", ["gqa"])
  128. @pytest.mark.parametrize("causal", [False, True])
  129. # @pytest.mark.parametrize("causal", [True])
  130. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  131. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  132. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  133. # @pytest.mark.parametrize('d', [56, 80])
  134. @pytest.mark.parametrize("d", [64, 128, 256])
  135. # @pytest.mark.parametrize("d", [256])
  136. @pytest.mark.parametrize(
  137. "seqlen_q,seqlen_k",
  138. [
  139. (64, 128),
  140. (128, 128),
  141. (256, 256),
  142. (113, 203),
  143. (128, 217),
  144. (113, 211),
  145. (108, 256),
  146. (256, 512),
  147. (384, 256),
  148. (640, 128),
  149. (512, 256),
  150. (1024, 1024),
  151. (1023, 1024),
  152. (1024, 1023),
  153. (2048, 2048),
  154. ],
  155. )
  156. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  157. def test_flash_attn_output(
  158. seqlen_q, seqlen_k, d, causal, mha_type, dtype
  159. ):
  160. device = "cuda"
  161. # set seed
  162. torch.random.manual_seed(0)
  163. # batch_size = 40
  164. # nheads = 16
  165. batch_size = 9
  166. nheads = 6
  167. nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
  168. # batch_size = 1
  169. # nheads = 1
  170. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  171. k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
  172. v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
  173. out, lse = flash_attn_func(q, k, v, causal=causal)
  174. out_ref, attn_ref = attention_ref(
  175. q,
  176. k,
  177. v,
  178. None,
  179. None,
  180. causal=causal,
  181. )
  182. out_pt, attn_pt = attention_ref(
  183. q,
  184. k,
  185. v,
  186. None,
  187. None,
  188. causal=causal,
  189. upcast=False,
  190. reorder_ops=True,
  191. )
  192. # qk = torch.einsum('bshd,bthd->bhst', q, k).float()
  193. # m = qk.amax(-1, keepdim=True)
  194. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  195. # exp_sum = s_tmp.sum(-1)
  196. # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
  197. # lse_ref = torch.logsumexp(qk, dim=-1)
  198. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  199. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  200. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  201. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  202. # if not causal:
  203. # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
  204. # breakpoint()
  205. # if d <= 128:
  206. # g = torch.randn_like(out)
  207. # do_o = (g.float() * out.float()).sum(-1)
  208. # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
  209. # dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g)
  210. # dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g)
  211. # print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  212. # print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  213. # print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  214. # print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  215. # print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  216. # print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  217. # print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  218. # print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  219. # print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  220. # print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  221. # print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  222. # print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  223. # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
  224. # P = torch.softmax(qk, -1)
  225. # dP = P * (dS - do_o.unsqueeze(1))
  226. # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
  227. # dV = torch.einsum('bhts,bthd->bshd', P, g.float())
  228. # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
  229. # breakpoint()
  230. # Check that FlashAttention's numerical error is at most twice the numerical error
  231. # of a Pytorch implementation.
  232. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  233. # if d <= 128:
  234. # assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
  235. # assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
  236. # assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()