1
0

test_flash_attn_triton_amd.py 81 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159
  1. import math
  2. import os
  3. import random
  4. import pytest
  5. import torch
  6. import torch.nn.functional as F
  7. from einops import rearrange, repeat
  8. from flash_attn import (
  9. flash_attn_func,
  10. flash_attn_kvpacked_func,
  11. flash_attn_qkvpacked_func,
  12. flash_attn_varlen_func,
  13. flash_attn_varlen_kvpacked_func,
  14. flash_attn_varlen_qkvpacked_func,
  15. flash_attn_with_kvcache,
  16. )
  17. from flash_attn.bert_padding import pad_input, unpad_input
  18. from flash_attn.flash_attn_interface import _get_block_size_n
  19. from flash_attn.layers.rotary import apply_rotary_emb
  20. from flash_attn.flash_attn_triton_amd.utils import DEBUG
  21. # Test ROCM Triton Backend
  22. USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
  23. if USE_TRITON_ROCM:
  24. random.seed(42)
  25. MAX_HEADDIM_SM8x = 192
  26. is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
  27. is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
  28. is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
  29. is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
  30. def attn_bias_from_alibi_slopes(
  31. slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None
  32. ):
  33. batch, nheads = slopes.shape
  34. device = slopes.device
  35. slopes = rearrange(slopes, "b h -> b h 1 1")
  36. if causal:
  37. return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
  38. else:
  39. row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
  40. col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
  41. if key_leftpad is not None:
  42. key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
  43. col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
  44. col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
  45. sk = (
  46. seqlen_k
  47. if key_padding_mask is None
  48. else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
  49. )
  50. sq = (
  51. seqlen_q
  52. if query_padding_mask is None
  53. else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
  54. )
  55. relative_pos = torch.abs(row_idx + sk - sq - col_idx)
  56. return -slopes * relative_pos.to(dtype=slopes.dtype)
  57. def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
  58. assert mode in ["full", "random", "third"]
  59. if mode == "full":
  60. lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
  61. elif mode == "random":
  62. lengths = torch.randint(
  63. max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
  64. )
  65. elif mode == "third":
  66. lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
  67. padding_mask = (
  68. repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
  69. )
  70. return padding_mask
  71. def generate_qkv(
  72. q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
  73. ):
  74. """
  75. Arguments:
  76. q: (batch_size, seqlen_q, nheads, d)
  77. k: (batch_size, seqlen_k, nheads_k, d)
  78. v: (batch_size, seqlen_k, nheads_k, d)
  79. query_padding_mask: (batch_size, seqlen), bool
  80. key_padding_mask: (batch_size, seqlen), bool
  81. """
  82. assert not (kvpacked and qkvpacked)
  83. batch_size, seqlen_q, nheads, d = q.shape
  84. _, seqlen_k, nheads_k, _ = k.shape
  85. assert k.shape == (batch_size, seqlen_k, nheads_k, d)
  86. assert v.shape == (batch_size, seqlen_k, nheads_k, d)
  87. if query_padding_mask is not None:
  88. q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask)
  89. output_pad_fn = lambda output_unpad: pad_input(
  90. output_unpad, indices_q, batch_size, seqlen_q
  91. )
  92. else:
  93. q_unpad = rearrange(q, "b s h d -> (b s) h d")
  94. cu_seqlens_q = torch.arange(
  95. 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
  96. )
  97. max_seqlen_q = seqlen_q
  98. output_pad_fn = lambda output_unpad: rearrange(
  99. output_unpad, "(b s) h d -> b s h d", b=batch_size
  100. )
  101. if key_padding_mask is not None:
  102. k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask)
  103. v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask)
  104. else:
  105. k_unpad = rearrange(k, "b s h d -> (b s) h d")
  106. v_unpad = rearrange(v, "b s h d -> (b s) h d")
  107. cu_seqlens_k = torch.arange(
  108. 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
  109. )
  110. max_seqlen_k = seqlen_k
  111. if qkvpacked:
  112. assert (query_padding_mask == key_padding_mask).all()
  113. assert nheads == nheads_k
  114. qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
  115. qkv = torch.stack([q, k, v], dim=2)
  116. if query_padding_mask is not None:
  117. dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
  118. else:
  119. dqkv_pad_fn = lambda dqkv_unpad: rearrange(
  120. dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
  121. )
  122. return (
  123. qkv_unpad.detach().requires_grad_(),
  124. cu_seqlens_q,
  125. max_seqlen_q,
  126. qkv.detach().requires_grad_(),
  127. output_pad_fn,
  128. dqkv_pad_fn,
  129. )
  130. elif kvpacked:
  131. kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
  132. kv = torch.stack([k, v], dim=2)
  133. dq_pad_fn = output_pad_fn
  134. if key_padding_mask is not None:
  135. dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
  136. else:
  137. dkv_pad_fn = lambda dkv_unpad: rearrange(
  138. dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
  139. )
  140. return (
  141. q_unpad.detach().requires_grad_(),
  142. kv_unpad.detach().requires_grad_(),
  143. cu_seqlens_q,
  144. cu_seqlens_k,
  145. max_seqlen_q,
  146. max_seqlen_k,
  147. q.detach().requires_grad_(),
  148. kv.detach().requires_grad_(),
  149. output_pad_fn,
  150. dq_pad_fn,
  151. dkv_pad_fn,
  152. )
  153. else:
  154. dq_pad_fn = output_pad_fn
  155. if key_padding_mask is not None:
  156. dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
  157. else:
  158. dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
  159. return (
  160. q_unpad.detach().requires_grad_(),
  161. k_unpad.detach().requires_grad_(),
  162. v_unpad.detach().requires_grad_(),
  163. cu_seqlens_q,
  164. cu_seqlens_k,
  165. max_seqlen_q,
  166. max_seqlen_k,
  167. q.detach().requires_grad_(),
  168. k.detach().requires_grad_(),
  169. v.detach().requires_grad_(),
  170. output_pad_fn,
  171. dq_pad_fn,
  172. dk_pad_fn,
  173. )
  174. def construct_local_mask(
  175. seqlen_q,
  176. seqlen_k,
  177. window_size=(-1, -1), # -1 means infinite window size
  178. query_padding_mask=None,
  179. key_padding_mask=None,
  180. device=None,
  181. key_leftpad=None,
  182. ):
  183. row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
  184. col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
  185. if key_leftpad is not None:
  186. key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
  187. col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
  188. col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
  189. sk = (
  190. seqlen_k
  191. if key_padding_mask is None
  192. else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
  193. )
  194. sq = (
  195. seqlen_q
  196. if query_padding_mask is None
  197. else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
  198. )
  199. if window_size[0] < 0:
  200. return col_idx > row_idx + sk - sq + window_size[1]
  201. else:
  202. sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
  203. return torch.logical_or(
  204. col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
  205. col_idx < row_idx + sk - sq - window_size[0],
  206. )
  207. def attention_ref(
  208. q,
  209. k,
  210. v,
  211. query_padding_mask=None,
  212. key_padding_mask=None,
  213. attn_bias=None,
  214. dropout_p=0.0,
  215. dropout_mask=None,
  216. causal=False,
  217. window_size=(-1, -1), # -1 means infinite window size
  218. softcap=0.0,
  219. upcast=True,
  220. reorder_ops=False,
  221. key_leftpad=None,
  222. ):
  223. """
  224. Arguments:
  225. q: (batch_size, seqlen_q, nheads, head_dim)
  226. k: (batch_size, seqlen_k, nheads_k, head_dim)
  227. v: (batch_size, seqlen_k, nheads_k, head_dim)
  228. query_padding_mask: (batch_size, seqlen_q)
  229. key_padding_mask: (batch_size, seqlen_k)
  230. attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
  231. dropout_p: float
  232. dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
  233. causal: whether to apply causal masking
  234. window_size: (int, int), left and right window size
  235. upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
  236. output back to fp16/bf16.
  237. reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
  238. without changing the math. This is to estimate the numerical error from operation
  239. reordering.
  240. Output:
  241. output: (batch_size, seqlen_q, nheads, head_dim)
  242. attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
  243. """
  244. if causal:
  245. window_size = (window_size[0], 0)
  246. dtype_og = q.dtype
  247. if upcast:
  248. q, k, v = q.float(), k.float(), v.float()
  249. seqlen_q, seqlen_k = q.shape[1], k.shape[1]
  250. k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
  251. v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
  252. d = q.shape[-1]
  253. if not reorder_ops:
  254. scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
  255. else:
  256. scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
  257. if softcap > 0:
  258. scores = scores / softcap
  259. scores = scores.tanh()
  260. scores = scores * softcap
  261. if key_padding_mask is not None:
  262. scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
  263. if window_size[0] >= 0 or window_size[1] >= 0:
  264. local_mask = construct_local_mask(
  265. seqlen_q,
  266. seqlen_k,
  267. window_size,
  268. query_padding_mask,
  269. key_padding_mask,
  270. q.device,
  271. key_leftpad=key_leftpad,
  272. )
  273. scores.masked_fill_(local_mask, float("-inf"))
  274. if attn_bias is not None:
  275. scores = scores + attn_bias
  276. attention = torch.softmax(scores, dim=-1).to(v.dtype)
  277. # Some rows might be completely masked out so we fill them with zero instead of NaN
  278. if window_size[0] >= 0 or window_size[1] >= 0:
  279. attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
  280. # We want to mask here so that the attention matrix doesn't have any NaNs
  281. # Otherwise we'll get NaN in dV
  282. if query_padding_mask is not None:
  283. attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
  284. dropout_scaling = 1.0 / (1 - dropout_p)
  285. # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
  286. # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
  287. if dropout_mask is not None:
  288. attention_drop = attention.masked_fill(~dropout_mask, 0.0)
  289. else:
  290. attention_drop = attention
  291. output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
  292. if query_padding_mask is not None:
  293. output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
  294. return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
  295. def attention_kvpacked_ref(
  296. q,
  297. kv,
  298. query_padding_mask=None,
  299. key_padding_mask=None,
  300. attn_bias=None,
  301. dropout_p=0.0,
  302. dropout_mask=None,
  303. causal=False,
  304. window_size=(-1, -1), # -1 means infinite window size
  305. softcap=0.0,
  306. upcast=True,
  307. reorder_ops=False,
  308. key_leftpad=None,
  309. ):
  310. return attention_ref(
  311. q,
  312. kv[:, :, 0],
  313. kv[:, :, 1],
  314. query_padding_mask,
  315. key_padding_mask,
  316. attn_bias,
  317. dropout_p,
  318. dropout_mask,
  319. upcast=upcast,
  320. causal=causal,
  321. window_size=window_size,
  322. softcap=softcap,
  323. reorder_ops=reorder_ops,
  324. key_leftpad=key_leftpad,
  325. )
  326. def attention_qkvpacked_ref(
  327. qkv,
  328. key_padding_mask=None,
  329. attn_bias=None,
  330. dropout_p=0.0,
  331. dropout_mask=None,
  332. causal=False,
  333. window_size=(-1, -1), # -1 means infinite window size
  334. softcap=0.0,
  335. upcast=True,
  336. reorder_ops=False,
  337. ):
  338. return attention_ref(
  339. qkv[:, :, 0],
  340. qkv[:, :, 1],
  341. qkv[:, :, 2],
  342. key_padding_mask,
  343. key_padding_mask,
  344. attn_bias,
  345. dropout_p,
  346. dropout_mask,
  347. upcast=upcast,
  348. causal=causal,
  349. window_size=window_size,
  350. softcap=softcap,
  351. reorder_ops=reorder_ops,
  352. )
  353. def generate_sparsity_mask(seqlen, sparsity=0.3):
  354. repeats = seqlen // 16 // 2
  355. # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
  356. # torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
  357. # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
  358. # torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
  359. # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
  360. # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
  361. nrow, ncol = seqlen // 16, seqlen // 256
  362. mask = torch.rand(nrow, ncol, device="cuda") < sparsity
  363. return mask
  364. def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):
  365. """
  366. Arguments:
  367. qkv: (batch_size, seqlen, 3, nheads, head_dim)
  368. blockmask: (seqlen / 16, seqlen / 256)
  369. attn_mask: (batch_size, seqlen)
  370. dropout_p: float
  371. dropout_mask: (batch_size, nheads, seqlen, seqlen)
  372. Output:
  373. output: (batch_size, seqlen, nheads, head_dim)
  374. attention: softmax after dropout
  375. """
  376. q, k, v = qkv.float().unbind(dim=2)
  377. d = qkv.shape[-1]
  378. seqlen = qkv.shape[1]
  379. scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
  380. scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
  381. blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)")
  382. blockmask = blockmask[:seqlen, :seqlen]
  383. scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf"))
  384. attention = torch.softmax(scores, dim=-1)
  385. attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0)
  386. attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0)
  387. attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p)
  388. output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
  389. output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0)
  390. return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
  391. def convert_flash_attn_S_to_softmax(
  392. S,
  393. seqlen_q,
  394. seqlen_k,
  395. query_padding_mask,
  396. key_padding_mask,
  397. head_dim,
  398. is_dropout,
  399. causal=False,
  400. window_size=(-1, -1), # -1 means infinite window size
  401. ):
  402. """FlashAttention stores the S matrix in a different way.
  403. Arguments:
  404. S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
  405. query_padding_mask: (batch_size, seqlen_q_rounded)
  406. key_padding_mask: (batch_size, seqlen_k_rounded)
  407. """
  408. if causal:
  409. window_size = (window_size[0], 0)
  410. seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
  411. S_converted = S
  412. if window_size[0] >= 0 or window_size[1] >= 0:
  413. local_mask = construct_local_mask(
  414. seqlen_q,
  415. seqlen_k,
  416. window_size,
  417. query_padding_mask,
  418. key_padding_mask,
  419. S.device,
  420. )
  421. local_mask = F.pad(
  422. local_mask,
  423. (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
  424. value=True,
  425. )
  426. S_converted = S_converted.masked_fill(local_mask, 0.0)
  427. # Need to zero out things not in attention_mask in case S was initialized with random values
  428. # and some of those values aren't overwritten.
  429. seqlen_q_og = (
  430. query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
  431. )
  432. if query_padding_mask is not None:
  433. query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
  434. S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
  435. seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
  436. if key_padding_mask is not None:
  437. key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
  438. S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
  439. S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
  440. S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
  441. return S_converted[:, :, :seqlen_q, :seqlen_k]
  442. def normalize_flash_attn_S(
  443. attn_unnorm,
  444. q,
  445. k,
  446. v,
  447. query_padding_mask=None,
  448. key_padding_mask=None,
  449. attn_bias=None,
  450. is_dropout=False,
  451. causal=False,
  452. window_size=(-1, -1), # -1 means infinite window size
  453. ):
  454. """
  455. Arguments:
  456. q: (batch_size, seqlen_q, nheads, head_dim)
  457. k, v: (batch_size, seqlen_k, nheads, head_dim)
  458. key_padding_mask: (batch_size, seqlen_q)
  459. attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
  460. Output:
  461. softmax_lse: (batch_size, nheads, seqlen_q)
  462. softmax_max: (batch_size, nheads, seqlen_q)
  463. """
  464. if causal:
  465. window_size = (window_size[0], 0)
  466. q, k, v = q.float(), k.float(), v.float()
  467. _, seqlen_q, _, head_dim = q.shape
  468. seqlen_k = k.shape[1]
  469. scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
  470. if key_padding_mask is not None:
  471. scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
  472. if window_size[0] >= 0 or window_size[1] >= 0:
  473. local_mask = construct_local_mask(
  474. seqlen_q,
  475. seqlen_k,
  476. window_size,
  477. query_padding_mask,
  478. key_padding_mask,
  479. q.device,
  480. )
  481. scores.masked_fill_(local_mask, float("-inf"))
  482. if attn_bias is not None:
  483. scores = scores + attn_bias.to(dtype=scores.dtype)
  484. block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal)
  485. scores_block = scores.split(block_size_n, dim=-1)
  486. lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
  487. lse = torch.logsumexp(lse_block, dim=-1)
  488. # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
  489. # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
  490. lse[lse == float("-inf")] = float("inf")
  491. scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
  492. cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
  493. attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
  494. attn_norm = torch.cat(
  495. [
  496. a * rearrange(torch.exp(m - lse), "b h s -> b h s 1")
  497. for a, m in zip(attn_unnorm_block, cummax_block)
  498. ],
  499. dim=-1,
  500. )
  501. if query_padding_mask is not None:
  502. attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
  503. return attn_norm.to(dtype=attn_unnorm.dtype)
  504. def get_dropout_fraction(
  505. dropout_mask,
  506. query_padding_mask=None,
  507. key_padding_mask=None,
  508. causal=False,
  509. window_size=(-1, -1), # -1 means infinite window size
  510. ):
  511. """
  512. dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
  513. query_padding_mask: (batch_size, seqlen_q)
  514. key_padding_mask: (batch_size, seqlen_k)
  515. """
  516. if causal:
  517. window_size = (window_size[0], 0)
  518. batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
  519. dropped = ~dropout_mask
  520. valid = torch.ones_like(dropout_mask)
  521. if query_padding_mask is not None:
  522. dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
  523. valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
  524. if key_padding_mask is not None:
  525. dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
  526. valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
  527. if window_size[0] >= 0 or window_size[1] >= 0:
  528. local_mask = construct_local_mask(
  529. seqlen_q,
  530. seqlen_k,
  531. window_size,
  532. query_padding_mask,
  533. key_padding_mask,
  534. dropout_mask.device,
  535. )
  536. dropped.masked_fill_(local_mask, False)
  537. valid.masked_fill_(local_mask, False)
  538. dropped_total = dropped.sum()
  539. return dropped.sum() / valid.sum()
  540. # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  541. @pytest.mark.parametrize("dtype", [torch.float16])
  542. # @pytest.mark.parametrize("deterministic", [False, True])
  543. @pytest.mark.parametrize("deterministic", [False])
  544. # @pytest.mark.parametrize("alibi", [False, True])
  545. @pytest.mark.parametrize("alibi", [False])
  546. # @pytest.mark.parametrize("local", [False, True])
  547. @pytest.mark.parametrize("local", [False])
  548. # @pytest.mark.parametrize("causal", [False, True])
  549. @pytest.mark.parametrize("causal", [False])
  550. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  551. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  552. # @pytest.mark.parametrize('d', [32, 64, 96, 128])
  553. # @pytest.mark.parametrize("d", [32])
  554. # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
  555. @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
  556. # @pytest.mark.parametrize("seqlen", [128])
  557. # @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  558. @pytest.mark.parametrize("dropout_p", [0.0])
  559. def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
  560. if USE_TRITON_ROCM:
  561. if dropout_p != 0.0:
  562. pytest.skip("Dropout not supported in AMD's Triton Backend yet")
  563. if local == True:
  564. pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet")
  565. if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
  566. pytest.skip() # Reference implementation OOM
  567. device = "cuda"
  568. # set seed
  569. torch.random.manual_seed(0)
  570. batch_size = 4
  571. nheads = 9
  572. window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
  573. qkv = torch.randn(
  574. batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
  575. )
  576. if alibi:
  577. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  578. attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)
  579. else:
  580. alibi_slopes, attn_bias = None, None
  581. out, lse, S_dmask = flash_attn_qkvpacked_func(
  582. qkv,
  583. dropout_p,
  584. causal=causal,
  585. window_size=window_size,
  586. alibi_slopes=alibi_slopes,
  587. deterministic=deterministic,
  588. return_attn_probs=True,
  589. )
  590. if dropout_p > 0.0:
  591. S_dmask_converted = convert_flash_attn_S_to_softmax(
  592. S_dmask,
  593. seqlen,
  594. seqlen,
  595. None,
  596. None,
  597. d,
  598. dropout_p > 0.0,
  599. causal=causal,
  600. window_size=window_size,
  601. )
  602. dropout_mask = S_dmask_converted >= 0
  603. attn_unnorm = S_dmask_converted.abs()
  604. attn = normalize_flash_attn_S(
  605. attn_unnorm,
  606. qkv[:, :, 0],
  607. qkv[:, :, 1],
  608. qkv[:, :, 2],
  609. None,
  610. None,
  611. attn_bias,
  612. dropout_p > 0.0,
  613. causal=causal,
  614. window_size=window_size,
  615. )
  616. dropout_fraction = get_dropout_fraction(
  617. dropout_mask, None, None, causal=causal, window_size=window_size
  618. ).item()
  619. print(f"Actual dropout fraction: {dropout_fraction}")
  620. else:
  621. dropout_mask = None
  622. out_ref, attn_ref = attention_qkvpacked_ref(
  623. qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size
  624. )
  625. out_pt, attn_pt = attention_qkvpacked_ref(
  626. qkv,
  627. None,
  628. attn_bias,
  629. dropout_p,
  630. dropout_mask,
  631. causal=causal,
  632. window_size=window_size,
  633. upcast=False,
  634. reorder_ops=True,
  635. )
  636. # v = qkv[:, :, 2].float()
  637. # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
  638. # if causal:
  639. # causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)
  640. # qk.masked_fill_(causal_mask, float('-inf'))
  641. # m = qk.amax(-1, keepdim=True)
  642. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  643. # p_tmp = torch.softmax(qk / math.sqrt(d), -1)
  644. # p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0)
  645. # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
  646. # qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values
  647. # qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values
  648. # qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values
  649. # qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values
  650. # o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:])
  651. # o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:])
  652. # o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:])
  653. # o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :])
  654. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  655. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  656. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  657. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  658. if dropout_p > 0.0:
  659. print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
  660. print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
  661. g = torch.randn_like(out)
  662. # do_o = (g.float() * out.float()).sum(-1)
  663. # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
  664. # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
  665. if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
  666. (dqkv,) = torch.autograd.grad(out, qkv, g)
  667. (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
  668. (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
  669. print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  670. print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  671. print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  672. print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
  673. print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  674. print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  675. print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  676. print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
  677. # Check that FlashAttention's numerical error is at most twice the numerical error
  678. # of a Pytorch implementation.
  679. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  680. if dropout_p > 0.0:
  681. assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
  682. # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
  683. if not alibi:
  684. assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
  685. if DEBUG:
  686. print("dqkv:", dqkv, dqkv.shape)
  687. print("dqkv_ref:", dqkv_ref, dqkv_ref.shape)
  688. print("dqkv_pt:", dqkv_pt, dqkv_pt.shape)
  689. if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
  690. assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
  691. # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  692. @pytest.mark.parametrize('dtype', [torch.float16])
  693. # @pytest.mark.parametrize("deterministic", [False, True])
  694. @pytest.mark.parametrize("deterministic", [False])
  695. # @pytest.mark.parametrize("alibi", [False, True])
  696. @pytest.mark.parametrize("alibi", [False])
  697. # @pytest.mark.parametrize("local", [False, True])
  698. @pytest.mark.parametrize("local", [False])
  699. # @pytest.mark.parametrize("causal", [False, True])
  700. @pytest.mark.parametrize('causal', [False])
  701. @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
  702. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  703. # @pytest.mark.parametrize('d', [32])
  704. @pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])
  705. # @pytest.mark.parametrize('seqlen', [128])
  706. # @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  707. @pytest.mark.parametrize('dropout_p', [0.0])
  708. def test_flash_attn_varlen_qkvpacked(
  709. seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype
  710. ):
  711. if USE_TRITON_ROCM:
  712. if dropout_p != 0.0:
  713. pytest.skip("Dropout not supported in AMD's Triton Backend yet")
  714. if local == True:
  715. pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet")
  716. if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
  717. pytest.skip() # Reference implementation OOM
  718. device = "cuda"
  719. # set seed
  720. torch.random.manual_seed(0)
  721. batch_size = 5
  722. nheads = 6
  723. window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
  724. qkv = torch.randn(
  725. batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
  726. )
  727. key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
  728. # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
  729. if alibi:
  730. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  731. attn_bias = attn_bias_from_alibi_slopes(
  732. alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal
  733. )
  734. else:
  735. alibi_slopes, attn_bias = None, None
  736. qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
  737. *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
  738. )
  739. out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
  740. qkv_unpad,
  741. cu_seqlens,
  742. max_seqlen,
  743. dropout_p,
  744. causal=causal,
  745. window_size=window_size,
  746. alibi_slopes=alibi_slopes,
  747. deterministic=deterministic,
  748. return_attn_probs=True,
  749. )
  750. out = output_pad_fn(out_unpad)
  751. if dropout_p > 0.0:
  752. S_dmask_converted = convert_flash_attn_S_to_softmax(
  753. S_dmask,
  754. seqlen,
  755. seqlen,
  756. key_padding_mask,
  757. key_padding_mask,
  758. d,
  759. dropout_p > 0.0,
  760. causal=causal,
  761. window_size=window_size,
  762. )
  763. dropout_mask = S_dmask_converted >= 0
  764. attn_unnorm = S_dmask_converted.abs()
  765. attn = normalize_flash_attn_S(
  766. attn_unnorm,
  767. qkv[:, :, 0],
  768. qkv[:, :, 1],
  769. qkv[:, :, 2],
  770. key_padding_mask,
  771. key_padding_mask,
  772. attn_bias,
  773. dropout_p > 0.0,
  774. causal=causal,
  775. window_size=window_size,
  776. )
  777. dropout_fraction = get_dropout_fraction(
  778. dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size
  779. ).item()
  780. print(f"Actual dropout fraction: {dropout_fraction}")
  781. else:
  782. dropout_mask = None
  783. out_ref, attn_ref = attention_qkvpacked_ref(
  784. qkv,
  785. key_padding_mask,
  786. attn_bias,
  787. dropout_p,
  788. dropout_mask,
  789. causal=causal,
  790. window_size=window_size,
  791. )
  792. out_pt, attn_pt = attention_qkvpacked_ref(
  793. qkv,
  794. key_padding_mask,
  795. attn_bias,
  796. dropout_p,
  797. dropout_mask,
  798. causal=causal,
  799. window_size=window_size,
  800. upcast=False,
  801. reorder_ops=True,
  802. )
  803. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  804. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  805. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  806. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  807. if dropout_p > 0.0:
  808. print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
  809. print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
  810. g = torch.randn_like(out)
  811. if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
  812. (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
  813. dqkv = dqkv_pad_fn(dqkv_unpad)
  814. (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
  815. (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
  816. print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  817. print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  818. print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  819. print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
  820. print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
  821. print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
  822. print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
  823. print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
  824. # Check that FlashAttention's numerical error is at most twice the numerical error
  825. # of a Pytorch implementation.
  826. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  827. if dropout_p > 0.0:
  828. assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
  829. # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
  830. if not alibi:
  831. assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
  832. if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
  833. assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
  834. # @pytest.mark.parametrize("kvpacked", [True, False])
  835. @pytest.mark.parametrize("kvpacked", [False])
  836. # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  837. # @pytest.mark.parametrize("dtype", [torch.bfloat16])
  838. @pytest.mark.parametrize("dtype", [torch.float16])
  839. # @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  840. @pytest.mark.parametrize("mha_type", ["mha"])
  841. # @pytest.mark.parametrize("deterministic", [False, True])
  842. # @pytest.mark.parametrize("deterministic", [True])
  843. @pytest.mark.parametrize("deterministic", [False])
  844. # @pytest.mark.parametrize("alibi", [False, True])
  845. @pytest.mark.parametrize("alibi", [False])
  846. # @pytest.mark.parametrize("local", [False, True])
  847. @pytest.mark.parametrize("local", [False])
  848. # @pytest.mark.parametrize("causal", [False, True])
  849. # @pytest.mark.parametrize("causal", [True])
  850. @pytest.mark.parametrize("causal", [False])
  851. @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
  852. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  853. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  854. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  855. # @pytest.mark.parametrize('d', [56, 80])
  856. # @pytest.mark.parametrize("d", [64])
  857. @pytest.mark.parametrize(
  858. "seqlen_q,seqlen_k",
  859. [
  860. (113, 203),
  861. (128, 217),
  862. (113, 211),
  863. (108, 256),
  864. (256, 512),
  865. (512, 256),
  866. (1024, 1024),
  867. (1023, 1024),
  868. (1024, 1023),
  869. (2048, 2048),
  870. ],
  871. )
  872. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
  873. # @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  874. @pytest.mark.parametrize("dropout_p", [0.0])
  875. # @pytest.mark.parametrize("softcap", [0.0, 50.0])
  876. @pytest.mark.parametrize("softcap", [0.0])
  877. def test_flash_attn_output(
  878. seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
  879. ):
  880. if USE_TRITON_ROCM:
  881. if dropout_p != 0.0:
  882. pytest.skip("Dropout not supported on AMD's Triton Backend yet")
  883. if softcap != 0.0:
  884. pytest.skip("softcap not supported on AMD's Triton Backend yet")
  885. if local == True:
  886. pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet")
  887. if (
  888. max(seqlen_q, seqlen_k) >= 2048
  889. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  890. ):
  891. pytest.skip() # Reference implementation OOM
  892. if softcap > 0.0 and dropout_p > 0.0:
  893. pytest.skip("Softcap and dropout not supported together")
  894. device = "cuda"
  895. # set seed
  896. torch.random.manual_seed(0)
  897. batch_size = 4
  898. nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
  899. nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
  900. assert nheads % nheads_k == 0
  901. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  902. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  903. if softcap > 0:
  904. # Ensure the values of qk are at least within softcap range.
  905. q = q * softcap
  906. if kvpacked:
  907. kv = torch.randn(
  908. batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  909. )
  910. else:
  911. k = torch.randn(
  912. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  913. )
  914. v = torch.randn(
  915. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  916. )
  917. if alibi:
  918. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  919. attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
  920. else:
  921. alibi_slopes, attn_bias = None, None
  922. if kvpacked:
  923. out, lse, S_dmask = flash_attn_kvpacked_func(
  924. q,
  925. kv,
  926. dropout_p,
  927. causal=causal,
  928. window_size=window_size,
  929. softcap=softcap,
  930. alibi_slopes=alibi_slopes,
  931. deterministic=deterministic,
  932. return_attn_probs=True,
  933. )
  934. else:
  935. out, lse, S_dmask = flash_attn_func(
  936. q,
  937. k,
  938. v,
  939. dropout_p,
  940. causal=causal,
  941. window_size=window_size,
  942. softcap=softcap,
  943. alibi_slopes=alibi_slopes,
  944. deterministic=deterministic,
  945. return_attn_probs=True,
  946. )
  947. if DEBUG:
  948. print("out:", out, out.shape)
  949. print("lse:", lse, lse.shape)
  950. if dropout_p > 0.0:
  951. S_dmask_converted = convert_flash_attn_S_to_softmax(
  952. S_dmask,
  953. seqlen_q,
  954. seqlen_k,
  955. None,
  956. None,
  957. d,
  958. dropout_p > 0.0,
  959. causal=causal,
  960. window_size=window_size,
  961. )
  962. dropout_mask = S_dmask_converted >= 0
  963. attn_unnorm = S_dmask_converted.abs()
  964. if kvpacked:
  965. kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
  966. k_rep, v_rep = kv_rep.unbind(dim=2)
  967. else:
  968. k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  969. v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  970. attn = normalize_flash_attn_S(
  971. attn_unnorm,
  972. q,
  973. k_rep,
  974. v_rep,
  975. None,
  976. None,
  977. attn_bias,
  978. dropout_p > 0.0,
  979. causal=causal,
  980. window_size=window_size,
  981. )
  982. dropout_fraction = get_dropout_fraction(
  983. dropout_mask, None, None, causal=causal, window_size=window_size
  984. ).item()
  985. print(f"Actual dropout fraction: {dropout_fraction}")
  986. else:
  987. dropout_mask = None
  988. if kvpacked:
  989. out_ref, attn_ref = attention_kvpacked_ref(
  990. q,
  991. kv,
  992. None,
  993. None,
  994. attn_bias,
  995. dropout_p,
  996. dropout_mask,
  997. causal=causal,
  998. window_size=window_size,
  999. softcap=softcap,
  1000. )
  1001. out_pt, attn_pt = attention_kvpacked_ref(
  1002. q,
  1003. kv,
  1004. None,
  1005. None,
  1006. attn_bias,
  1007. dropout_p,
  1008. dropout_mask,
  1009. causal=causal,
  1010. window_size=window_size,
  1011. softcap=softcap,
  1012. upcast=False,
  1013. reorder_ops=True,
  1014. )
  1015. else:
  1016. out_ref, attn_ref = attention_ref(
  1017. q,
  1018. k,
  1019. v,
  1020. None,
  1021. None,
  1022. attn_bias,
  1023. dropout_p,
  1024. dropout_mask,
  1025. causal=causal,
  1026. window_size=window_size,
  1027. softcap=softcap,
  1028. )
  1029. out_pt, attn_pt = attention_ref(
  1030. q,
  1031. k,
  1032. v,
  1033. None,
  1034. None,
  1035. attn_bias,
  1036. dropout_p,
  1037. dropout_mask,
  1038. causal=causal,
  1039. window_size=window_size,
  1040. softcap=softcap,
  1041. upcast=False,
  1042. reorder_ops=True,
  1043. )
  1044. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  1045. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  1046. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  1047. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  1048. if dropout_p > 0.0:
  1049. print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
  1050. print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
  1051. g = torch.randn_like(out)
  1052. do_o = (g.float() * out.float()).sum(-1)
  1053. if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
  1054. if kvpacked:
  1055. (
  1056. dq,
  1057. dkv,
  1058. ) = torch.autograd.grad(out, (q, kv), g)
  1059. dk, dv = dkv.unbind(2)
  1060. (
  1061. dq_ref,
  1062. dkv_ref,
  1063. ) = torch.autograd.grad(out_ref, (q, kv), g)
  1064. dk_ref, dv_ref = dkv_ref.unbind(2)
  1065. (
  1066. dq_pt,
  1067. dkv_pt,
  1068. ) = torch.autograd.grad(out_pt, (q, kv), g)
  1069. dk_pt, dv_pt = dkv_pt.unbind(2)
  1070. else:
  1071. (
  1072. dq,
  1073. dk,
  1074. dv,
  1075. ) = torch.autograd.grad(out, (q, k, v), g)
  1076. (
  1077. dq_ref,
  1078. dk_ref,
  1079. dv_ref,
  1080. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  1081. (
  1082. dq_pt,
  1083. dk_pt,
  1084. dv_pt,
  1085. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  1086. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  1087. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  1088. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  1089. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  1090. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  1091. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  1092. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  1093. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  1094. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  1095. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  1096. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  1097. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  1098. # Check that FlashAttention's numerical error is at most twice the numerical error
  1099. # of a Pytorch implementation.
  1100. if DEBUG:
  1101. print("out:", out, out.shape)
  1102. print("out_ref:", out_ref, out_ref.shape)
  1103. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  1104. if dropout_p > 0.0:
  1105. assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
  1106. # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
  1107. if not alibi:
  1108. assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
  1109. if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
  1110. if DEBUG:
  1111. print("dv:", dv, dv.shape)
  1112. print("dv_ref:", dv_ref, dv_ref.shape)
  1113. print("dv_pt:", dv_pt, dv_pt.shape)
  1114. assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
  1115. if DEBUG:
  1116. print("dk:", dk, dk.shape)
  1117. print("dk_ref:", dk_ref, dk_ref.shape)
  1118. print("dk_pt:", dk_pt, dk_pt.shape)
  1119. assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
  1120. if DEBUG:
  1121. print("dq:", dq, dq.shape)
  1122. print("dq_ref:", dq_ref, dq_ref.shape)
  1123. print("dq_pt:", dq_pt, dq_pt.shape)
  1124. assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
  1125. @pytest.mark.parametrize("kvpacked", [False])
  1126. # @pytest.mark.parametrize('kvpacked', [False])
  1127. # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  1128. @pytest.mark.parametrize('dtype', [torch.float16])
  1129. # @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  1130. @pytest.mark.parametrize('mha_type', ["mha"])
  1131. # @pytest.mark.parametrize("deterministic", [False, True])
  1132. @pytest.mark.parametrize("deterministic", [False])
  1133. # @pytest.mark.parametrize("alibi", [False, True])
  1134. @pytest.mark.parametrize("alibi", [False])
  1135. # @pytest.mark.parametrize("local", [False, True])
  1136. @pytest.mark.parametrize("local", [False])
  1137. # @pytest.mark.parametrize("causal", [False, True])
  1138. @pytest.mark.parametrize('causal', [False])
  1139. @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  1140. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  1141. # @pytest.mark.parametrize('d', [160])
  1142. @pytest.mark.parametrize(
  1143. "seqlen_q,seqlen_k",
  1144. [
  1145. (1, 147),
  1146. (113, 203),
  1147. (128, 217),
  1148. (113, 211),
  1149. (108, 256),
  1150. (256, 512),
  1151. (512, 256),
  1152. (1024, 1024),
  1153. (1023, 1024),
  1154. (1024, 1023),
  1155. (2048, 2048),
  1156. ],
  1157. )
  1158. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
  1159. # @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
  1160. @pytest.mark.parametrize('dropout_p', [0.0])
  1161. # @pytest.mark.parametrize("softcap", [0.0, 50.0])
  1162. @pytest.mark.parametrize("softcap", [0.0])
  1163. def test_flash_attn_varlen_output(
  1164. seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
  1165. ):
  1166. if USE_TRITON_ROCM:
  1167. if dropout_p != 0.0:
  1168. pytest.skip("Dropout not supported in AMD's Triton Backend yet")
  1169. if local == True:
  1170. pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet")
  1171. if softcap != 0.0:
  1172. pytest.skip("softcap not supported on AMD's Triton Backend yet")
  1173. if (
  1174. max(seqlen_q, seqlen_k) >= 2048
  1175. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  1176. ):
  1177. pytest.skip() # Reference implementation OOM
  1178. if softcap > 0.0 and dropout_p > 0.0:
  1179. pytest.skip("Softcap and dropout not supported together")
  1180. device = "cuda"
  1181. # set seed
  1182. torch.random.manual_seed(0)
  1183. batch_size = 4
  1184. nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
  1185. nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
  1186. assert nheads % nheads_k == 0
  1187. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  1188. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1189. if softcap > 0:
  1190. # Ensure the values of qk are at least within softcap range.
  1191. q = q * softcap
  1192. if kvpacked:
  1193. kv = torch.randn(
  1194. batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  1195. )
  1196. else:
  1197. k = torch.randn(
  1198. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  1199. )
  1200. v = torch.randn(
  1201. batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
  1202. )
  1203. query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
  1204. key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
  1205. # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
  1206. if alibi:
  1207. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  1208. attn_bias = attn_bias_from_alibi_slopes(
  1209. alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal
  1210. )
  1211. else:
  1212. alibi_slopes, attn_bias = None, None
  1213. if kvpacked:
  1214. (
  1215. q_unpad,
  1216. kv_unpad,
  1217. cu_seqlens_q,
  1218. cu_seqlens_k,
  1219. max_seqlen_q,
  1220. max_seqlen_k,
  1221. q,
  1222. kv,
  1223. output_pad_fn,
  1224. dq_pad_fn,
  1225. dkv_pad_fn,
  1226. ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)
  1227. out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
  1228. q_unpad,
  1229. kv_unpad,
  1230. cu_seqlens_q,
  1231. cu_seqlens_k,
  1232. max_seqlen_q,
  1233. max_seqlen_k,
  1234. dropout_p,
  1235. causal=causal,
  1236. window_size=window_size,
  1237. softcap=softcap,
  1238. alibi_slopes=alibi_slopes,
  1239. deterministic=deterministic,
  1240. return_attn_probs=True,
  1241. )
  1242. else:
  1243. (
  1244. q_unpad,
  1245. k_unpad,
  1246. v_unpad,
  1247. cu_seqlens_q,
  1248. cu_seqlens_k,
  1249. max_seqlen_q,
  1250. max_seqlen_k,
  1251. q,
  1252. k,
  1253. v,
  1254. output_pad_fn,
  1255. dq_pad_fn,
  1256. dk_pad_fn,
  1257. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
  1258. out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
  1259. q_unpad,
  1260. k_unpad,
  1261. v_unpad,
  1262. cu_seqlens_q,
  1263. cu_seqlens_k,
  1264. max_seqlen_q,
  1265. max_seqlen_k,
  1266. dropout_p,
  1267. causal=causal,
  1268. window_size=window_size,
  1269. softcap=softcap,
  1270. alibi_slopes=alibi_slopes,
  1271. deterministic=deterministic,
  1272. return_attn_probs=True,
  1273. )
  1274. if DEBUG:
  1275. print("out_unpad:", out_unpad, out_unpad.shape)
  1276. print("sm_lse:", sm_lse, sm_lse.shape)
  1277. out = output_pad_fn(out_unpad)
  1278. if dropout_p > 0.0:
  1279. S_dmask_converted = convert_flash_attn_S_to_softmax(
  1280. S_dmask,
  1281. seqlen_q,
  1282. seqlen_k,
  1283. query_padding_mask,
  1284. key_padding_mask,
  1285. d,
  1286. dropout_p > 0.0,
  1287. causal=causal,
  1288. window_size=window_size,
  1289. )
  1290. dropout_mask = S_dmask_converted >= 0
  1291. attn_unnorm = S_dmask_converted.abs()
  1292. if kvpacked:
  1293. kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
  1294. k_rep, v_rep = kv_rep.unbind(dim=2)
  1295. else:
  1296. k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  1297. v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  1298. attn = normalize_flash_attn_S(
  1299. attn_unnorm,
  1300. q,
  1301. k_rep,
  1302. v_rep,
  1303. query_padding_mask,
  1304. key_padding_mask,
  1305. attn_bias,
  1306. dropout_p > 0.0,
  1307. causal=causal,
  1308. window_size=window_size,
  1309. )
  1310. dropout_fraction = get_dropout_fraction(
  1311. dropout_mask,
  1312. query_padding_mask,
  1313. key_padding_mask,
  1314. causal=causal,
  1315. window_size=window_size,
  1316. ).item()
  1317. print(f"Actual dropout fraction: {dropout_fraction}")
  1318. else:
  1319. dropout_mask = None
  1320. if kvpacked:
  1321. out_ref, attn_ref = attention_kvpacked_ref(
  1322. q,
  1323. kv,
  1324. query_padding_mask,
  1325. key_padding_mask,
  1326. attn_bias,
  1327. dropout_p,
  1328. dropout_mask,
  1329. causal=causal,
  1330. window_size=window_size,
  1331. softcap=softcap,
  1332. )
  1333. out_pt, attn_pt = attention_kvpacked_ref(
  1334. q,
  1335. kv,
  1336. query_padding_mask,
  1337. key_padding_mask,
  1338. attn_bias,
  1339. dropout_p,
  1340. dropout_mask,
  1341. causal=causal,
  1342. window_size=window_size,
  1343. softcap=softcap,
  1344. upcast=False,
  1345. reorder_ops=True,
  1346. )
  1347. else:
  1348. out_ref, attn_ref = attention_ref(
  1349. q,
  1350. k,
  1351. v,
  1352. query_padding_mask,
  1353. key_padding_mask,
  1354. attn_bias,
  1355. dropout_p,
  1356. dropout_mask,
  1357. causal=causal,
  1358. window_size=window_size,
  1359. softcap=softcap,
  1360. )
  1361. out_pt, attn_pt = attention_ref(
  1362. q,
  1363. k,
  1364. v,
  1365. query_padding_mask,
  1366. key_padding_mask,
  1367. attn_bias,
  1368. dropout_p,
  1369. dropout_mask,
  1370. causal=causal,
  1371. window_size=window_size,
  1372. softcap=softcap,
  1373. upcast=False,
  1374. reorder_ops=True,
  1375. )
  1376. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  1377. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  1378. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  1379. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  1380. if dropout_p > 0.0:
  1381. print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
  1382. print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
  1383. g = torch.randn_like(out)
  1384. if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)):
  1385. if kvpacked:
  1386. (
  1387. dq_unpad,
  1388. dkv_unpad,
  1389. ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
  1390. dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
  1391. (
  1392. dq_ref,
  1393. dkv_ref,
  1394. ) = torch.autograd.grad(out_ref, (q, kv), g)
  1395. dk_ref, dv_ref = dkv_ref.unbind(2)
  1396. (
  1397. dq_pt,
  1398. dkv_pt,
  1399. ) = torch.autograd.grad(out_pt, (q, kv), g)
  1400. dk_pt, dv_pt = dkv_pt.unbind(2)
  1401. else:
  1402. (
  1403. dq_unpad,
  1404. dk_unpad,
  1405. dv_unpad,
  1406. ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
  1407. dk = dk_pad_fn(dk_unpad)
  1408. dv = dk_pad_fn(dv_unpad)
  1409. (
  1410. dq_ref,
  1411. dk_ref,
  1412. dv_ref,
  1413. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  1414. (
  1415. dq_pt,
  1416. dk_pt,
  1417. dv_pt,
  1418. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  1419. dq = dq_pad_fn(dq_unpad)
  1420. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  1421. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  1422. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  1423. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  1424. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  1425. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  1426. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  1427. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  1428. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  1429. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  1430. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  1431. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  1432. # Check that FlashAttention's numerical error is at most twice the numerical error
  1433. # of a Pytorch implementation.
  1434. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
  1435. if dropout_p > 0.0:
  1436. assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
  1437. # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
  1438. if not alibi:
  1439. assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)
  1440. if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
  1441. if DEBUG:
  1442. print("dv:", dv, dv.shape)
  1443. print("dv_ref:", dv_ref, dv_ref.shape)
  1444. print("dv_pt:", dv_pt, dv_pt.shape)
  1445. assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
  1446. if DEBUG:
  1447. print("dk:", dk, dk.shape)
  1448. print("dk_ref:", dk_ref, dk_ref.shape)
  1449. print("dk_pt:", dk_pt, dk_pt.shape)
  1450. assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
  1451. if DEBUG:
  1452. print("dq:", dq, dq.shape)
  1453. print("dq_ref:", dq_ref, dq_ref.shape)
  1454. print("dq_pt:", dq_pt, dq_pt.shape)
  1455. assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
  1456. # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  1457. @pytest.mark.parametrize("dtype", [torch.float16])
  1458. # @pytest.mark.parametrize("local", [False, True])
  1459. @pytest.mark.parametrize("local", [False])
  1460. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  1461. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  1462. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  1463. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  1464. # @pytest.mark.parametrize('d', [56, 80])
  1465. # @pytest.mark.parametrize("d", [64, 128])
  1466. # @pytest.mark.parametrize("d", [32])
  1467. # @pytest.mark.parametrize("swap_sq_sk", [False, True])
  1468. @pytest.mark.parametrize("swap_sq_sk", [False])
  1469. @pytest.mark.parametrize(
  1470. "seqlen_q,seqlen_k",
  1471. [
  1472. (1, 239),
  1473. (3, 799),
  1474. (127, 512),
  1475. (127, 513),
  1476. (113, 203),
  1477. (128, 217),
  1478. (113, 211),
  1479. (108, 256),
  1480. (256, 512),
  1481. (1023, 1024),
  1482. ],
  1483. )
  1484. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
  1485. def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
  1486. if (
  1487. max(seqlen_q, seqlen_k) >= 2048
  1488. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  1489. ):
  1490. pytest.skip() # Reference implementation OOM
  1491. if swap_sq_sk:
  1492. seqlen_q, seqlen_k = seqlen_k, seqlen_q
  1493. device = "cuda"
  1494. causal = True
  1495. # set seed
  1496. torch.random.manual_seed(0)
  1497. batch_size = 8
  1498. nheads = 9
  1499. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  1500. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1501. k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1502. v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1503. out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size)
  1504. out_ref, attn_ref = attention_ref(
  1505. q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size
  1506. )
  1507. out_pt, attn_pt = attention_ref(
  1508. q,
  1509. k,
  1510. v,
  1511. None,
  1512. None,
  1513. None,
  1514. 0.0,
  1515. None,
  1516. causal=causal,
  1517. window_size=window_size,
  1518. upcast=False,
  1519. reorder_ops=True,
  1520. )
  1521. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  1522. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  1523. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  1524. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  1525. g = torch.randn_like(out)
  1526. do_o = (g.float() * out.float()).sum(-1)
  1527. (
  1528. dq,
  1529. dk,
  1530. dv,
  1531. ) = torch.autograd.grad(out, (q, k, v), g)
  1532. (
  1533. dq_ref,
  1534. dk_ref,
  1535. dv_ref,
  1536. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  1537. (
  1538. dq_pt,
  1539. dk_pt,
  1540. dv_pt,
  1541. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  1542. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  1543. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  1544. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  1545. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  1546. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  1547. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  1548. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  1549. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  1550. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  1551. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  1552. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  1553. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  1554. # Check that FlashAttention's numerical error is at most twice the numerical error
  1555. # of a Pytorch implementation.
  1556. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
  1557. if DEBUG:
  1558. print("dv:", dv, dv.shape)
  1559. print("dv_ref:", dv_ref, dv_ref.shape)
  1560. print("dv_pt:", dv_pt, dv_pt.shape)
  1561. assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
  1562. if DEBUG:
  1563. print("dk:", dk, dk.shape)
  1564. print("dk_ref:", dk_ref, dk_ref.shape)
  1565. print("dk_pt:", dk_pt, dk_pt.shape)
  1566. assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
  1567. if DEBUG:
  1568. print("dq:", dq, dq.shape)
  1569. print("dq_ref:", dq_ref, dq_ref.shape)
  1570. print("dq_pt:", dq_pt, dq_pt.shape)
  1571. assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
  1572. # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  1573. @pytest.mark.parametrize("dtype", [torch.float16])
  1574. # @pytest.mark.parametrize("local", [False, True])
  1575. @pytest.mark.parametrize("local", [False])
  1576. @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
  1577. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  1578. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  1579. # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
  1580. # @pytest.mark.parametrize('d', [56, 80])
  1581. # @pytest.mark.parametrize("d", [64])
  1582. # @pytest.mark.parametrize("swap_sq_sk", [False, True])
  1583. @pytest.mark.parametrize("swap_sq_sk", [False])
  1584. @pytest.mark.parametrize(
  1585. "seqlen_q,seqlen_k",
  1586. [
  1587. (1, 239),
  1588. (3, 799),
  1589. (127, 512),
  1590. (127, 513),
  1591. (113, 203),
  1592. (128, 217),
  1593. (113, 211),
  1594. (108, 256),
  1595. (256, 512),
  1596. (1023, 1024),
  1597. ],
  1598. )
  1599. # TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
  1600. # @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
  1601. @pytest.mark.parametrize("paged_kv_block_size", [None])
  1602. # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
  1603. def test_flash_attn_varlen_causal(
  1604. seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype
  1605. ):
  1606. if (
  1607. max(seqlen_q, seqlen_k) >= 2048
  1608. and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
  1609. ):
  1610. pytest.skip() # Reference implementation OOM
  1611. if swap_sq_sk:
  1612. seqlen_q, seqlen_k = seqlen_k, seqlen_q
  1613. device = "cuda"
  1614. causal = True
  1615. # set seed
  1616. torch.random.manual_seed(0)
  1617. batch_size = 8
  1618. nheads = 9
  1619. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  1620. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
  1621. if paged_kv_block_size is None:
  1622. k = torch.randn(
  1623. batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
  1624. )
  1625. v = torch.randn(
  1626. batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
  1627. )
  1628. block_table = None
  1629. else:
  1630. k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(
  1631. seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype
  1632. )
  1633. query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
  1634. key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
  1635. (
  1636. q_unpad,
  1637. k_unpad,
  1638. v_unpad,
  1639. cu_seqlens_q,
  1640. cu_seqlens_k,
  1641. max_seqlen_q,
  1642. max_seqlen_k,
  1643. q,
  1644. k,
  1645. v,
  1646. output_pad_fn,
  1647. dq_pad_fn,
  1648. dk_pad_fn,
  1649. ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
  1650. out_unpad = flash_attn_varlen_func(
  1651. q_unpad,
  1652. k_unpad if paged_kv_block_size is None else k_cache_paged,
  1653. v_unpad if paged_kv_block_size is None else v_cache_paged,
  1654. cu_seqlens_q,
  1655. cu_seqlens_k,
  1656. max_seqlen_q,
  1657. max_seqlen_k,
  1658. 0.0,
  1659. causal=causal,
  1660. window_size=window_size,
  1661. block_table=block_table,
  1662. )
  1663. out = output_pad_fn(out_unpad)
  1664. out_ref, attn_ref = attention_ref(
  1665. q,
  1666. k,
  1667. v,
  1668. query_padding_mask,
  1669. key_padding_mask,
  1670. None,
  1671. 0.0,
  1672. None,
  1673. causal=causal,
  1674. window_size=window_size,
  1675. )
  1676. out_pt, attn_pt = attention_ref(
  1677. q,
  1678. k,
  1679. v,
  1680. query_padding_mask,
  1681. key_padding_mask,
  1682. None,
  1683. 0.0,
  1684. None,
  1685. causal=causal,
  1686. window_size=window_size,
  1687. upcast=False,
  1688. reorder_ops=True,
  1689. )
  1690. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  1691. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  1692. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  1693. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  1694. g = torch.randn_like(out)
  1695. do_o = (g.float() * out.float()).sum(-1)
  1696. test_backward = block_table is None
  1697. if test_backward:
  1698. (
  1699. dq_unpad,
  1700. dk_unpad,
  1701. dv_unpad,
  1702. ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
  1703. dq = dq_pad_fn(dq_unpad)
  1704. dk = dk_pad_fn(dk_unpad)
  1705. dv = dk_pad_fn(dv_unpad)
  1706. (
  1707. dq_ref,
  1708. dk_ref,
  1709. dv_ref,
  1710. ) = torch.autograd.grad(out_ref, (q, k, v), g)
  1711. (
  1712. dq_pt,
  1713. dk_pt,
  1714. dv_pt,
  1715. ) = torch.autograd.grad(out_pt, (q, k, v), g)
  1716. print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
  1717. print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
  1718. print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
  1719. print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
  1720. print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
  1721. print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
  1722. print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
  1723. print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
  1724. print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
  1725. print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
  1726. print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
  1727. print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
  1728. # Check that FlashAttention's numerical error is at most twice the numerical error
  1729. # of a Pytorch implementation.
  1730. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
  1731. if test_backward:
  1732. assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
  1733. assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
  1734. assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
  1735. # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
  1736. @pytest.mark.parametrize("dtype", [torch.float16])
  1737. @pytest.mark.parametrize("num_splits", [1, 0])
  1738. # @pytest.mark.parametrize("num_splits", [1])
  1739. @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
  1740. # @pytest.mark.parametrize("mha_type", ["mha"])
  1741. @pytest.mark.parametrize("new_kv", [False, True])
  1742. # @pytest.mark.parametrize("new_kv", [False])
  1743. @pytest.mark.parametrize("alibi", [False, True])
  1744. # @pytest.mark.parametrize("alibi", [False])
  1745. @pytest.mark.parametrize("local", [False])
  1746. # @pytest.mark.parametrize("local", [False])
  1747. @pytest.mark.parametrize("causal", [False, True])
  1748. # @pytest.mark.parametrize("causal", [False])
  1749. @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
  1750. # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
  1751. # @pytest.mark.parametrize("rotary_interleaved", [False, True])
  1752. @pytest.mark.parametrize("rotary_interleaved", [False])
  1753. # @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
  1754. @pytest.mark.parametrize("rotary_fraction", [0.0])
  1755. # @pytest.mark.parametrize("paged_kv_block_size", [None, 256])
  1756. # @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
  1757. @pytest.mark.parametrize("paged_kv_block_size", [None])
  1758. # @pytest.mark.parametrize("has_leftpad", [False, True])
  1759. @pytest.mark.parametrize("has_leftpad", [False])
  1760. # @pytest.mark.parametrize("has_batch_idx", [False, True])
  1761. @pytest.mark.parametrize("has_batch_idx", [False])
  1762. @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
  1763. # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
  1764. # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
  1765. # @pytest.mark.parametrize('d', [56, 80])
  1766. # @pytest.mark.parametrize("d", [128])
  1767. @pytest.mark.parametrize(
  1768. "seqlen_q,seqlen_k",
  1769. [
  1770. (1, 128),
  1771. (1, 339),
  1772. (3, 1024),
  1773. (64, 800),
  1774. (64, 256),
  1775. (3, 799),
  1776. (64, 2048),
  1777. (16, 20000),
  1778. (1, 128 * 1024),
  1779. (16, 128 * 1024),
  1780. (128, 128),
  1781. ],
  1782. )
  1783. # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
  1784. def test_flash_attn_kvcache(
  1785. seqlen_q,
  1786. seqlen_k,
  1787. d,
  1788. has_batch_idx,
  1789. has_leftpad,
  1790. paged_kv_block_size,
  1791. rotary_fraction,
  1792. rotary_interleaved,
  1793. seqlen_new_eq_seqlen_q,
  1794. causal,
  1795. local,
  1796. alibi,
  1797. new_kv,
  1798. mha_type,
  1799. num_splits,
  1800. dtype,
  1801. ):
  1802. if USE_TRITON_ROCM:
  1803. if paged_kv_block_size is not None:
  1804. pytest.skip("paged attention not supported on AMD's Triton Backend yet")
  1805. if local == True:
  1806. pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet")
  1807. if rotary_interleaved == True or rotary_fraction > 0.0:
  1808. pytest.skip("rotary embedding not supported on AMD's Triton Backend yet")
  1809. if has_leftpad == True:
  1810. pytest.skip("cache_leftpad not supported on AMD's Triton Backend yet")
  1811. if seqlen_q > seqlen_k and new_kv:
  1812. pytest.skip()
  1813. if not new_kv and rotary_fraction > 0.0:
  1814. pytest.skip()
  1815. if has_batch_idx and paged_kv_block_size is not None:
  1816. pytest.skip()
  1817. if has_leftpad and paged_kv_block_size is not None:
  1818. pytest.skip()
  1819. device = "cuda"
  1820. # set seed
  1821. torch.random.manual_seed(0)
  1822. batch_size = 2
  1823. batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
  1824. nheads = 6
  1825. # rotary_dim must be a multiple of 16, and must be <= d
  1826. rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
  1827. nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
  1828. assert nheads % nheads_k == 0
  1829. window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
  1830. q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
  1831. seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
  1832. if new_kv:
  1833. k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
  1834. v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
  1835. else:
  1836. k, v = None, None
  1837. if paged_kv_block_size is None:
  1838. k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
  1839. v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
  1840. block_table = None
  1841. else:
  1842. (
  1843. k_cache,
  1844. v_cache,
  1845. block_table,
  1846. k_cache_paged,
  1847. v_cache_paged,
  1848. num_blocks,
  1849. ) = _generate_block_kvcache(
  1850. seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
  1851. )
  1852. cache_seqlens = torch.randint(
  1853. 0 if new_kv else 1,
  1854. # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
  1855. (
  1856. (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
  1857. if new_kv
  1858. else (seqlen_k + 1)
  1859. ),
  1860. (batch_size,),
  1861. dtype=torch.int32,
  1862. device=device,
  1863. )
  1864. if has_leftpad:
  1865. cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)
  1866. if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)
  1867. for i in range(batch_size)])
  1868. else:
  1869. cache_leftpad = None
  1870. arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
  1871. cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
  1872. key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
  1873. if has_leftpad:
  1874. key_padding_mask = torch.logical_and(
  1875. key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)
  1876. )
  1877. if has_batch_idx:
  1878. cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
  1879. :batch_size
  1880. ]
  1881. else:
  1882. cache_batch_idx = None
  1883. if alibi:
  1884. alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
  1885. attn_bias = attn_bias_from_alibi_slopes(
  1886. alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad
  1887. )
  1888. else:
  1889. alibi_slopes, attn_bias = None, None
  1890. # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
  1891. if rotary_dim > 0:
  1892. angle = (
  1893. torch.rand(
  1894. seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size,
  1895. rotary_dim // 2,
  1896. device=device,
  1897. )
  1898. * 2
  1899. * math.pi
  1900. )
  1901. cos = torch.cos(angle).to(dtype=dtype)
  1902. sin = torch.sin(angle).to(dtype=dtype)
  1903. if causal or local:
  1904. q_ro = apply_rotary_emb(
  1905. q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
  1906. )
  1907. else:
  1908. q_ro = rearrange(
  1909. apply_rotary_emb(
  1910. rearrange(q, "b s h d -> b 1 (s h) d"),
  1911. cos,
  1912. sin,
  1913. seqlen_offsets=cache_seqlens,
  1914. interleaved=rotary_interleaved,
  1915. ),
  1916. "b 1 (s h) d -> b s h d",
  1917. s=seqlen_q,
  1918. )
  1919. # q_ro = q
  1920. k_ro = apply_rotary_emb(
  1921. k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
  1922. )
  1923. else:
  1924. cos, sin = None, None
  1925. q_ro, k_ro = q, k
  1926. # k_cache[:, 64:] = -1
  1927. k_cache_ref = (
  1928. k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
  1929. ).clone()
  1930. v_cache_ref = (
  1931. v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
  1932. ).clone()
  1933. if new_kv:
  1934. update_mask = torch.logical_and(
  1935. cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
  1936. )
  1937. k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
  1938. v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
  1939. k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  1940. v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
  1941. out = flash_attn_with_kvcache(
  1942. q,
  1943. k_cache if paged_kv_block_size is None else k_cache_paged,
  1944. v_cache if paged_kv_block_size is None else v_cache_paged,
  1945. k,
  1946. v,
  1947. rotary_cos=cos,
  1948. rotary_sin=sin,
  1949. cache_seqlens=cache_seqlens,
  1950. cache_batch_idx=cache_batch_idx,
  1951. cache_leftpad=cache_leftpad,
  1952. block_table=block_table,
  1953. causal=causal,
  1954. window_size=window_size,
  1955. rotary_interleaved=rotary_interleaved,
  1956. alibi_slopes=alibi_slopes,
  1957. num_splits=num_splits,
  1958. )
  1959. # out = flash_attn_with_kvcache(
  1960. # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
  1961. # )
  1962. # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
  1963. # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
  1964. # m = qk.amax(-1, keepdim=True)
  1965. # s_tmp = torch.exp((qk - m) / math.sqrt(d))
  1966. # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
  1967. # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
  1968. # probs = torch.softmax(qk, dim=-1)
  1969. out_ref, _ = attention_ref(
  1970. q_ro,
  1971. k_cache_rep,
  1972. v_cache_rep,
  1973. None,
  1974. key_padding_mask,
  1975. attn_bias,
  1976. 0.0,
  1977. None,
  1978. causal=causal,
  1979. window_size=window_size,
  1980. key_leftpad=cache_leftpad,
  1981. )
  1982. out_pt, _ = attention_ref(
  1983. q_ro,
  1984. k_cache_rep,
  1985. v_cache_rep,
  1986. None,
  1987. key_padding_mask,
  1988. attn_bias,
  1989. 0.0,
  1990. None,
  1991. causal=causal,
  1992. window_size=window_size,
  1993. upcast=False,
  1994. reorder_ops=True,
  1995. key_leftpad=cache_leftpad,
  1996. )
  1997. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  1998. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  1999. print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
  2000. print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
  2001. # Check that FlashAttention's numerical error is at most twice the numerical error
  2002. # of a Pytorch implementation.
  2003. if new_kv:
  2004. if paged_kv_block_size is None:
  2005. k_cache_select = (
  2006. k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
  2007. )
  2008. v_cache_select = (
  2009. v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
  2010. )
  2011. else:
  2012. k_cache_select = rearrange(
  2013. k_cache_paged[block_table.to(dtype=torch.long).flatten()],
  2014. "(b nblocks) block_size ... -> b (nblocks block_size) ...",
  2015. b=batch_size,
  2016. )[:, :seqlen_k]
  2017. v_cache_select = rearrange(
  2018. v_cache_paged[block_table.to(dtype=torch.long).flatten()],
  2019. "(b nblocks) block_size ... -> b (nblocks block_size) ...",
  2020. b=batch_size,
  2021. )[:, :seqlen_k]
  2022. assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
  2023. assert torch.equal(v_cache_select, v_cache_ref)
  2024. mult = 3 if not alibi else 5
  2025. assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
  2026. def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):
  2027. num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
  2028. k_cache_paged = torch.randn(
  2029. num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
  2030. )
  2031. v_cache_paged = torch.randn(
  2032. num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
  2033. )
  2034. block_table = rearrange(
  2035. torch.randperm(num_blocks, dtype=torch.int32, device=device),
  2036. "(b nblocks) -> b nblocks",
  2037. b=batch_size,
  2038. )
  2039. k_cache = rearrange(
  2040. # pytorch 1.12 doesn't have indexing with int32
  2041. k_cache_paged[block_table.to(dtype=torch.long).flatten()],
  2042. "(b nblocks) block_size ... -> b (nblocks block_size) ...",
  2043. b=batch_size,
  2044. )[:, :seqlen_k]
  2045. v_cache = rearrange(
  2046. v_cache_paged[block_table.to(dtype=torch.long).flatten()],
  2047. "(b nblocks) block_size ... -> b (nblocks block_size) ...",
  2048. b=batch_size,
  2049. )[:, :seqlen_k]
  2050. return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks