test_flash_attn.py 72 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817
  1. import math
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from einops import rearrange, repeat
  6. from flash_attn import (
  7. flash_attn_func,
  8. flash_attn_kvpacked_func,
  9. flash_attn_qkvpacked_func,
  10. flash_attn_varlen_func,
  11. flash_attn_varlen_kvpacked_func,
  12. flash_attn_varlen_qkvpacked_func,
  13. flash_attn_with_kvcache,
  14. )
  15. from flash_attn.bert_padding import pad_input, unpad_input
  16. from flash_attn.flash_attn_interface import _get_block_size
  17. MAX_HEADDIM_SM8x = 192
  18. is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
  19. is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
  20. is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
  21. is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
  22. def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
  23. assert mode in ["full", "random", "third"]
  24. if mode == "full":
  25. lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
  26. elif mode == "random":
  27. lengths = torch.randint(
  28. max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
  29. )
  30. elif mode == "third":
  31. lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
  32. padding_mask = (
  33. repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
  34. )
  35. return padding_mask
  36. def generate_qkv(
  37. q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
  38. ):
  39. """
  40. Arguments:
  41. q: (batch_size, seqlen_q, nheads, d)
  42. k: (batch_size, seqlen_k, nheads_k, d)
  43. v: (batch_size, seqlen_k, nheads_k, d)
  44. query_padding_mask: (batch_size, seqlen), bool
  45. key_padding_mask: (batch_size, seqlen), bool
  46. """
  47. assert not (kvpacked and qkvpacked)
  48. batch_size, seqlen_q, nheads, d = q.shape
  49. _, seqlen_k, nheads_k, _ = k.shape
  50. assert k.shape == (batch_size, seqlen_k, nheads_k, d)
  51. assert v.shape == (batch_size, seqlen_k, nheads_k, d)
  52. if query_padding_mask is not None:
  53. q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
  54. output_pad_fn = lambda output_unpad: pad_input(
  55. output_unpad, indices_q, batch_size, seqlen_q
  56. )
  57. else:
  58. q_unpad = rearrange(q, "b s h d -> (b s) h d")
  59. cu_seqlens_q = torch.arange(
  60. 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
  61. )
  62. max_seqlen_q = seqlen_q
  63. output_pad_fn = lambda output_unpad: rearrange(
  64. output_unpad, "(b s) h d -> b s h d", b=batch_size
  65. )
  66. if key_padding_mask is not None:
  67. k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
  68. v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
  69. else:
  70. k_unpad = rearrange(k, "b s h d -> (b s) h d")
  71. v_unpad = rearrange(v, "b s h d -> (b s) h d")
  72. cu_seqlens_k = torch.arange(
  73. 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
  74. )
  75. max_seqlen_k = seqlen_k
  76. if qkvpacked:
  77. assert (query_padding_mask == key_padding_mask).all()
  78. assert nheads == nheads_k
  79. qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
  80. qkv = torch.stack([q, k, v], dim=2)
  81. if query_padding_mask is not None:
  82. dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
  83. else:
  84. dqkv_pad_fn = lambda dqkv_unpad: rearrange(
  85. dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
  86. )
  87. return (
  88. qkv_unpad.detach().requires_grad_(),
  89. cu_seqlens_q,
  90. max_seqlen_q,
  91. qkv.detach().requires_grad_(),
  92. output_pad_fn,
  93. dqkv_pad_fn,
  94. )
  95. elif kvpacked:
  96. kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
  97. kv = torch.stack([k, v], dim=2)
  98. dq_pad_fn = output_pad_fn
  99. if key_padding_mask is not None:
  100. dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
  101. else:
  102. dkv_pad_fn = lambda dkv_unpad: rearrange(
  103. dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
  104. )
  105. return (
  106. q_unpad.detach().requires_grad_(),
  107. kv_unpad.detach().requires_grad_(),
  108. cu_seqlens_q,
  109. cu_seqlens_k,
  110. max_seqlen_q,
  111. max_seqlen_k,
  112. q.detach().requires_grad_(),
  113. kv.detach().requires_grad_(),
  114. output_pad_fn,
  115. dq_pad_fn,
  116. dkv_pad_fn,
  117. )
  118. else:
  119. dq_pad_fn = output_pad_fn
  120. if key_padding_mask is not None:
  121. dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
  122. else:
  123. dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
  124. return (
  125. q_unpad.detach().requires_grad_(),
  126. k_unpad.detach().requires_grad_(),
  127. v_unpad.detach().requires_grad_(),
  128. cu_seqlens_q,
  129. cu_seqlens_k,
  130. max_seqlen_q,
  131. max_seqlen_k,
  132. q.detach().requires_grad_(),
  133. k.detach().requires_grad_(),
  134. v.detach().requires_grad_(),
  135. output_pad_fn,
  136. dq_pad_fn,
  137. dk_pad_fn,
  138. )
  139. def construct_causal_mask(
  140. seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None
  141. ):
  142. row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
  143. col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
  144. sk = (
  145. seqlen_k
  146. if key_padding_mask is None
  147. else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
  148. )
  149. sq = (
  150. seqlen_q
  151. if query_padding_mask is None
  152. else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
  153. )
  154. return col_idx > row_idx + sk - sq
  155. def attention_ref(
  156. q,
  157. k,
  158. v,
  159. query_padding_mask=None,
  160. key_padding_mask=None,
  161. dropout_p=0.0,
  162. dropout_mask=None,
  163. causal=False,
  164. upcast=True,
  165. reorder_ops=False,
  166. ):
  167. """
  168. Arguments:
  169. q: (batch_size, seqlen_q, nheads, head_dim)
  170. k: (batch_size, seqlen_k, nheads_k, head_dim)
  171. v: (batch_size, seqlen_k, nheads_k, head_dim)
  172. query_padding_mask: (batch_size, seqlen_q)
  173. key_padding_mask: (batch_size, seqlen_k)
  174. dropout_p: float
  175. dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
  176. upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
  177. output back to fp16/bf16.
  178. reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
  179. without changing the math. This is to estimate the numerical error from operation
  180. reordering.
  181. Output:
  182. output: (batch_size, seqlen_q, nheads, head_dim)
  183. attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
  184. """
  185. dtype_og = q.dtype
  186. if upcast:
  187. q, k, v = q.float(), k.float(), v.float()
  188. seqlen_q, seqlen_k = q.shape[1], k.shape[1]
  189. k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
  190. v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
  191. d = q.shape[-1]
  192. if not reorder_ops:
  193. scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
  194. else:
  195. scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
  196. if key_padding_mask is not None:
  197. scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
  198. if causal:
  199. # causal_mask = torch.triu(
  200. # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
  201. # )
  202. causal_mask = construct_causal_mask(
  203. seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device
  204. )
  205. scores.masked_fill_(causal_mask, float("-inf"))
  206. attention = torch.softmax(scores, dim=-1)
  207. if causal: # Some rows are completely masked out so we fill them with zero instead of NaN
  208. attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0)
  209. dropout_scaling = 1.0 / (1 - dropout_p)
  210. # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
  211. # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
  212. if dropout_mask is not None:
  213. attention_drop = attention.masked_fill(~dropout_mask, 0.0)
  214. else:
  215. attention_drop = attention
  216. output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
  217. if query_padding_mask is not None:
  218. output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
  219. attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
  220. return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
  221. def attention_kvpacked_ref(
  222. q,
  223. kv,
  224. query_padding_mask=None,
  225. key_padding_mask=None,
  226. dropout_p=0.0,
  227. dropout_mask=None,
  228. causal=False,
  229. upcast=True,
  230. reorder_ops=False,
  231. ):
  232. return attention_ref(
  233. q,
  234. kv[:, :, 0],
  235. kv[:, :, 1],
  236. query_padding_mask,
  237. key_padding_mask,
  238. dropout_p,
  239. dropout_mask,
  240. upcast=upcast,
  241. causal=causal,
  242. reorder_ops=reorder_ops,
  243. )
  244. def attention_qkvpacked_ref(
  245. qkv,
  246. key_padding_mask=None,
  247. dropout_p=0.0,
  248. dropout_mask=None,
  249. causal=False,
  250. upcast=True,
  251. reorder_ops=False,
  252. ):
  253. return attention_ref(
  254. qkv[:, :, 0],
  255. qkv[:, :, 1],
  256. qkv[:, :, 2],
  257. key_padding_mask,
  258. key_padding_mask,
  259. dropout_p,
  260. dropout_mask,
  261. upcast=upcast,
  262. causal=causal,
  263. reorder_ops=reorder_ops,
  264. )
  265. def generate_sparsity_mask(seqlen, sparsity=0.3):
  266. repeats = seqlen // 16 // 2
  267. # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
  268. # torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
  269. # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
  270. # torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
  271. # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
  272. # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
  273. nrow, ncol = seqlen // 16, seqlen // 256
  274. mask = torch.rand(nrow, ncol, device="cuda") < sparsity
  275. return mask
  276. def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):
  277. """
  278. Arguments:
  279. qkv: (batch_size, seqlen, 3, nheads, head_dim)
  280. blockmask: (seqlen / 16, seqlen / 256)
  281. attn_mask: (batch_size, seqlen)
  282. dropout_p: float
  283. dropout_mask: (batch_size, nheads, seqlen, seqlen)
  284. Output:
  285. output: (batch_size, seqlen, nheads, head_dim)
  286. attention: softmax after dropout
  287. """
  288. q, k, v = qkv.float().unbind(dim=2)
  289. d = qkv.shape[-1]
  290. seqlen = qkv.shape[1]
  291. scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
  292. scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
  293. blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)")
  294. blockmask = blockmask[:seqlen, :seqlen]
  295. scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf"))
  296. attention = torch.softmax(scores, dim=-1)
  297. attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0)
  298. attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0)
  299. attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p)
  300. output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
  301. output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0)
  302. return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
  303. def convert_flash_attn_S_to_softmax(
  304. S, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, head_dim, is_dropout, causal=False
  305. ):
  306. """FlashAttention stores the S matrix in a different way.
  307. Arguments:
  308. S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
  309. query_padding_mask: (batch_size, seqlen_q_rounded)
  310. key_padding_mask: (batch_size, seqlen_k_rounded)
  311. """
  312. seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
  313. warps_n = 4
  314. blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal)
  315. nblocks_n = (seqlen_k_rounded + blocksize_n - 1) // blocksize_n
  316. nblocks_m = (seqlen_q_rounded + blocksize_m - 1) // blocksize_m
  317. mmas_n = (blocksize_n + 16 - 1) // 16
  318. S_flat = rearrange(
  319. S,
  320. "b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)",
  321. blocksize_m=blocksize_m,
  322. blocksize_n=blocksize_n,
  323. )
  324. S_converted = rearrange(
  325. S_flat,
  326. "b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)",
  327. mmas_n=mmas_n,
  328. warps_n=warps_n,
  329. eight=8,
  330. c0=2,
  331. c1=2,
  332. c2=2,
  333. four=4,
  334. )
  335. if causal:
  336. # causal_mask = torch.triu(
  337. # torch.ones(seqlen_q_rounded, seqlen_k_rounded, dtype=torch.bool, device=q.device), 1
  338. # )
  339. causal_mask = construct_causal_mask(
  340. seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, S.device
  341. )
  342. causal_mask = F.pad(
  343. causal_mask,
  344. (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
  345. value=True,
  346. )
  347. S_converted.masked_fill_(causal_mask, 0.0)
  348. # Need to zero out things not in attention_mask in case S was initialized with random values
  349. # and some of those values aren't overwritten.
  350. seqlen_q_og = (
  351. query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
  352. )
  353. if query_padding_mask is not None:
  354. query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
  355. S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
  356. seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
  357. if key_padding_mask is not None:
  358. key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
  359. S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
  360. S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
  361. S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
  362. return S_converted[:, :, :seqlen_q, :seqlen_k]
  363. def normalize_flash_attn_S(
  364. attn_unnorm,
  365. q,
  366. k,
  367. v,
  368. query_padding_mask=None,
  369. key_padding_mask=None,
  370. is_dropout=False,
  371. causal=False,
  372. ):
  373. """
  374. Arguments:
  375. q: (batch_size, seqlen_q, nheads, head_dim)
  376. k, v: (batch_size, seqlen_k, nheads, head_dim)
  377. key_padding_mask: (batch_size, seqlen_q)
  378. Output:
  379. softmax_lse: (batch_size, nheads, seqlen_q)
  380. softmax_max: (batch_size, nheads, seqlen_q)
  381. """
  382. q, k, v = q.float(), k.float(), v.float()
  383. _, seqlen_q, _, head_dim = q.shape
  384. seqlen_k = k.shape[1]
  385. scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
  386. if key_padding_mask is not None:
  387. scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
  388. if causal:
  389. # causal_mask = torch.triu(
  390. # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
  391. # )
  392. causal_mask = construct_causal_mask(
  393. seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device
  394. )
  395. scores.masked_fill_(causal_mask, float("-inf"))
  396. _, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal)
  397. scores_block = scores.split(block_size_n, dim=-1)
  398. lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
  399. lse = torch.logsumexp(lse_block, dim=-1)
  400. # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
  401. # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
  402. lse[lse == float("-inf")] = float("inf")
  403. scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
  404. cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
  405. attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
  406. attn_norm = torch.cat(
  407. [
  408. a * rearrange(torch.exp(m - lse), "b h s -> b h s 1")
  409. for a, m in zip(attn_unnorm_block, cummax_block)
  410. ],
  411. dim=-1,
  412. )
  413. if query_padding_mask is not None:
  414. attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
  415. return attn_norm.to(dtype=attn_unnorm.dtype)
  416. def get_dropout_fraction(
  417. dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False
  418. ):
  419. """
  420. dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
  421. query_padding_mask: (batch_size, seqlen_q)
  422. key_padding_mask: (batch_size, seqlen_k)
  423. """
  424. batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
  425. dropped = ~dropout_mask
  426. if query_padding_mask is not None:
  427. dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
  428. if key_padding_mask is not None:
  429. dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
  430. if causal:
  431. # causal_mask = torch.triu(
  432. # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1
  433. # )
  434. causal_mask = construct_causal_mask(
  435. seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, dropout_mask.device
  436. )
  437. dropped.masked_fill_(causal_mask, False)
  438. dropped_total = dropped.sum()
  439. query_lengths = (
  440. query_padding_mask.sum(dim=-1)
  441. if query_padding_mask is not None
  442. else torch.full((batch_size,), seqlen_q, device=dropout_mask.device)
  443. )
  444. key_lengths = (
  445. key_padding_mask.sum(dim=-1)
  446. if key_padding_mask is not None
  447. else torch.full((batch_size,), seqlen_k, device=dropout_mask.device)
  448. )
  449. if not causal:
  450. numel_per_batch = query_lengths * key_lengths
  451. else:
  452. numel_per_batch = torch.where(
  453. key_lengths <= query_lengths,
  454. key_lengths * (key_lengths + 1) / 2,
  455. query_lengths * key_lengths - (query_lengths * (query_lengths - 1) / 2),
  456. )
  457. return dropped_total / (numel_per_batch.sum() * nheads)
  458. @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  459. # @pytest.mark.parametrize('dtype', [torch.float16])
  460. @pytest.mark.parametrize("causal", [False, True])
  461. # @pytest.mark.parametrize('causal', [True])
  462. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  463. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
  464. # @pytest.mark.parametrize('d', [32, 64, 96, 128])
  465. # @pytest.mark.parametrize('d', [64])
  466. # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
  467. @pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
  468. # @pytest.mark.parametrize('seqlen', [97])
  469. @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  470. # @pytest.mark.parametrize('dropout_p', [0.17])
  471. def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
  472. if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
  473. pytest.skip() # Reference implementation OOM
  474. device = "cuda"
  475. # set seed
  476. torch.random.manual_seed(0)
  477. batch_size = 16
  478. nheads = 9
  479. qkv = torch.randn(
  480. batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
  481. )
  482. out, lse, S_dmask = flash_attn_qkvpacked_func(
  483. qkv, dropout_p, return_attn_probs=True, causal=causal
  484. )
  485. if dropout_p > 0.0:
  486. S_dmask_converted = convert_flash_attn_S_to_softmax(
  487. S_dmask, seqlen, seqlen, None, None, d, dropout_p > 0.0, causal=causal
  488. )
  489. dropout_mask = S_dmask_converted >= 0
  490. attn_unnorm = S_dmask_converted.abs()
  491. attn = normalize_flash_attn_S(
  492. attn_unnorm,
  493. qkv[:, :, 0],
  494. qkv[:, :, 1],
  495. qkv[:, :, 2],
  496. None,
  497. None,
  498. dropout_p > 0.0,
  499. causal=causal,
  500. )
  501. dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item()
  502. print(f"Actual dropout fraction: {dropout_fraction}")
  503. else:
  504. dropout_mask = None
  505. out_ref, attn_ref = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal)
  506. out_pt, attn_pt = attention_qkvpacked_ref(
  507. qkv, None, dropout_p, dropout_mask, causal=causal, upcast=False, reorder_ops=True
  508. )
  509. # v = qkv[:, :, 2].float()
  510. # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
  511. # if causal:
  512. # causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)
  513. # qk.masked_fill_(causal_mask, float('-inf'))
  514. # m = qk.amax(-1, keepdim=True)
  515. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  516. # p_tmp = torch.softmax(qk / math.sqrt(d), -1)
  517. # p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0)
  518. # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
  519. # qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values
  520. # qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values
  521. # qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values
  522. # qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values
  523. # o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:])
  524. # o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:])
  525. # o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:])
  526. # o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :])
  527. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  528. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  529. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  530. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  531. if dropout_p > 0.0:
  532. print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
  533. print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
  534. g = torch.randn_like(out)
  535. # do_o = (g.float() * out.float()).sum(-1)
  536. # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
  537. # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
  538. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  539. (dqkv,) = torch.autograd.grad(out, qkv, g)
  540. (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
  541. (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
  542. print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  543. print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  544. print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  545. print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
  546. print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  547. print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  548. print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  549. print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
  550. # Check that FlashAttention's numerical error is at most twice the numerical error
  551. # of a Pytorch implementation.
  552. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  553. if dropout_p > 0.0:
  554. assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
  555. assert abs(dropout_fraction - dropout_p) <= 0.01
  556. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  557. assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
  558. @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  559. # @pytest.mark.parametrize('dtype', [torch.float16])
  560. @pytest.mark.parametrize("causal", [False, True])
  561. # @pytest.mark.parametrize('causal', [False])
  562. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  563. # @pytest.mark.parametrize('d', [64])
  564. @pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
  565. # @pytest.mark.parametrize('seqlen', [128])
  566. @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  567. # @pytest.mark.parametrize('dropout_p', [0.0])
  568. def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
  569. if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
  570. pytest.skip() # Reference implementation OOM
  571. device = "cuda"
  572. # set seed
  573. torch.random.manual_seed(0)
  574. batch_size = 5
  575. nheads = 6
  576. qkv = torch.randn(
  577. batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
  578. )
  579. key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
  580. # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
  581. qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
  582. *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
  583. )
  584. out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
  585. qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal
  586. )
  587. out = output_pad_fn(out_unpad)
  588. if dropout_p > 0.0:
  589. S_dmask_converted = convert_flash_attn_S_to_softmax(
  590. S_dmask,
  591. seqlen,
  592. seqlen,
  593. key_padding_mask,
  594. key_padding_mask,
  595. d,
  596. dropout_p > 0.0,
  597. causal=causal,
  598. )
  599. dropout_mask = S_dmask_converted >= 0
  600. attn_unnorm = S_dmask_converted.abs()
  601. attn = normalize_flash_attn_S(
  602. attn_unnorm,
  603. qkv[:, :, 0],
  604. qkv[:, :, 1],
  605. qkv[:, :, 2],
  606. key_padding_mask,
  607. key_padding_mask,
  608. dropout_p > 0.0,
  609. causal=causal,
  610. )
  611. dropout_fraction = get_dropout_fraction(
  612. dropout_mask, key_padding_mask, key_padding_mask, causal=causal
  613. ).item()
  614. print(f"Actual dropout fraction: {dropout_fraction}")
  615. else:
  616. dropout_mask = None
  617. out_ref, attn_ref = attention_qkvpacked_ref(
  618. qkv, key_padding_mask, dropout_p, dropout_mask, causal=causal
  619. )
  620. out_pt, attn_pt = attention_qkvpacked_ref(
  621. qkv,
  622. key_padding_mask,
  623. dropout_p,
  624. dropout_mask,
  625. causal=causal,
  626. upcast=False,
  627. reorder_ops=True,
  628. )
  629. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  630. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  631. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  632. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  633. if dropout_p > 0.0:
  634. print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
  635. print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
  636. g = torch.randn_like(out)
  637. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  638. (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
  639. dqkv = dqkv_pad_fn(dqkv_unpad)
  640. (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
  641. (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
  642. print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  643. print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  644. print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  645. print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
  646. print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  647. print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  648. print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  649. print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
  650. # Check that FlashAttention's numerical error is at most twice the numerical error
  651. # of a Pytorch implementation.
  652. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  653. if dropout_p > 0.0:
  654. assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
  655. assert abs(dropout_fraction - dropout_p) <= 0.01
  656. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  657. assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
  658. @pytest.mark.parametrize("kvpacked", [True, False])
  659. # @pytest.mark.parametrize("kvpacked", [False])
  660. @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  661. # @pytest.mark.parametrize("dtype", [torch.bfloat16])
  662. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  663. # @pytest.mark.parametrize("mha_type", ["mha"])
  664. @pytest.mark.parametrize("causal", [False, True])
  665. # @pytest.mark.parametrize("causal", [True])
  666. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  667. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
  668. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  669. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  670. # @pytest.mark.parametrize('d', [56, 80])
  671. # @pytest.mark.parametrize("d", [64])
  672. @pytest.mark.parametrize(
  673. "seqlen_q,seqlen_k",
  674. [
  675. (113, 203),
  676. (128, 217),
  677. (113, 211),
  678. (108, 256),
  679. (256, 512),
  680. (512, 256),
  681. (1024, 1024),
  682. (1023, 1024),
  683. (1024, 1023),
  684. (2048, 2048),
  685. ],
  686. )
  687. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
  688. @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  689. # @pytest.mark.parametrize("dropout_p", [0.17])
  690. def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked):
  691. if (
  692. max(seqlen_q, seqlen_k) >= 2048
  693. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  694. ):
  695. pytest.skip() # Reference implementation OOM
  696. device = "cuda"
  697. # set seed
  698. torch.random.manual_seed(0)
  699. batch_size = 16
  700. nheads = 9
  701. nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
  702. assert nheads % nheads_k == 0
  703. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  704. if kvpacked:
  705. kv = torch.randn(
  706. batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  707. )
  708. else:
  709. k = torch.randn(
  710. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  711. )
  712. v = torch.randn(
  713. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  714. )
  715. if kvpacked:
  716. out, lse, S_dmask = flash_attn_kvpacked_func(
  717. q, kv, dropout_p, return_attn_probs=True, causal=causal
  718. )
  719. else:
  720. out, lse, S_dmask = flash_attn_func(
  721. q, k, v, dropout_p, return_attn_probs=True, causal=causal
  722. )
  723. if dropout_p > 0.0:
  724. S_dmask_converted = convert_flash_attn_S_to_softmax(
  725. S_dmask, seqlen_q, seqlen_k, None, None, d, dropout_p > 0.0, causal=causal
  726. )
  727. dropout_mask = S_dmask_converted >= 0
  728. attn_unnorm = S_dmask_converted.abs()
  729. if kvpacked:
  730. kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
  731. k_rep, v_rep = kv_rep.unbind(dim=2)
  732. else:
  733. k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  734. v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  735. attn = normalize_flash_attn_S(
  736. attn_unnorm, q, k_rep, v_rep, None, None, dropout_p > 0.0, causal=causal
  737. )
  738. dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item()
  739. print(f"Actual dropout fraction: {dropout_fraction}")
  740. else:
  741. dropout_mask = None
  742. if kvpacked:
  743. out_ref, attn_ref = attention_kvpacked_ref(
  744. q, kv, None, None, dropout_p, dropout_mask, causal=causal
  745. )
  746. out_pt, attn_pt = attention_kvpacked_ref(
  747. q,
  748. kv,
  749. None,
  750. None,
  751. dropout_p,
  752. dropout_mask,
  753. causal=causal,
  754. upcast=False,
  755. reorder_ops=True,
  756. )
  757. else:
  758. out_ref, attn_ref = attention_ref(
  759. q, k, v, None, None, dropout_p, dropout_mask, causal=causal
  760. )
  761. out_pt, attn_pt = attention_ref(
  762. q,
  763. k,
  764. v,
  765. None,
  766. None,
  767. dropout_p,
  768. dropout_mask,
  769. causal=causal,
  770. upcast=False,
  771. reorder_ops=True,
  772. )
  773. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  774. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  775. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  776. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  777. if dropout_p > 0.0:
  778. print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
  779. print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
  780. g = torch.randn_like(out)
  781. do_o = (g.float() * out.float()).sum(-1)
  782. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  783. if kvpacked:
  784. (
  785. dq,
  786. dkv,
  787. ) = torch.autograd.grad(out, (q, kv), g)
  788. dk, dv = dkv.unbind(2)
  789. (
  790. dq_ref,
  791. dkv_ref,
  792. ) = torch.autograd.grad(out_ref, (q, kv), g)
  793. dk_ref, dv_ref = dkv_ref.unbind(2)
  794. (
  795. dq_pt,
  796. dkv_pt,
  797. ) = torch.autograd.grad(out_pt, (q, kv), g)
  798. dk_pt, dv_pt = dkv_pt.unbind(2)
  799. else:
  800. (
  801. dq,
  802. dk,
  803. dv,
  804. ) = torch.autograd.grad(out, (q, k, v), g)
  805. (
  806. dq_ref,
  807. dk_ref,
  808. dv_ref,
  809. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  810. (
  811. dq_pt,
  812. dk_pt,
  813. dv_pt,
  814. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  815. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  816. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  817. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  818. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  819. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  820. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  821. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  822. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  823. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  824. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  825. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  826. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  827. # Check that FlashAttention's numerical error is at most twice the numerical error
  828. # of a Pytorch implementation.
  829. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  830. if dropout_p > 0.0:
  831. assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
  832. assert abs(dropout_fraction - dropout_p) <= 0.01
  833. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  834. assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
  835. assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
  836. assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
  837. @pytest.mark.parametrize("kvpacked", [True, False])
  838. # @pytest.mark.parametrize('kvpacked', [False])
  839. @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  840. # @pytest.mark.parametrize('dtype', [torch.float16])
  841. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  842. # @pytest.mark.parametrize('mha_type', ["mqa"])
  843. @pytest.mark.parametrize("causal", [False, True])
  844. # @pytest.mark.parametrize('causal', [True])
  845. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  846. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
  847. # @pytest.mark.parametrize('d', [64])
  848. @pytest.mark.parametrize(
  849. "seqlen_q,seqlen_k",
  850. [
  851. (113, 203),
  852. (128, 217),
  853. (113, 211),
  854. (108, 256),
  855. (256, 512),
  856. (512, 256),
  857. (1024, 1024),
  858. (1023, 1024),
  859. (1024, 1023),
  860. (2048, 2048),
  861. ],
  862. )
  863. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  864. @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  865. # @pytest.mark.parametrize('dropout_p', [0.0])
  866. def test_flash_attn_varlen_output(
  867. seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked
  868. ):
  869. if (
  870. max(seqlen_q, seqlen_k) >= 2048
  871. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  872. ):
  873. pytest.skip() # Reference implementation OOM
  874. device = "cuda"
  875. # set seed
  876. torch.random.manual_seed(0)
  877. batch_size = 16
  878. nheads = 9
  879. nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
  880. assert nheads % nheads_k == 0
  881. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  882. if kvpacked:
  883. kv = torch.randn(
  884. batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  885. )
  886. else:
  887. k = torch.randn(
  888. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  889. )
  890. v = torch.randn(
  891. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  892. )
  893. query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
  894. key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
  895. # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
  896. if kvpacked:
  897. (
  898. q_unpad,
  899. kv_unpad,
  900. cu_seqlens_q,
  901. cu_seqlens_k,
  902. max_seqlen_q,
  903. max_seqlen_k,
  904. q,
  905. kv,
  906. output_pad_fn,
  907. dq_pad_fn,
  908. dkv_pad_fn,
  909. ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)
  910. out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
  911. q_unpad,
  912. kv_unpad,
  913. cu_seqlens_q,
  914. cu_seqlens_k,
  915. max_seqlen_q,
  916. max_seqlen_k,
  917. dropout_p,
  918. return_attn_probs=True,
  919. causal=causal,
  920. )
  921. else:
  922. (
  923. q_unpad,
  924. k_unpad,
  925. v_unpad,
  926. cu_seqlens_q,
  927. cu_seqlens_k,
  928. max_seqlen_q,
  929. max_seqlen_k,
  930. q,
  931. k,
  932. v,
  933. output_pad_fn,
  934. dq_pad_fn,
  935. dk_pad_fn,
  936. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
  937. out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
  938. q_unpad,
  939. k_unpad,
  940. v_unpad,
  941. cu_seqlens_q,
  942. cu_seqlens_k,
  943. max_seqlen_q,
  944. max_seqlen_k,
  945. dropout_p,
  946. return_attn_probs=True,
  947. causal=causal,
  948. )
  949. out = output_pad_fn(out_unpad)
  950. if dropout_p > 0.0:
  951. S_dmask_converted = convert_flash_attn_S_to_softmax(
  952. S_dmask,
  953. seqlen_q,
  954. seqlen_k,
  955. query_padding_mask,
  956. key_padding_mask,
  957. d,
  958. dropout_p > 0.0,
  959. causal=causal,
  960. )
  961. dropout_mask = S_dmask_converted >= 0
  962. attn_unnorm = S_dmask_converted.abs()
  963. if kvpacked:
  964. kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
  965. k_rep, v_rep = kv_rep.unbind(dim=2)
  966. else:
  967. k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  968. v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  969. attn = normalize_flash_attn_S(
  970. attn_unnorm,
  971. q,
  972. k_rep,
  973. v_rep,
  974. query_padding_mask,
  975. key_padding_mask,
  976. dropout_p > 0.0,
  977. causal=causal,
  978. )
  979. dropout_fraction = get_dropout_fraction(
  980. dropout_mask, query_padding_mask, key_padding_mask, causal=causal
  981. ).item()
  982. print(f"Actual dropout fraction: {dropout_fraction}")
  983. else:
  984. dropout_mask = None
  985. if kvpacked:
  986. out_ref, attn_ref = attention_kvpacked_ref(
  987. q, kv, query_padding_mask, key_padding_mask, dropout_p, dropout_mask, causal=causal
  988. )
  989. out_pt, attn_pt = attention_kvpacked_ref(
  990. q,
  991. kv,
  992. query_padding_mask,
  993. key_padding_mask,
  994. dropout_p,
  995. dropout_mask,
  996. causal=causal,
  997. upcast=False,
  998. reorder_ops=True,
  999. )
  1000. else:
  1001. out_ref, attn_ref = attention_ref(
  1002. q, k, v, query_padding_mask, key_padding_mask, dropout_p, dropout_mask, causal=causal
  1003. )
  1004. out_pt, attn_pt = attention_ref(
  1005. q,
  1006. k,
  1007. v,
  1008. query_padding_mask,
  1009. key_padding_mask,
  1010. dropout_p,
  1011. dropout_mask,
  1012. causal=causal,
  1013. upcast=False,
  1014. reorder_ops=True,
  1015. )
  1016. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  1017. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  1018. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  1019. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  1020. if dropout_p > 0.0:
  1021. print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
  1022. print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
  1023. g = torch.randn_like(out)
  1024. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  1025. if kvpacked:
  1026. (
  1027. dq_unpad,
  1028. dkv_unpad,
  1029. ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
  1030. dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
  1031. (
  1032. dq_ref,
  1033. dkv_ref,
  1034. ) = torch.autograd.grad(out_ref, (q, kv), g)
  1035. dk_ref, dv_ref = dkv_ref.unbind(2)
  1036. (
  1037. dq_pt,
  1038. dkv_pt,
  1039. ) = torch.autograd.grad(out_pt, (q, kv), g)
  1040. dk_pt, dv_pt = dkv_pt.unbind(2)
  1041. else:
  1042. (
  1043. dq_unpad,
  1044. dk_unpad,
  1045. dv_unpad,
  1046. ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
  1047. dk = dk_pad_fn(dk_unpad)
  1048. dv = dk_pad_fn(dv_unpad)
  1049. (
  1050. dq_ref,
  1051. dk_ref,
  1052. dv_ref,
  1053. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  1054. (
  1055. dq_pt,
  1056. dk_pt,
  1057. dv_pt,
  1058. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  1059. dq = dq_pad_fn(dq_unpad)
  1060. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  1061. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  1062. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  1063. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  1064. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  1065. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  1066. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  1067. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  1068. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  1069. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  1070. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  1071. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  1072. # Check that FlashAttention's numerical error is at most twice the numerical error
  1073. # of a Pytorch implementation.
  1074. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  1075. if dropout_p > 0.0:
  1076. assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
  1077. assert abs(dropout_fraction - dropout_p) <= 0.01
  1078. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  1079. assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
  1080. assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
  1081. assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
  1082. @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  1083. # @pytest.mark.parametrize("dtype", [torch.bfloat16])
  1084. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  1085. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
  1086. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  1087. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  1088. # @pytest.mark.parametrize('d', [56, 80])
  1089. # @pytest.mark.parametrize("d", [64, 128])
  1090. @pytest.mark.parametrize("swap_sq_sk", [False, True])
  1091. # @pytest.mark.parametrize("swap_sq_sk", [True])
  1092. @pytest.mark.parametrize(
  1093. "seqlen_q,seqlen_k",
  1094. [
  1095. (1, 239),
  1096. (3, 799),
  1097. (127, 512),
  1098. (127, 513),
  1099. (113, 203),
  1100. (128, 217),
  1101. (113, 211),
  1102. (108, 256),
  1103. (256, 512),
  1104. (1023, 1024),
  1105. ],
  1106. )
  1107. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
  1108. def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
  1109. if (
  1110. max(seqlen_q, seqlen_k) >= 2048
  1111. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  1112. ):
  1113. pytest.skip() # Reference implementation OOM
  1114. if swap_sq_sk:
  1115. seqlen_q, seqlen_k = seqlen_k, seqlen_q
  1116. device = "cuda"
  1117. causal = True
  1118. # set seed
  1119. torch.random.manual_seed(0)
  1120. batch_size = 16
  1121. nheads = 9
  1122. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1123. k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1124. v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1125. out = flash_attn_func(q, k, v, 0.0, causal=causal)
  1126. out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal)
  1127. out_pt, attn_pt = attention_ref(
  1128. q,
  1129. k,
  1130. v,
  1131. None,
  1132. None,
  1133. 0.0,
  1134. None,
  1135. causal=causal,
  1136. upcast=False,
  1137. reorder_ops=True,
  1138. )
  1139. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  1140. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  1141. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  1142. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  1143. g = torch.randn_like(out)
  1144. do_o = (g.float() * out.float()).sum(-1)
  1145. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  1146. (
  1147. dq,
  1148. dk,
  1149. dv,
  1150. ) = torch.autograd.grad(out, (q, k, v), g)
  1151. (
  1152. dq_ref,
  1153. dk_ref,
  1154. dv_ref,
  1155. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  1156. (
  1157. dq_pt,
  1158. dk_pt,
  1159. dv_pt,
  1160. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  1161. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  1162. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  1163. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  1164. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  1165. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  1166. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  1167. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  1168. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  1169. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  1170. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  1171. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  1172. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  1173. # Check that FlashAttention's numerical error is at most twice the numerical error
  1174. # of a Pytorch implementation.
  1175. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
  1176. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  1177. assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
  1178. assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
  1179. assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
  1180. @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  1181. # @pytest.mark.parametrize("dtype", [torch.bfloat16])
  1182. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  1183. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
  1184. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  1185. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  1186. # @pytest.mark.parametrize('d', [56, 80])
  1187. # @pytest.mark.parametrize("d", [128])
  1188. @pytest.mark.parametrize("swap_sq_sk", [False, True])
  1189. # @pytest.mark.parametrize("swap_sq_sk", [True])
  1190. @pytest.mark.parametrize(
  1191. "seqlen_q,seqlen_k",
  1192. [
  1193. (1, 239),
  1194. (3, 799),
  1195. (127, 512),
  1196. (127, 513),
  1197. (113, 203),
  1198. (128, 217),
  1199. (113, 211),
  1200. (108, 256),
  1201. (256, 512),
  1202. (1023, 1024),
  1203. ],
  1204. )
  1205. # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
  1206. def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
  1207. if (
  1208. max(seqlen_q, seqlen_k) >= 2048
  1209. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  1210. ):
  1211. pytest.skip() # Reference implementation OOM
  1212. if swap_sq_sk:
  1213. seqlen_q, seqlen_k = seqlen_k, seqlen_q
  1214. device = "cuda"
  1215. causal = True
  1216. # set seed
  1217. torch.random.manual_seed(0)
  1218. batch_size = 16
  1219. nheads = 9
  1220. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1221. k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1222. v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1223. query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
  1224. key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
  1225. (
  1226. q_unpad,
  1227. k_unpad,
  1228. v_unpad,
  1229. cu_seqlens_q,
  1230. cu_seqlens_k,
  1231. max_seqlen_q,
  1232. max_seqlen_k,
  1233. q,
  1234. k,
  1235. v,
  1236. output_pad_fn,
  1237. dq_pad_fn,
  1238. dk_pad_fn,
  1239. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
  1240. out_unpad = flash_attn_varlen_func(
  1241. q_unpad,
  1242. k_unpad,
  1243. v_unpad,
  1244. cu_seqlens_q,
  1245. cu_seqlens_k,
  1246. max_seqlen_q,
  1247. max_seqlen_k,
  1248. 0.0,
  1249. causal=causal,
  1250. )
  1251. out = output_pad_fn(out_unpad)
  1252. out_ref, attn_ref = attention_ref(
  1253. q, k, v, query_padding_mask, key_padding_mask, 0.0, None, causal=causal
  1254. )
  1255. out_pt, attn_pt = attention_ref(
  1256. q,
  1257. k,
  1258. v,
  1259. query_padding_mask,
  1260. key_padding_mask,
  1261. 0.0,
  1262. None,
  1263. causal=causal,
  1264. upcast=False,
  1265. reorder_ops=True,
  1266. )
  1267. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  1268. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  1269. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  1270. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  1271. g = torch.randn_like(out)
  1272. do_o = (g.float() * out.float()).sum(-1)
  1273. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  1274. (
  1275. dq_unpad,
  1276. dk_unpad,
  1277. dv_unpad,
  1278. ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
  1279. dq = dq_pad_fn(dq_unpad)
  1280. dk = dk_pad_fn(dk_unpad)
  1281. dv = dk_pad_fn(dv_unpad)
  1282. (
  1283. dq_ref,
  1284. dk_ref,
  1285. dv_ref,
  1286. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  1287. (
  1288. dq_pt,
  1289. dk_pt,
  1290. dv_pt,
  1291. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  1292. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  1293. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  1294. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  1295. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  1296. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  1297. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  1298. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  1299. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  1300. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  1301. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  1302. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  1303. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  1304. # Check that FlashAttention's numerical error is at most twice the numerical error
  1305. # of a Pytorch implementation.
  1306. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
  1307. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  1308. assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
  1309. assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
  1310. assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
  1311. @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  1312. # @pytest.mark.parametrize("dtype", [torch.float16])
  1313. @pytest.mark.parametrize("causal", [False, True])
  1314. # @pytest.mark.parametrize("causal", [True])
  1315. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  1316. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
  1317. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  1318. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  1319. # @pytest.mark.parametrize('d', [56, 80])
  1320. # @pytest.mark.parametrize("d", [64])
  1321. @pytest.mark.parametrize("swap_sq_sk", [False, True])
  1322. # @pytest.mark.parametrize("swap_sq_sk", [False])
  1323. @pytest.mark.parametrize(
  1324. "seqlen_q,seqlen_k",
  1325. [
  1326. (3, 1024),
  1327. (1, 339),
  1328. (64, 800),
  1329. (3, 799),
  1330. (64, 2048),
  1331. (16, 20000),
  1332. (16, 100000),
  1333. (128, 128),
  1334. (256, 256),
  1335. ],
  1336. )
  1337. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
  1338. def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
  1339. if swap_sq_sk:
  1340. seqlen_q, seqlen_k = seqlen_k, seqlen_q
  1341. device = "cuda"
  1342. # set seed
  1343. torch.random.manual_seed(0)
  1344. batch_size = 1
  1345. nheads = 12
  1346. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1347. k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1348. v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1349. out, lse, _ = flash_attn_func(q, k, v, 0.0, causal=causal, return_attn_probs=True)
  1350. out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal)
  1351. out_pt, attn_pt = attention_ref(
  1352. q,
  1353. k,
  1354. v,
  1355. None,
  1356. None,
  1357. 0.0,
  1358. None,
  1359. causal=causal,
  1360. upcast=False,
  1361. reorder_ops=True,
  1362. )
  1363. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  1364. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  1365. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  1366. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  1367. g = torch.randn_like(out)
  1368. do_o = (g.float() * out.float()).sum(-1)
  1369. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  1370. (
  1371. dq,
  1372. dk,
  1373. dv,
  1374. ) = torch.autograd.grad(out, (q, k, v), g)
  1375. (
  1376. dq_ref,
  1377. dk_ref,
  1378. dv_ref,
  1379. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  1380. (
  1381. dq_pt,
  1382. dk_pt,
  1383. dv_pt,
  1384. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  1385. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  1386. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  1387. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  1388. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  1389. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  1390. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  1391. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  1392. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  1393. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  1394. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  1395. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  1396. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  1397. # Check that FlashAttention's numerical error is at most twice the numerical error
  1398. # of a Pytorch implementation.
  1399. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
  1400. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  1401. assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 2e-4
  1402. assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4
  1403. assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4
  1404. @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  1405. # @pytest.mark.parametrize("dtype", [torch.float16])
  1406. @pytest.mark.parametrize("num_splits", [1, 0])
  1407. # @pytest.mark.parametrize("num_splits", [0])
  1408. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  1409. # @pytest.mark.parametrize("mha_type", ["mha"])
  1410. @pytest.mark.parametrize("new_kv", [False, True])
  1411. # @pytest.mark.parametrize("new_kv", [True])
  1412. @pytest.mark.parametrize("causal", [False, True])
  1413. # @pytest.mark.parametrize("causal", [False])
  1414. @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
  1415. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  1416. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
  1417. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  1418. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  1419. # @pytest.mark.parametrize('d', [56, 80])
  1420. # @pytest.mark.parametrize("d", [64])
  1421. @pytest.mark.parametrize(
  1422. "seqlen_q,seqlen_k",
  1423. [
  1424. (1, 128),
  1425. (1, 339),
  1426. (3, 1024),
  1427. (64, 800),
  1428. (64, 256),
  1429. (3, 799),
  1430. (64, 2048),
  1431. (16, 20000),
  1432. (1, 128 * 1024),
  1433. (16, 128 * 1024),
  1434. (128, 128),
  1435. ],
  1436. )
  1437. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
  1438. def test_flash_attn_kvcache(
  1439. seqlen_q, seqlen_k, d, seqlen_new_eq_seqlen_q, causal, new_kv, mha_type, num_splits, dtype
  1440. ):
  1441. if seqlen_q > seqlen_k and new_kv:
  1442. pytest.skip()
  1443. device = "cuda"
  1444. # set seed
  1445. torch.random.manual_seed(0)
  1446. batch_size = 2
  1447. nheads = 6
  1448. nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
  1449. assert nheads % nheads_k == 0
  1450. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
  1451. seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
  1452. if new_kv:
  1453. k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
  1454. v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
  1455. else:
  1456. k, v = None, None
  1457. k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
  1458. v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
  1459. cache_seqlens = torch.randint(
  1460. 0,
  1461. (seqlen_k - seqlen_new + 1) if new_kv else (seqlen_k + 1),
  1462. (batch_size,),
  1463. dtype=torch.int32,
  1464. device=device,
  1465. )
  1466. # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
  1467. # k_cache[:, 64:] = -1
  1468. k_cache_ref = k_cache.clone()
  1469. v_cache_ref = v_cache.clone()
  1470. arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
  1471. cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
  1472. if new_kv:
  1473. update_mask = torch.logical_and(
  1474. cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
  1475. )
  1476. k_cache_ref[update_mask] = rearrange(k, "b s ... -> (b s) ...")
  1477. v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
  1478. k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  1479. v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  1480. out = flash_attn_with_kvcache(
  1481. q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits
  1482. )
  1483. # out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
  1484. # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
  1485. # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
  1486. # m = qk.amax(-1, keepdim=True)
  1487. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  1488. # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
  1489. # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
  1490. # probs = torch.softmax(qk, dim=-1)
  1491. key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
  1492. out_ref, _ = attention_ref(
  1493. q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal
  1494. )
  1495. out_pt, _ = attention_ref(
  1496. q,
  1497. k_cache_rep,
  1498. v_cache_rep,
  1499. None,
  1500. key_padding_mask,
  1501. 0.0,
  1502. None,
  1503. causal=causal,
  1504. upcast=False,
  1505. reorder_ops=True,
  1506. )
  1507. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  1508. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  1509. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  1510. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  1511. # Check that FlashAttention's numerical error is at most twice the numerical error
  1512. # of a Pytorch implementation.
  1513. assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5
  1514. if new_kv:
  1515. assert torch.equal(k_cache, k_cache_ref)
  1516. assert torch.equal(v_cache, v_cache_ref)
  1517. # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  1518. @pytest.mark.parametrize("dtype", [torch.float16])
  1519. @pytest.mark.parametrize("causal", [False, True])
  1520. # @pytest.mark.parametrize('causal', [True])
  1521. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  1522. # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
  1523. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192])
  1524. # @pytest.mark.parametrize('d', [128])
  1525. @pytest.mark.parametrize(
  1526. "seqlen_q,seqlen_k",
  1527. [
  1528. (1, 239),
  1529. (239, 1),
  1530. (3, 799),
  1531. (799, 3),
  1532. (1024, 128),
  1533. (97, 97),
  1534. (128, 128),
  1535. (200, 200),
  1536. (256, 256),
  1537. (257, 257),
  1538. (384, 384),
  1539. (512, 512),
  1540. (768, 768),
  1541. (1024, 1024),
  1542. ],
  1543. )
  1544. @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  1545. # @pytest.mark.parametrize("dropout_p", [0.0])
  1546. def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):
  1547. device = "cuda"
  1548. # set seed
  1549. torch.random.manual_seed(0)
  1550. batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger
  1551. nheads = 4
  1552. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1553. k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1554. v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1555. torch.random.manual_seed(42)
  1556. out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
  1557. g = torch.randn_like(out0)
  1558. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  1559. (
  1560. dq0,
  1561. dk0,
  1562. dv0,
  1563. ) = torch.autograd.grad(out0, (q, k, v), g)
  1564. # Numerical error if we just do any arithmetic on dq
  1565. dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()
  1566. for i in range(250):
  1567. torch.random.manual_seed(42)
  1568. out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
  1569. assert torch.equal(out, out0)
  1570. assert torch.equal(lse, lse0)
  1571. if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
  1572. (
  1573. dq,
  1574. dk,
  1575. dv,
  1576. ) = torch.autograd.grad(out, (q, k, v), g)
  1577. dq_equal = torch.allclose(dq, dq0, atol=dq_atol)
  1578. if not dq_equal:
  1579. print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}")
  1580. assert torch.equal(dv, dv0)
  1581. assert torch.equal(dk, dk0)
  1582. assert dq_equal
  1583. @pytest.mark.parametrize("dtype", [torch.float16])
  1584. @pytest.mark.parametrize("causal", [False, True])
  1585. # @pytest.mark.parametrize('causal', [False])
  1586. @pytest.mark.parametrize("d", [16, 32, 64])
  1587. # @pytest.mark.parametrize('d', [16])
  1588. @pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128])
  1589. # @pytest.mark.parametrize('seqlen', [2])
  1590. def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
  1591. """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
  1592. in the case where seqlen % 128 != 0.
  1593. """
  1594. device = "cuda"
  1595. # set seed
  1596. torch.random.manual_seed(0)
  1597. batch_size = 2
  1598. nheads = 5
  1599. q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5
  1600. k, v = [
  1601. torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3
  1602. for _ in range(2)
  1603. ]
  1604. q.requires_grad_(True)
  1605. k.requires_grad_(True)
  1606. v.requires_grad_(True)
  1607. out = flash_attn_func(q, k, v, causal=causal)
  1608. g = torch.randn_like(out)
  1609. out.backward(g)
  1610. q_pt = q.detach().clone().requires_grad_(True)
  1611. k_pt = k.detach().clone().requires_grad_(True)
  1612. v_pt = v.detach().clone().requires_grad_(True)
  1613. out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
  1614. out_pt.backward(g)
  1615. q_ref = q.detach().clone().requires_grad_(True)
  1616. k_ref = k.detach().clone().requires_grad_(True)
  1617. v_ref = v.detach().clone().requires_grad_(True)
  1618. out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
  1619. out_ref.backward(g)
  1620. print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
  1621. print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
  1622. print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
  1623. print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
  1624. print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
  1625. print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
  1626. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  1627. assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (
  1628. q_pt.grad - q_ref.grad
  1629. ).abs().max().item() + 1e-3
  1630. assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (
  1631. k_pt.grad - k_ref.grad
  1632. ).abs().max().item() + 1e-3
  1633. assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (
  1634. v_pt.grad - v_ref.grad
  1635. ).abs().max().item() + 1e-3
  1636. @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  1637. # @pytest.mark.parametrize('dtype', [torch.bfloat16])
  1638. @pytest.mark.parametrize("causal", [False, True])
  1639. # @pytest.mark.parametrize('causal', [False])
  1640. @pytest.mark.parametrize("d", [64, 128])
  1641. # @pytest.mark.parametrize('d', [64])
  1642. @pytest.mark.parametrize("seqlen", [97, 128, 200, 256])
  1643. # @pytest.mark.parametrize('seqlen', [128])
  1644. def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
  1645. """We previously had a bug where we were using the wrong strides of dout, which shows up
  1646. when dout is not contiguous.
  1647. """
  1648. device = "cuda"
  1649. # set seed
  1650. torch.random.manual_seed(0)
  1651. batch_size = 5
  1652. nheads = 2
  1653. q, k, v = [
  1654. torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True)
  1655. for _ in range(3)
  1656. ]
  1657. out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...")
  1658. # So g is not contiguous
  1659. g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2]
  1660. out.backward(g)
  1661. q_pt = q.detach().clone().requires_grad_(True)
  1662. k_pt = k.detach().clone().requires_grad_(True)
  1663. v_pt = v.detach().clone().requires_grad_(True)
  1664. out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
  1665. out_pt = rearrange(out_pt, "b s ... -> s b ...")
  1666. out_pt.backward(g)
  1667. q_ref = q.detach().clone().requires_grad_(True)
  1668. k_ref = k.detach().clone().requires_grad_(True)
  1669. v_ref = v.detach().clone().requires_grad_(True)
  1670. out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
  1671. out_ref = rearrange(out_ref, "b s ... -> s b ...")
  1672. out_ref.backward(g)
  1673. print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
  1674. print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
  1675. print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
  1676. print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
  1677. print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
  1678. print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
  1679. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  1680. assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (
  1681. q_pt.grad - q_ref.grad
  1682. ).abs().max().item()
  1683. assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (
  1684. k_pt.grad - k_ref.grad
  1685. ).abs().max().item()
  1686. assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (
  1687. v_pt.grad - v_ref.grad
  1688. ).abs().max().item()
  1689. @pytest.mark.parametrize("dtype", [torch.float16])
  1690. @pytest.mark.parametrize("causal", [False, True])
  1691. # @pytest.mark.parametrize('causal', [False])
  1692. @pytest.mark.parametrize("d", [16, 32, 64])
  1693. # @pytest.mark.parametrize('d', [16])
  1694. def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
  1695. """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
  1696. in the case where seqlen % 128 != 0 or varlen.
  1697. """
  1698. device = "cuda"
  1699. # set seed
  1700. torch.random.manual_seed(0)
  1701. nheads = 5
  1702. q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32)
  1703. k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32)
  1704. Mq = 256
  1705. Mk = 3
  1706. q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3
  1707. k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)]
  1708. q.requires_grad_(True)
  1709. k.requires_grad_(True)
  1710. v.requires_grad_(True)
  1711. out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal)
  1712. g = torch.randn_like(out)
  1713. out.backward(g)
  1714. assert not q.grad.isnan().any()
  1715. assert not k.grad.isnan().any()
  1716. assert not v.grad.isnan().any()