flash_attn_interface.py 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Sequence, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. import os
  6. # isort: off
  7. # We need to import the CUDA kernels after importing torch
  8. USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
  9. if USE_TRITON_ROCM:
  10. from .flash_attn_triton_amd import interface_fa as flash_attn_gpu
  11. else:
  12. import flash_attn_2_cuda as flash_attn_gpu
  13. # isort: on
  14. def maybe_contiguous(x):
  15. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  16. def _get_block_size_n(device, head_dim, is_dropout, is_causal):
  17. # This should match the block sizes in the CUDA kernel
  18. assert head_dim <= 256
  19. major, minor = torch.cuda.get_device_capability(device)
  20. is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
  21. is_sm80 = major == 8 and minor == 0
  22. is_sm90 = major == 9 and minor == 0
  23. if head_dim <= 32:
  24. return 128
  25. if head_dim <= 64:
  26. return 128 if not is_dropout else 64
  27. elif head_dim <= 96:
  28. return 64
  29. elif head_dim <= 128:
  30. if is_sm8x:
  31. return 64 if (not is_dropout and is_causal) else 32
  32. else:
  33. return 64 if not is_dropout else 32
  34. elif head_dim <= 160:
  35. if is_sm8x:
  36. return 64
  37. else:
  38. return 32
  39. elif head_dim <= 192:
  40. return 64
  41. elif head_dim <= 224:
  42. return 64
  43. elif head_dim <= 256:
  44. return 64
  45. def round_multiple(x, m):
  46. return (x + m - 1) // m * m
  47. # torch.compile() support is only enabled for pytorch >= 2.4
  48. # The reason for this is that we are using the new custom_op and register_fake
  49. # APIs, which support inplace modification of inputs in the function itself
  50. if torch.__version__ >= "2.4.0":
  51. _torch_custom_op_wrapper = torch.library.custom_op
  52. _torch_register_fake_wrapper = torch.library.register_fake
  53. else:
  54. def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
  55. def wrap(func):
  56. return func
  57. if fn is None:
  58. return wrap
  59. return fn
  60. def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
  61. def wrap(func):
  62. return func
  63. if fn is None:
  64. return wrap
  65. return fn
  66. _torch_custom_op_wrapper = noop_custom_op_wrapper
  67. _torch_register_fake_wrapper = noop_register_fake_wrapper
  68. @_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda")
  69. def _flash_attn_forward(
  70. q: torch.Tensor,
  71. k: torch.Tensor,
  72. v: torch.Tensor,
  73. dropout_p: float,
  74. softmax_scale: float,
  75. causal: bool,
  76. window_size_left: int,
  77. window_size_right: int,
  78. softcap: float,
  79. alibi_slopes: Optional[torch.Tensor],
  80. return_softmax: bool
  81. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  82. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  83. out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
  84. q,
  85. k,
  86. v,
  87. None,
  88. alibi_slopes,
  89. dropout_p,
  90. softmax_scale,
  91. causal,
  92. window_size_left,
  93. window_size_right,
  94. softcap,
  95. return_softmax,
  96. None,
  97. )
  98. return out, softmax_lse, S_dmask, rng_state
  99. @_torch_register_fake_wrapper("flash_attn::_flash_attn_forward")
  100. def _flash_attn_forward_fake(
  101. q: torch.Tensor,
  102. k: torch.Tensor,
  103. v: torch.Tensor,
  104. dropout_p: float,
  105. softmax_scale: float,
  106. causal: bool,
  107. window_size_left: int,
  108. window_size_right: int,
  109. softcap: float,
  110. alibi_slopes: Optional[torch.Tensor],
  111. return_softmax: bool
  112. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  113. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  114. batch_size, seqlen_q, num_heads, head_size = q.shape
  115. seqlen_k = k.shape[1]
  116. out = torch.empty_like(q)
  117. softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)
  118. p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
  119. if return_softmax:
  120. p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
  121. rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
  122. return out, softmax_lse, p, rng_state
  123. if torch.__version__ >= "2.4.0":
  124. _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward
  125. else:
  126. _wrapped_flash_attn_forward = _flash_attn_forward
  127. @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
  128. def _flash_attn_varlen_forward(
  129. q: torch.Tensor,
  130. k: torch.Tensor,
  131. v: torch.Tensor,
  132. cu_seqlens_q: torch.Tensor,
  133. cu_seqlens_k: torch.Tensor,
  134. max_seqlen_q: int,
  135. max_seqlen_k: int,
  136. dropout_p: float,
  137. softmax_scale: float,
  138. causal: bool,
  139. window_size_left: int = -1,
  140. window_size_right: int = -1,
  141. softcap: float = 0.0,
  142. alibi_slopes: Optional[torch.Tensor] = None,
  143. return_softmax: bool = False,
  144. block_table: Optional[torch.Tensor] = None,
  145. leftpad_k: Optional[torch.Tensor] = None,
  146. seqused_k: Optional[torch.Tensor] = None,
  147. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  148. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  149. out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
  150. q,
  151. k,
  152. v,
  153. None,
  154. cu_seqlens_q,
  155. cu_seqlens_k,
  156. seqused_k,
  157. leftpad_k,
  158. block_table,
  159. alibi_slopes,
  160. max_seqlen_q,
  161. max_seqlen_k,
  162. dropout_p,
  163. softmax_scale,
  164. False,
  165. causal,
  166. window_size_left,
  167. window_size_right,
  168. softcap,
  169. return_softmax,
  170. None,
  171. )
  172. # if out.isnan().any() or softmax_lse.isnan().any():
  173. # breakpoint()
  174. return out, softmax_lse, S_dmask, rng_state
  175. @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward")
  176. def _flash_attn_varlen_forward_fake(
  177. q: torch.Tensor,
  178. k: torch.Tensor,
  179. v: torch.Tensor,
  180. cu_seqlens_q: torch.Tensor,
  181. cu_seqlens_k: torch.Tensor,
  182. max_seqlen_q: int,
  183. max_seqlen_k: int,
  184. dropout_p: float,
  185. softmax_scale: float,
  186. causal: bool,
  187. window_size_left: int = -1,
  188. window_size_right: int = -1,
  189. softcap: float = 0.0,
  190. alibi_slopes: Optional[torch.Tensor] = None,
  191. return_softmax: bool = False,
  192. block_table: Optional[torch.Tensor] = None,
  193. leftpad_k: Optional[torch.Tensor] = None,
  194. seqused_k: Optional[torch.Tensor] = None,
  195. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  196. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  197. paged_kv = block_table is not None
  198. batch_size = cu_seqlens_q.numel() - 1
  199. total_q, num_heads, _ = q.shape
  200. out = torch.empty_like(q)
  201. softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
  202. p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
  203. seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
  204. seqlen_k_rounded = round_multiple(max_seqlen_k, 128)
  205. if return_softmax:
  206. p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout)
  207. rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
  208. return out, softmax_lse, p, rng_state
  209. if torch.__version__ >= "2.4.0":
  210. _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
  211. else:
  212. _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
  213. @_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
  214. def _flash_attn_backward(
  215. dout: torch.Tensor,
  216. q: torch.Tensor,
  217. k: torch.Tensor,
  218. v: torch.Tensor,
  219. out: torch.Tensor,
  220. softmax_lse: torch.Tensor,
  221. dq: Optional[torch.Tensor],
  222. dk: Optional[torch.Tensor],
  223. dv: Optional[torch.Tensor],
  224. dropout_p: float,
  225. softmax_scale: float,
  226. causal: bool,
  227. window_size_left: int,
  228. window_size_right: int,
  229. softcap: float,
  230. alibi_slopes: Optional[torch.Tensor],
  231. deterministic: bool,
  232. rng_state: Optional[torch.Tensor] = None,
  233. ) -> torch.Tensor:
  234. # dq, dk, dv are allocated by us so they should already be contiguous
  235. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  236. (
  237. dq,
  238. dk,
  239. dv,
  240. softmax_d,
  241. ) = flash_attn_gpu.bwd(
  242. dout,
  243. q,
  244. k,
  245. v,
  246. out,
  247. softmax_lse,
  248. dq,
  249. dk,
  250. dv,
  251. alibi_slopes,
  252. dropout_p,
  253. softmax_scale,
  254. causal,
  255. window_size_left,
  256. window_size_right,
  257. softcap,
  258. deterministic,
  259. None,
  260. rng_state,
  261. )
  262. return softmax_d
  263. @_torch_register_fake_wrapper("flash_attn::_flash_attn_backward")
  264. def _flash_attn_backward_fake(
  265. dout: torch.Tensor,
  266. q: torch.Tensor,
  267. k: torch.Tensor,
  268. v: torch.Tensor,
  269. out: torch.Tensor,
  270. softmax_lse: torch.Tensor,
  271. dq: Optional[torch.Tensor],
  272. dk: Optional[torch.Tensor],
  273. dv: Optional[torch.Tensor],
  274. dropout_p: float,
  275. softmax_scale: float,
  276. causal: bool,
  277. window_size_left: int,
  278. window_size_right: int,
  279. softcap: float,
  280. alibi_slopes: Optional[torch.Tensor],
  281. deterministic: bool,
  282. rng_state: Optional[torch.Tensor] = None,
  283. ) -> torch.Tensor:
  284. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  285. if dq is None:
  286. dq = torch.empty_like(q)
  287. if dk is None:
  288. dk = torch.empty_like(k)
  289. if dv is None:
  290. dv = torch.empty_like(v)
  291. batch_size, seqlen_q, num_heads, _ = q.shape
  292. softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)
  293. return softmax_d
  294. if torch.__version__ >= "2.4.0":
  295. _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward
  296. else:
  297. _wrapped_flash_attn_backward = _flash_attn_backward
  298. @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
  299. def _flash_attn_varlen_backward(
  300. dout: torch.Tensor,
  301. q: torch.Tensor,
  302. k: torch.Tensor,
  303. v: torch.Tensor,
  304. out: torch.Tensor,
  305. softmax_lse: torch.Tensor,
  306. dq: Optional[torch.Tensor],
  307. dk: Optional[torch.Tensor],
  308. dv: Optional[torch.Tensor],
  309. cu_seqlens_q: torch.Tensor,
  310. cu_seqlens_k: torch.Tensor,
  311. max_seqlen_q: int,
  312. max_seqlen_k: int,
  313. dropout_p: float,
  314. softmax_scale: float,
  315. causal: bool,
  316. window_size_left: int,
  317. window_size_right: int,
  318. softcap: float,
  319. alibi_slopes: Optional[torch.Tensor],
  320. deterministic: bool,
  321. rng_state: Optional[torch.Tensor] = None,
  322. ) -> torch.Tensor:
  323. # dq, dk, dv are allocated by us so they should already be contiguous
  324. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  325. (
  326. dq,
  327. dk,
  328. dv,
  329. softmax_d,
  330. ) = flash_attn_gpu.varlen_bwd(
  331. dout,
  332. q,
  333. k,
  334. v,
  335. out,
  336. softmax_lse,
  337. dq,
  338. dk,
  339. dv,
  340. cu_seqlens_q,
  341. cu_seqlens_k,
  342. alibi_slopes,
  343. max_seqlen_q,
  344. max_seqlen_k,
  345. dropout_p,
  346. softmax_scale,
  347. False,
  348. causal,
  349. window_size_left,
  350. window_size_right,
  351. softcap,
  352. deterministic,
  353. None,
  354. rng_state,
  355. )
  356. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  357. # breakpoint()
  358. return softmax_d
  359. @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward")
  360. def _flash_attn_varlen_backward_fake(
  361. dout: torch.Tensor,
  362. q: torch.Tensor,
  363. k: torch.Tensor,
  364. v: torch.Tensor,
  365. out: torch.Tensor,
  366. softmax_lse: torch.Tensor,
  367. dq: Optional[torch.Tensor],
  368. dk: Optional[torch.Tensor],
  369. dv: Optional[torch.Tensor],
  370. cu_seqlens_q: torch.Tensor,
  371. cu_seqlens_k: torch.Tensor,
  372. max_seqlen_q: int,
  373. max_seqlen_k: int,
  374. dropout_p: float,
  375. softmax_scale: float,
  376. causal: bool,
  377. window_size_left: int,
  378. window_size_right: int,
  379. softcap: float,
  380. alibi_slopes: Optional[torch.Tensor],
  381. deterministic: bool,
  382. rng_state: Optional[torch.Tensor] = None,
  383. ) -> torch.Tensor:
  384. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  385. batch_size = cu_seqlens_q.numel() - 1
  386. total_q, num_heads, _ = q.shape
  387. if dq is None:
  388. dq = torch.empty_like(q)
  389. if dk is None:
  390. dk = torch.empty_like(k)
  391. if dv is None:
  392. dv = torch.empty_like(v)
  393. softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)
  394. return softmax_d
  395. if torch.__version__ >= "2.4.0":
  396. _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward
  397. else:
  398. _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward
  399. class FlashAttnQKVPackedFunc(torch.autograd.Function):
  400. @staticmethod
  401. def forward(
  402. ctx,
  403. qkv,
  404. dropout_p,
  405. softmax_scale,
  406. causal,
  407. window_size,
  408. softcap,
  409. alibi_slopes,
  410. deterministic,
  411. return_softmax,
  412. ):
  413. is_grad = torch.is_grad_enabled() and qkv.requires_grad
  414. if softmax_scale is None:
  415. softmax_scale = qkv.shape[-1] ** (-0.5)
  416. q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()
  417. head_size_og = q.size(3)
  418. if head_size_og % 8 != 0:
  419. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  420. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  421. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  422. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
  423. q,
  424. k,
  425. v,
  426. dropout_p,
  427. softmax_scale,
  428. causal=causal,
  429. window_size_left=window_size[0],
  430. window_size_right=window_size[1],
  431. softcap=softcap,
  432. alibi_slopes=alibi_slopes,
  433. return_softmax=return_softmax and dropout_p > 0,
  434. )
  435. if is_grad:
  436. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  437. ctx.dropout_p = dropout_p
  438. ctx.softmax_scale = softmax_scale
  439. ctx.causal = causal
  440. ctx.window_size = window_size
  441. ctx.softcap = softcap
  442. ctx.alibi_slopes = alibi_slopes
  443. ctx.deterministic = deterministic
  444. out = out_padded[..., :head_size_og]
  445. return out if not return_softmax else (out, softmax_lse, S_dmask)
  446. @staticmethod
  447. def backward(ctx, dout, *args):
  448. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  449. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  450. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  451. head_size_og = dout.size(3)
  452. dout_padded = dout
  453. if head_size_og % 8 != 0:
  454. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  455. _wrapped_flash_attn_backward(
  456. dout_padded,
  457. q,
  458. k,
  459. v,
  460. out,
  461. softmax_lse,
  462. dqkv[:, :, 0],
  463. dqkv[:, :, 1],
  464. dqkv[:, :, 2],
  465. ctx.dropout_p,
  466. ctx.softmax_scale,
  467. ctx.causal,
  468. ctx.window_size[0],
  469. ctx.window_size[1],
  470. ctx.softcap,
  471. ctx.alibi_slopes,
  472. ctx.deterministic,
  473. rng_state=rng_state,
  474. )
  475. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  476. return dqkv, None, None, None, None, None, None, None, None
  477. class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
  478. @staticmethod
  479. def forward(
  480. ctx,
  481. qkv,
  482. cu_seqlens,
  483. max_seqlen,
  484. dropout_p,
  485. softmax_scale,
  486. causal,
  487. window_size,
  488. softcap,
  489. alibi_slopes,
  490. deterministic,
  491. return_softmax,
  492. ):
  493. is_grad = torch.is_grad_enabled() and qkv.requires_grad
  494. if softmax_scale is None:
  495. softmax_scale = qkv.shape[-1] ** (-0.5)
  496. q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
  497. head_size_og = q.size(2)
  498. if head_size_og % 8 != 0:
  499. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  500. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  501. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  502. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
  503. q,
  504. k,
  505. v,
  506. cu_seqlens,
  507. cu_seqlens,
  508. max_seqlen,
  509. max_seqlen,
  510. dropout_p,
  511. softmax_scale,
  512. causal=causal,
  513. window_size_left=window_size[0],
  514. window_size_right=window_size[1],
  515. softcap=softcap,
  516. alibi_slopes=alibi_slopes,
  517. return_softmax=return_softmax and dropout_p > 0,
  518. block_table=None,
  519. )
  520. if is_grad:
  521. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
  522. ctx.dropout_p = dropout_p
  523. ctx.max_seqlen = max_seqlen
  524. ctx.softmax_scale = softmax_scale
  525. ctx.causal = causal
  526. ctx.window_size = window_size
  527. ctx.softcap = softcap
  528. ctx.alibi_slopes = alibi_slopes
  529. ctx.deterministic = deterministic
  530. out = out_padded[..., :head_size_og]
  531. return out if not return_softmax else (out, softmax_lse, S_dmask)
  532. @staticmethod
  533. def backward(ctx, dout, *args):
  534. q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
  535. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  536. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  537. head_size_og = dout.size(2)
  538. dout_padded = dout
  539. if head_size_og % 8 != 0:
  540. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  541. _wrapped_flash_attn_varlen_backward(
  542. dout_padded,
  543. q,
  544. k,
  545. v,
  546. out,
  547. softmax_lse,
  548. dqkv[:, 0],
  549. dqkv[:, 1],
  550. dqkv[:, 2],
  551. cu_seqlens,
  552. cu_seqlens,
  553. ctx.max_seqlen,
  554. ctx.max_seqlen,
  555. ctx.dropout_p,
  556. ctx.softmax_scale,
  557. ctx.causal,
  558. ctx.window_size[0],
  559. ctx.window_size[1],
  560. ctx.softcap,
  561. ctx.alibi_slopes,
  562. ctx.deterministic,
  563. rng_state=rng_state,
  564. )
  565. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  566. return dqkv, None, None, None, None, None, None, None, None, None, None
  567. class FlashAttnKVPackedFunc(torch.autograd.Function):
  568. @staticmethod
  569. def forward(
  570. ctx,
  571. q,
  572. kv,
  573. dropout_p,
  574. softmax_scale,
  575. causal,
  576. window_size,
  577. softcap,
  578. alibi_slopes,
  579. deterministic,
  580. return_softmax,
  581. ):
  582. is_grad = torch.is_grad_enabled() and any(
  583. x.requires_grad for x in [q, kv]
  584. )
  585. if softmax_scale is None:
  586. softmax_scale = q.shape[-1] ** (-0.5)
  587. k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach()
  588. head_size_og = q.size(3)
  589. if head_size_og % 8 != 0:
  590. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  591. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  592. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  593. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
  594. q,
  595. k,
  596. v,
  597. dropout_p,
  598. softmax_scale,
  599. causal=causal,
  600. window_size_left=window_size[0],
  601. window_size_right=window_size[1],
  602. softcap=softcap,
  603. alibi_slopes=alibi_slopes,
  604. return_softmax=return_softmax and dropout_p > 0,
  605. )
  606. if is_grad:
  607. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  608. ctx.dropout_p = dropout_p
  609. ctx.softmax_scale = softmax_scale
  610. ctx.causal = causal
  611. ctx.window_size = window_size
  612. ctx.softcap = softcap
  613. ctx.alibi_slopes = alibi_slopes
  614. ctx.deterministic = deterministic
  615. out = out_padded[..., :head_size_og]
  616. return out if not return_softmax else (out, softmax_lse, S_dmask)
  617. @staticmethod
  618. def backward(ctx, dout, *args):
  619. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  620. dq = torch.empty_like(q)
  621. kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
  622. dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
  623. head_size_og = dout.size(3)
  624. dout_padded = dout
  625. if head_size_og % 8 != 0:
  626. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  627. _wrapped_flash_attn_backward(
  628. dout_padded,
  629. q,
  630. k,
  631. v,
  632. out,
  633. softmax_lse,
  634. dq,
  635. dkv[:, :, 0],
  636. dkv[:, :, 1],
  637. ctx.dropout_p,
  638. ctx.softmax_scale,
  639. ctx.causal,
  640. ctx.window_size[0],
  641. ctx.window_size[1],
  642. ctx.softcap,
  643. ctx.alibi_slopes,
  644. ctx.deterministic,
  645. rng_state=rng_state,
  646. )
  647. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  648. dkv = dkv[..., : dout.shape[-1]]
  649. return dq, dkv, None, None, None, None, None, None, None, None
  650. class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
  651. @staticmethod
  652. def forward(
  653. ctx,
  654. q,
  655. kv,
  656. cu_seqlens_q,
  657. cu_seqlens_k,
  658. max_seqlen_q,
  659. max_seqlen_k,
  660. dropout_p,
  661. softmax_scale,
  662. causal,
  663. window_size,
  664. softcap,
  665. alibi_slopes,
  666. deterministic,
  667. return_softmax,
  668. ):
  669. is_grad = torch.is_grad_enabled() and any(
  670. x.requires_grad for x in [q, kv]
  671. )
  672. if softmax_scale is None:
  673. softmax_scale = q.shape[-1] ** (-0.5)
  674. k, v = kv[:, 0].detach(), kv[:, 1].detach()
  675. head_size_og = q.size(2)
  676. if head_size_og % 8 != 0:
  677. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  678. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  679. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  680. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
  681. q,
  682. k,
  683. v,
  684. cu_seqlens_q,
  685. cu_seqlens_k,
  686. max_seqlen_q,
  687. max_seqlen_k,
  688. dropout_p,
  689. softmax_scale,
  690. causal=causal,
  691. window_size_left=window_size[0],
  692. window_size_right=window_size[1],
  693. softcap=softcap,
  694. alibi_slopes=alibi_slopes,
  695. return_softmax=return_softmax and dropout_p > 0,
  696. block_table=None,
  697. )
  698. if is_grad:
  699. ctx.save_for_backward(
  700. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
  701. )
  702. ctx.dropout_p = dropout_p
  703. ctx.max_seqlen_q = max_seqlen_q
  704. ctx.max_seqlen_k = max_seqlen_k
  705. ctx.softmax_scale = softmax_scale
  706. ctx.causal = causal
  707. ctx.window_size = window_size
  708. ctx.softcap = softcap
  709. ctx.alibi_slopes = alibi_slopes
  710. ctx.deterministic = deterministic
  711. out = out_padded[..., :head_size_og]
  712. return out if not return_softmax else (out, softmax_lse, S_dmask)
  713. @staticmethod
  714. def backward(ctx, dout, *args):
  715. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
  716. dq = torch.empty_like(q)
  717. kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
  718. dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
  719. head_size_og = dout.size(2)
  720. dout_padded = dout
  721. if head_size_og % 8 != 0:
  722. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  723. _wrapped_flash_attn_varlen_backward(
  724. dout_padded,
  725. q,
  726. k,
  727. v,
  728. out,
  729. softmax_lse,
  730. dq,
  731. dkv[:, 0],
  732. dkv[:, 1],
  733. cu_seqlens_q,
  734. cu_seqlens_k,
  735. ctx.max_seqlen_q,
  736. ctx.max_seqlen_k,
  737. ctx.dropout_p,
  738. ctx.softmax_scale,
  739. ctx.causal,
  740. ctx.window_size[0],
  741. ctx.window_size[1],
  742. ctx.softcap,
  743. ctx.alibi_slopes,
  744. ctx.deterministic,
  745. rng_state=rng_state,
  746. )
  747. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  748. dkv = dkv[..., : dout.shape[-1]]
  749. return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
  750. class FlashAttnFunc(torch.autograd.Function):
  751. @staticmethod
  752. def forward(
  753. ctx,
  754. q,
  755. k,
  756. v,
  757. dropout_p,
  758. softmax_scale,
  759. causal,
  760. window_size,
  761. softcap,
  762. alibi_slopes,
  763. deterministic,
  764. return_softmax,
  765. ):
  766. is_grad = torch.is_grad_enabled() and any(
  767. x.requires_grad for x in [q, k, v]
  768. )
  769. if softmax_scale is None:
  770. softmax_scale = q.shape[-1] ** (-0.5)
  771. head_size_og = q.size(3)
  772. if head_size_og % 8 != 0:
  773. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  774. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  775. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  776. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
  777. q,
  778. k,
  779. v,
  780. dropout_p,
  781. softmax_scale,
  782. causal=causal,
  783. window_size_left=window_size[0],
  784. window_size_right=window_size[1],
  785. softcap=softcap,
  786. alibi_slopes=alibi_slopes,
  787. return_softmax=return_softmax and dropout_p > 0,
  788. )
  789. if is_grad:
  790. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  791. ctx.dropout_p = dropout_p
  792. ctx.softmax_scale = softmax_scale
  793. ctx.causal = causal
  794. ctx.window_size = window_size
  795. ctx.softcap = softcap
  796. ctx.alibi_slopes = alibi_slopes
  797. ctx.deterministic = deterministic
  798. out = out_padded[..., :head_size_og]
  799. return out if not return_softmax else (out, softmax_lse, S_dmask)
  800. @staticmethod
  801. def backward(ctx, dout, *args):
  802. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  803. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  804. head_size_og = dout.size(3)
  805. dout_padded = dout
  806. if head_size_og % 8 != 0:
  807. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  808. _wrapped_flash_attn_backward(
  809. dout_padded,
  810. q,
  811. k,
  812. v,
  813. out,
  814. softmax_lse,
  815. dq,
  816. dk,
  817. dv,
  818. ctx.dropout_p,
  819. ctx.softmax_scale,
  820. ctx.causal,
  821. ctx.window_size[0],
  822. ctx.window_size[1],
  823. ctx.softcap,
  824. ctx.alibi_slopes,
  825. ctx.deterministic,
  826. rng_state=rng_state,
  827. )
  828. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  829. dk = dk[..., : dout.shape[-1]]
  830. dv = dv[..., : dout.shape[-1]]
  831. return dq, dk, dv, None, None, None, None, None, None, None, None
  832. class FlashAttnVarlenFunc(torch.autograd.Function):
  833. @staticmethod
  834. def forward(
  835. ctx,
  836. q,
  837. k,
  838. v,
  839. cu_seqlens_q,
  840. cu_seqlens_k,
  841. max_seqlen_q,
  842. max_seqlen_k,
  843. dropout_p,
  844. softmax_scale,
  845. causal,
  846. window_size,
  847. softcap,
  848. alibi_slopes,
  849. deterministic,
  850. return_softmax,
  851. block_table,
  852. ):
  853. is_grad = torch.is_grad_enabled() and any(
  854. x.requires_grad for x in [q, k, v]
  855. )
  856. if softmax_scale is None:
  857. softmax_scale = q.shape[-1] ** (-0.5)
  858. head_size_og = q.size(2)
  859. if head_size_og % 8 != 0:
  860. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  861. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  862. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  863. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
  864. q,
  865. k,
  866. v,
  867. cu_seqlens_q,
  868. cu_seqlens_k,
  869. max_seqlen_q,
  870. max_seqlen_k,
  871. dropout_p,
  872. softmax_scale,
  873. causal=causal,
  874. window_size_left=window_size[0],
  875. window_size_right=window_size[1],
  876. softcap=softcap,
  877. alibi_slopes=alibi_slopes,
  878. return_softmax=return_softmax and dropout_p > 0,
  879. block_table=block_table,
  880. )
  881. if is_grad:
  882. ctx.save_for_backward(
  883. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
  884. )
  885. ctx.dropout_p = dropout_p
  886. ctx.max_seqlen_q = max_seqlen_q
  887. ctx.max_seqlen_k = max_seqlen_k
  888. ctx.softmax_scale = softmax_scale
  889. ctx.causal = causal
  890. ctx.window_size = window_size
  891. ctx.softcap = softcap
  892. ctx.alibi_slopes = alibi_slopes
  893. ctx.deterministic = deterministic
  894. out = out_padded[..., :head_size_og]
  895. return out if not return_softmax else (out, softmax_lse, S_dmask)
  896. @staticmethod
  897. def backward(ctx, dout, *args):
  898. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
  899. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  900. head_size_og = dout.size(2)
  901. dout_padded = dout
  902. if head_size_og % 8 != 0:
  903. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  904. _wrapped_flash_attn_varlen_backward(
  905. dout_padded,
  906. q,
  907. k,
  908. v,
  909. out,
  910. softmax_lse,
  911. dq,
  912. dk,
  913. dv,
  914. cu_seqlens_q,
  915. cu_seqlens_k,
  916. ctx.max_seqlen_q,
  917. ctx.max_seqlen_k,
  918. ctx.dropout_p,
  919. ctx.softmax_scale,
  920. ctx.causal,
  921. ctx.window_size[0],
  922. ctx.window_size[1],
  923. ctx.softcap,
  924. ctx.alibi_slopes,
  925. ctx.deterministic,
  926. rng_state=rng_state,
  927. )
  928. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  929. dk = dk[..., : dout.shape[-1]]
  930. dv = dv[..., : dout.shape[-1]]
  931. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
  932. def flash_attn_qkvpacked_func(
  933. qkv,
  934. dropout_p=0.0,
  935. softmax_scale=None,
  936. causal=False,
  937. window_size=(-1, -1), # -1 means infinite context window
  938. softcap=0.0, # <=0.0 means deactivate
  939. alibi_slopes=None,
  940. deterministic=False,
  941. return_attn_probs=False,
  942. ):
  943. """dropout_p should be set to 0.0 during evaluation
  944. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  945. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  946. of the gradients of Q, K, V.
  947. For multi-query and grouped-query attention (MQA/GQA), please see
  948. flash_attn_kvpacked_func and flash_attn_func.
  949. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  950. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  951. Arguments:
  952. qkv: (batch_size, seqlen, 3, nheads, headdim)
  953. dropout_p: float. Dropout probability.
  954. softmax_scale: float. The scaling of QK^T before applying softmax.
  955. Default to 1 / sqrt(headdim).
  956. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  957. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  958. softcap: float. Anything > 0 activates softcapping attention.
  959. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
  960. the attention score of query i and key j.
  961. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  962. which is slightly slower and uses more memory. The forward pass is always deterministic.
  963. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  964. testing only. The returned probabilities are not guaranteed to be correct
  965. (they might not have the right scaling).
  966. Return:
  967. out: (batch_size, seqlen, nheads, headdim).
  968. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  969. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  970. normalization factor).
  971. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  972. The output of softmax (possibly with different scaling). It also encodes the dropout
  973. pattern (negative means that location was dropped, nonnegative means it was kept).
  974. """
  975. return FlashAttnQKVPackedFunc.apply(
  976. qkv,
  977. dropout_p,
  978. softmax_scale,
  979. causal,
  980. window_size,
  981. softcap,
  982. alibi_slopes,
  983. deterministic,
  984. return_attn_probs,
  985. )
  986. def flash_attn_kvpacked_func(
  987. q,
  988. kv,
  989. dropout_p=0.0,
  990. softmax_scale=None,
  991. causal=False,
  992. window_size=(-1, -1), # -1 means infinite context window
  993. softcap=0.0, # 0.0 means deactivated
  994. alibi_slopes=None,
  995. deterministic=False,
  996. return_attn_probs=False,
  997. ):
  998. """dropout_p should be set to 0.0 during evaluation
  999. If K, V are already stacked into 1 tensor, this function will be faster than
  1000. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  1001. of the gradients of K, V.
  1002. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1003. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1004. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1005. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1006. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1007. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1008. 1 1 1 1 0
  1009. 1 1 1 1 1
  1010. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1011. 0 0
  1012. 0 0
  1013. 0 0
  1014. 1 0
  1015. 1 1
  1016. If the row of the mask is all zero, the output will be zero.
  1017. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1018. will only attend to keys between
  1019. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1020. Arguments:
  1021. q: (batch_size, seqlen, nheads, headdim)
  1022. kv: (batch_size, seqlen, 2, nheads_k, headdim)
  1023. dropout_p: float. Dropout probability.
  1024. softmax_scale: float. The scaling of QK^T before applying softmax.
  1025. Default to 1 / sqrt(headdim).
  1026. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1027. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1028. softcap: float. Anything > 0 activates softcapping attention.
  1029. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1030. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1031. is added to the attention score of query i and key j.
  1032. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1033. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1034. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1035. testing only. The returned probabilities are not guaranteed to be correct
  1036. (they might not have the right scaling).
  1037. Return:
  1038. out: (batch_size, seqlen, nheads, headdim).
  1039. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  1040. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1041. normalization factor).
  1042. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1043. The output of softmax (possibly with different scaling). It also encodes the dropout
  1044. pattern (negative means that location was dropped, nonnegative means it was kept).
  1045. """
  1046. return FlashAttnKVPackedFunc.apply(
  1047. q,
  1048. kv,
  1049. dropout_p,
  1050. softmax_scale,
  1051. causal,
  1052. window_size,
  1053. softcap,
  1054. alibi_slopes,
  1055. deterministic,
  1056. return_attn_probs,
  1057. )
  1058. def flash_attn_func(
  1059. q,
  1060. k,
  1061. v,
  1062. dropout_p=0.0,
  1063. softmax_scale=None,
  1064. causal=False,
  1065. window_size=(-1, -1), # -1 means infinite context window
  1066. softcap=0.0, # 0.0 means deactivated
  1067. alibi_slopes=None,
  1068. deterministic=False,
  1069. return_attn_probs=False,
  1070. ):
  1071. """dropout_p should be set to 0.0 during evaluation
  1072. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1073. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1074. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1075. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1076. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1077. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1078. 1 1 1 1 0
  1079. 1 1 1 1 1
  1080. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1081. 0 0
  1082. 0 0
  1083. 0 0
  1084. 1 0
  1085. 1 1
  1086. If the row of the mask is all zero, the output will be zero.
  1087. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1088. will only attend to keys between
  1089. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1090. Arguments:
  1091. q: (batch_size, seqlen, nheads, headdim)
  1092. k: (batch_size, seqlen, nheads_k, headdim)
  1093. v: (batch_size, seqlen, nheads_k, headdim)
  1094. dropout_p: float. Dropout probability.
  1095. softmax_scale: float. The scaling of QK^T before applying softmax.
  1096. Default to 1 / sqrt(headdim).
  1097. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1098. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1099. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1100. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1101. is added to the attention score of query i and key j.
  1102. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1103. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1104. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1105. testing only. The returned probabilities are not guaranteed to be correct
  1106. (they might not have the right scaling).
  1107. Return:
  1108. out: (batch_size, seqlen, nheads, headdim).
  1109. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  1110. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1111. normalization factor).
  1112. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1113. The output of softmax (possibly with different scaling). It also encodes the dropout
  1114. pattern (negative means that location was dropped, nonnegative means it was kept).
  1115. """
  1116. return FlashAttnFunc.apply(
  1117. q,
  1118. k,
  1119. v,
  1120. dropout_p,
  1121. softmax_scale,
  1122. causal,
  1123. window_size,
  1124. softcap,
  1125. alibi_slopes,
  1126. deterministic,
  1127. return_attn_probs,
  1128. )
  1129. def flash_attn_varlen_qkvpacked_func(
  1130. qkv,
  1131. cu_seqlens,
  1132. max_seqlen,
  1133. dropout_p=0.0,
  1134. softmax_scale=None,
  1135. causal=False,
  1136. window_size=(-1, -1), # -1 means infinite context window
  1137. softcap=0.0, # 0.0 means deactivated
  1138. alibi_slopes=None,
  1139. deterministic=False,
  1140. return_attn_probs=False,
  1141. ):
  1142. """dropout_p should be set to 0.0 during evaluation
  1143. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  1144. calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
  1145. of the gradients of Q, K, V.
  1146. For multi-query and grouped-query attention (MQA/GQA), please see
  1147. flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
  1148. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1149. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  1150. Arguments:
  1151. qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
  1152. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1153. of the sequences in the batch, used to index into qkv.
  1154. max_seqlen: int. Maximum sequence length in the batch.
  1155. dropout_p: float. Dropout probability.
  1156. softmax_scale: float. The scaling of QK^T before applying softmax.
  1157. Default to 1 / sqrt(headdim).
  1158. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1159. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1160. softcap: float. Anything > 0 activates softcapping attention.
  1161. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
  1162. is added to the attention score of query i and key j.
  1163. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1164. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1165. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1166. testing only. The returned probabilities are not guaranteed to be correct
  1167. (they might not have the right scaling).
  1168. Return:
  1169. out: (total, nheads, headdim).
  1170. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  1171. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1172. normalization factor).
  1173. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1174. The output of softmax (possibly with different scaling). It also encodes the dropout
  1175. pattern (negative means that location was dropped, nonnegative means it was kept).
  1176. """
  1177. return FlashAttnVarlenQKVPackedFunc.apply(
  1178. qkv,
  1179. cu_seqlens,
  1180. max_seqlen,
  1181. dropout_p,
  1182. softmax_scale,
  1183. causal,
  1184. window_size,
  1185. softcap,
  1186. alibi_slopes,
  1187. deterministic,
  1188. return_attn_probs,
  1189. )
  1190. def flash_attn_varlen_kvpacked_func(
  1191. q,
  1192. kv,
  1193. cu_seqlens_q,
  1194. cu_seqlens_k,
  1195. max_seqlen_q,
  1196. max_seqlen_k,
  1197. dropout_p=0.0,
  1198. softmax_scale=None,
  1199. causal=False,
  1200. window_size=(-1, -1), # -1 means infinite context window
  1201. softcap=0.0, # 0.0 means deactivated
  1202. alibi_slopes=None,
  1203. deterministic=False,
  1204. return_attn_probs=False,
  1205. ):
  1206. """dropout_p should be set to 0.0 during evaluation
  1207. If K, V are already stacked into 1 tensor, this function will be faster than
  1208. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  1209. of the gradients of K, V.
  1210. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1211. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1212. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1213. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1214. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1215. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1216. 1 1 1 1 0
  1217. 1 1 1 1 1
  1218. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1219. 0 0
  1220. 0 0
  1221. 0 0
  1222. 1 0
  1223. 1 1
  1224. If the row of the mask is all zero, the output will be zero.
  1225. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1226. will only attend to keys between
  1227. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1228. Arguments:
  1229. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  1230. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1231. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1232. of the sequences in the batch, used to index into q.
  1233. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1234. of the sequences in the batch, used to index into kv.
  1235. max_seqlen_q: int. Maximum query sequence length in the batch.
  1236. max_seqlen_k: int. Maximum key sequence length in the batch.
  1237. dropout_p: float. Dropout probability.
  1238. softmax_scale: float. The scaling of QK^T before applying softmax.
  1239. Default to 1 / sqrt(headdim).
  1240. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1241. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1242. softcap: float. Anything > 0 activates softcapping attention.
  1243. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1244. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1245. is added to the attention score of query i and key j.
  1246. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1247. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1248. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1249. testing only. The returned probabilities are not guaranteed to be correct
  1250. (they might not have the right scaling).
  1251. Return:
  1252. out: (total, nheads, headdim).
  1253. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  1254. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1255. normalization factor).
  1256. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1257. The output of softmax (possibly with different scaling). It also encodes the dropout
  1258. pattern (negative means that location was dropped, nonnegative means it was kept).
  1259. """
  1260. return FlashAttnVarlenKVPackedFunc.apply(
  1261. q,
  1262. kv,
  1263. cu_seqlens_q,
  1264. cu_seqlens_k,
  1265. max_seqlen_q,
  1266. max_seqlen_k,
  1267. dropout_p,
  1268. softmax_scale,
  1269. causal,
  1270. window_size,
  1271. softcap,
  1272. alibi_slopes,
  1273. deterministic,
  1274. return_attn_probs,
  1275. )
  1276. def flash_attn_varlen_func(
  1277. q,
  1278. k,
  1279. v,
  1280. cu_seqlens_q,
  1281. cu_seqlens_k,
  1282. max_seqlen_q,
  1283. max_seqlen_k,
  1284. dropout_p=0.0,
  1285. softmax_scale=None,
  1286. causal=False,
  1287. window_size=(-1, -1), # -1 means infinite context window
  1288. softcap=0.0, # 0.0 means deactivated
  1289. alibi_slopes=None,
  1290. deterministic=False,
  1291. return_attn_probs=False,
  1292. block_table=None,
  1293. ):
  1294. """dropout_p should be set to 0.0 during evaluation
  1295. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  1296. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1297. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1298. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1299. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1300. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1301. 1 1 1 1 0
  1302. 1 1 1 1 1
  1303. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1304. 0 0
  1305. 0 0
  1306. 0 0
  1307. 1 0
  1308. 1 1
  1309. If the row of the mask is all zero, the output will be zero.
  1310. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1311. will only attend to keys between
  1312. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1313. Arguments:
  1314. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  1315. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1316. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1317. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1318. of the sequences in the batch, used to index into q.
  1319. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1320. of the sequences in the batch, used to index into kv.
  1321. max_seqlen_q: int. Maximum query sequence length in the batch.
  1322. max_seqlen_k: int. Maximum key sequence length in the batch.
  1323. dropout_p: float. Dropout probability.
  1324. softmax_scale: float. The scaling of QK^T before applying softmax.
  1325. Default to 1 / sqrt(headdim).
  1326. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1327. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1328. softcap: float. Anything > 0 activates softcapping attention.
  1329. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1330. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1331. is added to the attention score of query i and key j.
  1332. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1333. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1334. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1335. testing only. The returned probabilities are not guaranteed to be correct
  1336. (they might not have the right scaling).
  1337. Return:
  1338. out: (total, nheads, headdim).
  1339. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  1340. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1341. normalization factor).
  1342. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1343. The output of softmax (possibly with different scaling). It also encodes the dropout
  1344. pattern (negative means that location was dropped, nonnegative means it was kept).
  1345. """
  1346. return FlashAttnVarlenFunc.apply(
  1347. q,
  1348. k,
  1349. v,
  1350. cu_seqlens_q,
  1351. cu_seqlens_k,
  1352. max_seqlen_q,
  1353. max_seqlen_k,
  1354. dropout_p,
  1355. softmax_scale,
  1356. causal,
  1357. window_size,
  1358. softcap,
  1359. alibi_slopes,
  1360. deterministic,
  1361. return_attn_probs,
  1362. block_table,
  1363. )
  1364. def flash_attn_with_kvcache(
  1365. q,
  1366. k_cache,
  1367. v_cache,
  1368. k=None,
  1369. v=None,
  1370. rotary_cos=None,
  1371. rotary_sin=None,
  1372. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  1373. cache_batch_idx: Optional[torch.Tensor] = None,
  1374. cache_leftpad: Optional[torch.Tensor] = None,
  1375. block_table: Optional[torch.Tensor] = None,
  1376. softmax_scale=None,
  1377. causal=False,
  1378. window_size=(-1, -1), # -1 means infinite context window
  1379. softcap=0.0, # 0.0 means deactivated
  1380. rotary_interleaved=True,
  1381. alibi_slopes=None,
  1382. num_splits=0,
  1383. return_softmax_lse=False,
  1384. ):
  1385. """
  1386. If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
  1387. k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
  1388. the previous step, and update them with the new keys/values from the current step, and do
  1389. attention with the updated cache, all in 1 kernel.
  1390. If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
  1391. For example, the KV cache could be pre-allocated with the max sequence length, and you can use
  1392. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
  1393. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
  1394. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  1395. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
  1396. and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  1397. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
  1398. indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
  1399. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
  1400. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1401. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1402. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1403. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1404. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1405. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1406. 1 1 1 1 0
  1407. 1 1 1 1 1
  1408. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1409. 0 0
  1410. 0 0
  1411. 0 0
  1412. 1 0
  1413. 1 1
  1414. If the row of the mask is all zero, the output will be zero.
  1415. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1416. will only attend to keys between
  1417. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1418. Note: Does not support backward pass.
  1419. Arguments:
  1420. q: (batch_size, seqlen, nheads, headdim)
  1421. k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  1422. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  1423. page_block_size must be a multiple of 256.
  1424. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  1425. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  1426. k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
  1427. k with k_cache, starting at the indices specified by cache_seqlens.
  1428. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
  1429. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
  1430. to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
  1431. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
  1432. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
  1433. KV cache.
  1434. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
  1435. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
  1436. If the indices are not distinct, and k and v are provided, the values updated in the cache
  1437. might come from any of the duplicate indices.
  1438. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
  1439. block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
  1440. softmax_scale: float. The scaling of QK^T before applying softmax.
  1441. Default to 1 / sqrt(headdim).
  1442. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1443. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1444. softcap: float. Anything > 0 activates softcapping attention.
  1445. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
  1446. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
  1447. rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
  1448. (i.e. GPT-NeoX style).
  1449. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1450. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1451. is added to the attention score of query i and key j.
  1452. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
  1453. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
  1454. to automatically determine the number of splits.
  1455. Don't change this unless you know what you are doing.
  1456. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
  1457. Return:
  1458. out: (batch_size, seqlen, nheads, headdim).
  1459. softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
  1460. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1461. normalization factor).
  1462. """
  1463. assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
  1464. assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
  1465. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  1466. if softmax_scale is None:
  1467. softmax_scale = q.shape[-1] ** (-0.5)
  1468. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  1469. cache_seqlens = torch.full(
  1470. (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
  1471. )
  1472. cache_seqlens = maybe_contiguous(cache_seqlens)
  1473. cache_batch_idx = maybe_contiguous(cache_batch_idx)
  1474. block_table = maybe_contiguous(block_table)
  1475. out, softmax_lse = flash_attn_gpu.fwd_kvcache(
  1476. q,
  1477. k_cache,
  1478. v_cache,
  1479. k,
  1480. v,
  1481. cache_seqlens,
  1482. rotary_cos,
  1483. rotary_sin,
  1484. cache_batch_idx,
  1485. cache_leftpad,
  1486. block_table,
  1487. alibi_slopes,
  1488. None,
  1489. softmax_scale,
  1490. causal,
  1491. window_size[0],
  1492. window_size[1],
  1493. softcap,
  1494. rotary_interleaved,
  1495. num_splits,
  1496. )
  1497. return (out, softmax_lse) if return_softmax_lse else out