flash_attn_interface.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Sequence, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. import os
  6. # isort: off
  7. # We need to import the CUDA kernels after importing torch
  8. USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
  9. if USE_TRITON_ROCM:
  10. from .flash_attn_triton_amd import interface_fa as flash_attn_gpu
  11. else:
  12. import flash_attn_2_cuda as flash_attn_gpu
  13. # isort: on
  14. def maybe_contiguous(x):
  15. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  16. def _get_block_size_n(device, head_dim, is_dropout, is_causal):
  17. # This should match the block sizes in the CUDA kernel
  18. assert head_dim <= 256
  19. major, minor = torch.cuda.get_device_capability(device)
  20. is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
  21. is_sm80 = major == 8 and minor == 0
  22. is_sm90 = major == 9 and minor == 0
  23. if head_dim <= 32:
  24. return 128
  25. if head_dim <= 64:
  26. return 128 if not is_dropout else 64
  27. elif head_dim <= 96:
  28. return 64
  29. elif head_dim <= 128:
  30. if is_sm8x:
  31. return 64 if (not is_dropout and is_causal) else 32
  32. else:
  33. return 64 if not is_dropout else 32
  34. elif head_dim <= 160:
  35. if is_sm8x:
  36. return 64
  37. else:
  38. return 32
  39. elif head_dim <= 192:
  40. return 64
  41. elif head_dim <= 224:
  42. return 64
  43. elif head_dim <= 256:
  44. return 64
  45. def round_multiple(x, m):
  46. return (x + m - 1) // m * m
  47. # torch.compile() support is only enabled for pytorch >= 2.4
  48. # The reason for this is that we are using the new custom_op and register_fake
  49. # APIs, which support inplace modification of inputs in the function itself
  50. if torch.__version__ >= "2.4.0":
  51. _torch_custom_op_wrapper = torch.library.custom_op
  52. _torch_register_fake_wrapper = torch.library.register_fake
  53. else:
  54. def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
  55. def wrap(func):
  56. return func
  57. if fn is None:
  58. return wrap
  59. return fn
  60. def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
  61. def wrap(func):
  62. return func
  63. if fn is None:
  64. return wrap
  65. return fn
  66. _torch_custom_op_wrapper = noop_custom_op_wrapper
  67. _torch_register_fake_wrapper = noop_register_fake_wrapper
  68. @_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda")
  69. def _flash_attn_forward(
  70. q: torch.Tensor,
  71. k: torch.Tensor,
  72. v: torch.Tensor,
  73. dropout_p: float,
  74. softmax_scale: float,
  75. causal: bool,
  76. window_size_left: int,
  77. window_size_right: int,
  78. softcap: float,
  79. alibi_slopes: Optional[torch.Tensor],
  80. return_softmax: bool
  81. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  82. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  83. out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
  84. q,
  85. k,
  86. v,
  87. None,
  88. alibi_slopes,
  89. dropout_p,
  90. softmax_scale,
  91. causal,
  92. window_size_left,
  93. window_size_right,
  94. softcap,
  95. return_softmax,
  96. None,
  97. )
  98. return out, softmax_lse, S_dmask, rng_state
  99. @_torch_register_fake_wrapper("flash_attn::_flash_attn_forward")
  100. def _flash_attn_forward_fake(
  101. q: torch.Tensor,
  102. k: torch.Tensor,
  103. v: torch.Tensor,
  104. dropout_p: float,
  105. softmax_scale: float,
  106. causal: bool,
  107. window_size_left: int,
  108. window_size_right: int,
  109. softcap: float,
  110. alibi_slopes: Optional[torch.Tensor],
  111. return_softmax: bool
  112. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  113. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  114. batch_size, seqlen_q, num_heads, head_size = q.shape
  115. seqlen_k = k.shape[1]
  116. out = torch.empty_like(q)
  117. softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)
  118. p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
  119. if return_softmax:
  120. p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
  121. rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
  122. return out, softmax_lse, p, rng_state
  123. if torch.__version__ >= "2.4.0":
  124. _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward
  125. else:
  126. _wrapped_flash_attn_forward = _flash_attn_forward
  127. @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
  128. def _flash_attn_varlen_forward(
  129. q: torch.Tensor,
  130. k: torch.Tensor,
  131. v: torch.Tensor,
  132. cu_seqlens_q: torch.Tensor,
  133. cu_seqlens_k: torch.Tensor,
  134. max_seqlen_q: int,
  135. max_seqlen_k: int,
  136. dropout_p: float,
  137. softmax_scale: float,
  138. causal: bool,
  139. window_size_left: int = -1,
  140. window_size_right: int = -1,
  141. softcap: float = 0.0,
  142. alibi_slopes: Optional[torch.Tensor] = None,
  143. return_softmax: bool = False,
  144. block_table: Optional[torch.Tensor] = None,
  145. leftpad_k: Optional[torch.Tensor] = None,
  146. seqused_k: Optional[torch.Tensor] = None,
  147. zero_tensors: bool = False,
  148. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  149. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  150. out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
  151. q,
  152. k,
  153. v,
  154. None,
  155. cu_seqlens_q,
  156. cu_seqlens_k,
  157. seqused_k,
  158. leftpad_k,
  159. block_table,
  160. alibi_slopes,
  161. max_seqlen_q,
  162. max_seqlen_k,
  163. dropout_p,
  164. softmax_scale,
  165. zero_tensors,
  166. causal,
  167. window_size_left,
  168. window_size_right,
  169. softcap,
  170. return_softmax,
  171. None,
  172. )
  173. # if out.isnan().any() or softmax_lse.isnan().any():
  174. # breakpoint()
  175. return out, softmax_lse, S_dmask, rng_state
  176. @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward")
  177. def _flash_attn_varlen_forward_fake(
  178. q: torch.Tensor,
  179. k: torch.Tensor,
  180. v: torch.Tensor,
  181. cu_seqlens_q: torch.Tensor,
  182. cu_seqlens_k: torch.Tensor,
  183. max_seqlen_q: int,
  184. max_seqlen_k: int,
  185. dropout_p: float,
  186. softmax_scale: float,
  187. causal: bool,
  188. window_size_left: int = -1,
  189. window_size_right: int = -1,
  190. softcap: float = 0.0,
  191. alibi_slopes: Optional[torch.Tensor] = None,
  192. return_softmax: bool = False,
  193. block_table: Optional[torch.Tensor] = None,
  194. leftpad_k: Optional[torch.Tensor] = None,
  195. seqused_k: Optional[torch.Tensor] = None,
  196. zero_tensors: bool = False,
  197. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  198. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  199. paged_kv = block_table is not None
  200. batch_size = cu_seqlens_q.numel() - 1
  201. total_q, num_heads, _ = q.shape
  202. out = torch.empty_like(q)
  203. softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
  204. p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
  205. seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
  206. seqlen_k_rounded = round_multiple(max_seqlen_k, 128)
  207. if return_softmax:
  208. p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout)
  209. rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
  210. return out, softmax_lse, p, rng_state
  211. if torch.__version__ >= "2.4.0":
  212. _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
  213. else:
  214. _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
  215. @_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
  216. def _flash_attn_backward(
  217. dout: torch.Tensor,
  218. q: torch.Tensor,
  219. k: torch.Tensor,
  220. v: torch.Tensor,
  221. out: torch.Tensor,
  222. softmax_lse: torch.Tensor,
  223. dq: Optional[torch.Tensor],
  224. dk: Optional[torch.Tensor],
  225. dv: Optional[torch.Tensor],
  226. dropout_p: float,
  227. softmax_scale: float,
  228. causal: bool,
  229. window_size_left: int,
  230. window_size_right: int,
  231. softcap: float,
  232. alibi_slopes: Optional[torch.Tensor],
  233. deterministic: bool,
  234. rng_state: Optional[torch.Tensor] = None,
  235. ) -> torch.Tensor:
  236. # dq, dk, dv are allocated by us so they should already be contiguous
  237. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  238. (
  239. dq,
  240. dk,
  241. dv,
  242. softmax_d,
  243. ) = flash_attn_gpu.bwd(
  244. dout,
  245. q,
  246. k,
  247. v,
  248. out,
  249. softmax_lse,
  250. dq,
  251. dk,
  252. dv,
  253. alibi_slopes,
  254. dropout_p,
  255. softmax_scale,
  256. causal,
  257. window_size_left,
  258. window_size_right,
  259. softcap,
  260. deterministic,
  261. None,
  262. rng_state,
  263. )
  264. return softmax_d
  265. @_torch_register_fake_wrapper("flash_attn::_flash_attn_backward")
  266. def _flash_attn_backward_fake(
  267. dout: torch.Tensor,
  268. q: torch.Tensor,
  269. k: torch.Tensor,
  270. v: torch.Tensor,
  271. out: torch.Tensor,
  272. softmax_lse: torch.Tensor,
  273. dq: Optional[torch.Tensor],
  274. dk: Optional[torch.Tensor],
  275. dv: Optional[torch.Tensor],
  276. dropout_p: float,
  277. softmax_scale: float,
  278. causal: bool,
  279. window_size_left: int,
  280. window_size_right: int,
  281. softcap: float,
  282. alibi_slopes: Optional[torch.Tensor],
  283. deterministic: bool,
  284. rng_state: Optional[torch.Tensor] = None,
  285. ) -> torch.Tensor:
  286. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  287. if dq is None:
  288. dq = torch.empty_like(q)
  289. if dk is None:
  290. dk = torch.empty_like(k)
  291. if dv is None:
  292. dv = torch.empty_like(v)
  293. batch_size, seqlen_q, num_heads, _ = q.shape
  294. softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)
  295. return softmax_d
  296. if torch.__version__ >= "2.4.0":
  297. _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward
  298. else:
  299. _wrapped_flash_attn_backward = _flash_attn_backward
  300. @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
  301. def _flash_attn_varlen_backward(
  302. dout: torch.Tensor,
  303. q: torch.Tensor,
  304. k: torch.Tensor,
  305. v: torch.Tensor,
  306. out: torch.Tensor,
  307. softmax_lse: torch.Tensor,
  308. dq: Optional[torch.Tensor],
  309. dk: Optional[torch.Tensor],
  310. dv: Optional[torch.Tensor],
  311. cu_seqlens_q: torch.Tensor,
  312. cu_seqlens_k: torch.Tensor,
  313. max_seqlen_q: int,
  314. max_seqlen_k: int,
  315. dropout_p: float,
  316. softmax_scale: float,
  317. causal: bool,
  318. window_size_left: int,
  319. window_size_right: int,
  320. softcap: float,
  321. alibi_slopes: Optional[torch.Tensor],
  322. deterministic: bool,
  323. rng_state: Optional[torch.Tensor] = None,
  324. zero_tensors: bool = False,
  325. ) -> torch.Tensor:
  326. # dq, dk, dv are allocated by us so they should already be contiguous
  327. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  328. (
  329. dq,
  330. dk,
  331. dv,
  332. softmax_d,
  333. ) = flash_attn_gpu.varlen_bwd(
  334. dout,
  335. q,
  336. k,
  337. v,
  338. out,
  339. softmax_lse,
  340. dq,
  341. dk,
  342. dv,
  343. cu_seqlens_q,
  344. cu_seqlens_k,
  345. alibi_slopes,
  346. max_seqlen_q,
  347. max_seqlen_k,
  348. dropout_p,
  349. softmax_scale,
  350. zero_tensors,
  351. causal,
  352. window_size_left,
  353. window_size_right,
  354. softcap,
  355. deterministic,
  356. None,
  357. rng_state,
  358. )
  359. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  360. # breakpoint()
  361. return softmax_d
  362. @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward")
  363. def _flash_attn_varlen_backward_fake(
  364. dout: torch.Tensor,
  365. q: torch.Tensor,
  366. k: torch.Tensor,
  367. v: torch.Tensor,
  368. out: torch.Tensor,
  369. softmax_lse: torch.Tensor,
  370. dq: Optional[torch.Tensor],
  371. dk: Optional[torch.Tensor],
  372. dv: Optional[torch.Tensor],
  373. cu_seqlens_q: torch.Tensor,
  374. cu_seqlens_k: torch.Tensor,
  375. max_seqlen_q: int,
  376. max_seqlen_k: int,
  377. dropout_p: float,
  378. softmax_scale: float,
  379. causal: bool,
  380. window_size_left: int,
  381. window_size_right: int,
  382. softcap: float,
  383. alibi_slopes: Optional[torch.Tensor],
  384. deterministic: bool,
  385. rng_state: Optional[torch.Tensor] = None,
  386. zero_tensors: bool = False,
  387. ) -> torch.Tensor:
  388. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  389. batch_size = cu_seqlens_q.numel() - 1
  390. total_q, num_heads, _ = q.shape
  391. if dq is None:
  392. dq = torch.empty_like(q)
  393. if dk is None:
  394. dk = torch.empty_like(k)
  395. if dv is None:
  396. dv = torch.empty_like(v)
  397. softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)
  398. return softmax_d
  399. if torch.__version__ >= "2.4.0":
  400. _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward
  401. else:
  402. _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward
  403. class FlashAttnQKVPackedFunc(torch.autograd.Function):
  404. @staticmethod
  405. def forward(
  406. ctx,
  407. qkv,
  408. dropout_p,
  409. softmax_scale,
  410. causal,
  411. window_size,
  412. softcap,
  413. alibi_slopes,
  414. deterministic,
  415. return_softmax,
  416. is_grad_enabled,
  417. ):
  418. is_grad = is_grad_enabled and qkv.requires_grad
  419. if softmax_scale is None:
  420. softmax_scale = qkv.shape[-1] ** (-0.5)
  421. q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()
  422. head_size_og = q.size(3)
  423. if head_size_og % 8 != 0:
  424. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  425. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  426. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  427. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
  428. q,
  429. k,
  430. v,
  431. dropout_p,
  432. softmax_scale,
  433. causal=causal,
  434. window_size_left=window_size[0],
  435. window_size_right=window_size[1],
  436. softcap=softcap,
  437. alibi_slopes=alibi_slopes,
  438. return_softmax=return_softmax and dropout_p > 0,
  439. )
  440. if is_grad:
  441. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  442. ctx.dropout_p = dropout_p
  443. ctx.softmax_scale = softmax_scale
  444. ctx.causal = causal
  445. ctx.window_size = window_size
  446. ctx.softcap = softcap
  447. ctx.alibi_slopes = alibi_slopes
  448. ctx.deterministic = deterministic
  449. out = out_padded[..., :head_size_og]
  450. return out if not return_softmax else (out, softmax_lse, S_dmask)
  451. @staticmethod
  452. def backward(ctx, dout, *args):
  453. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  454. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  455. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  456. head_size_og = dout.size(3)
  457. dout_padded = dout
  458. if head_size_og % 8 != 0:
  459. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  460. _wrapped_flash_attn_backward(
  461. dout_padded,
  462. q,
  463. k,
  464. v,
  465. out,
  466. softmax_lse,
  467. dqkv[:, :, 0],
  468. dqkv[:, :, 1],
  469. dqkv[:, :, 2],
  470. ctx.dropout_p,
  471. ctx.softmax_scale,
  472. ctx.causal,
  473. ctx.window_size[0],
  474. ctx.window_size[1],
  475. ctx.softcap,
  476. ctx.alibi_slopes,
  477. ctx.deterministic,
  478. rng_state=rng_state,
  479. )
  480. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  481. return dqkv, None, None, None, None, None, None, None, None, None
  482. class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
  483. @staticmethod
  484. def forward(
  485. ctx,
  486. qkv,
  487. cu_seqlens,
  488. max_seqlen,
  489. dropout_p,
  490. softmax_scale,
  491. causal,
  492. window_size,
  493. softcap,
  494. alibi_slopes,
  495. deterministic,
  496. return_softmax,
  497. is_grad_enabled,
  498. ):
  499. is_grad = is_grad_enabled and qkv.requires_grad
  500. if softmax_scale is None:
  501. softmax_scale = qkv.shape[-1] ** (-0.5)
  502. q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
  503. head_size_og = q.size(2)
  504. if head_size_og % 8 != 0:
  505. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  506. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  507. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  508. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
  509. q,
  510. k,
  511. v,
  512. cu_seqlens,
  513. cu_seqlens,
  514. max_seqlen,
  515. max_seqlen,
  516. dropout_p,
  517. softmax_scale,
  518. causal=causal,
  519. window_size_left=window_size[0],
  520. window_size_right=window_size[1],
  521. softcap=softcap,
  522. alibi_slopes=alibi_slopes,
  523. return_softmax=return_softmax and dropout_p > 0,
  524. block_table=None,
  525. )
  526. if is_grad:
  527. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
  528. ctx.dropout_p = dropout_p
  529. ctx.max_seqlen = max_seqlen
  530. ctx.softmax_scale = softmax_scale
  531. ctx.causal = causal
  532. ctx.window_size = window_size
  533. ctx.softcap = softcap
  534. ctx.alibi_slopes = alibi_slopes
  535. ctx.deterministic = deterministic
  536. out = out_padded[..., :head_size_og]
  537. return out if not return_softmax else (out, softmax_lse, S_dmask)
  538. @staticmethod
  539. def backward(ctx, dout, *args):
  540. q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
  541. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  542. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  543. head_size_og = dout.size(2)
  544. dout_padded = dout
  545. if head_size_og % 8 != 0:
  546. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  547. _wrapped_flash_attn_varlen_backward(
  548. dout_padded,
  549. q,
  550. k,
  551. v,
  552. out,
  553. softmax_lse,
  554. dqkv[:, 0],
  555. dqkv[:, 1],
  556. dqkv[:, 2],
  557. cu_seqlens,
  558. cu_seqlens,
  559. ctx.max_seqlen,
  560. ctx.max_seqlen,
  561. ctx.dropout_p,
  562. ctx.softmax_scale,
  563. ctx.causal,
  564. ctx.window_size[0],
  565. ctx.window_size[1],
  566. ctx.softcap,
  567. ctx.alibi_slopes,
  568. ctx.deterministic,
  569. rng_state=rng_state,
  570. )
  571. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  572. return dqkv, None, None, None, None, None, None, None, None, None, None, None
  573. class FlashAttnKVPackedFunc(torch.autograd.Function):
  574. @staticmethod
  575. def forward(
  576. ctx,
  577. q,
  578. kv,
  579. dropout_p,
  580. softmax_scale,
  581. causal,
  582. window_size,
  583. softcap,
  584. alibi_slopes,
  585. deterministic,
  586. return_softmax,
  587. is_grad_enabled,
  588. ):
  589. is_grad = is_grad_enabled and any(
  590. x.requires_grad for x in [q, kv]
  591. )
  592. if softmax_scale is None:
  593. softmax_scale = q.shape[-1] ** (-0.5)
  594. k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach()
  595. head_size_og = q.size(3)
  596. if head_size_og % 8 != 0:
  597. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  598. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  599. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  600. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
  601. q,
  602. k,
  603. v,
  604. dropout_p,
  605. softmax_scale,
  606. causal=causal,
  607. window_size_left=window_size[0],
  608. window_size_right=window_size[1],
  609. softcap=softcap,
  610. alibi_slopes=alibi_slopes,
  611. return_softmax=return_softmax and dropout_p > 0,
  612. )
  613. if is_grad:
  614. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  615. ctx.dropout_p = dropout_p
  616. ctx.softmax_scale = softmax_scale
  617. ctx.causal = causal
  618. ctx.window_size = window_size
  619. ctx.softcap = softcap
  620. ctx.alibi_slopes = alibi_slopes
  621. ctx.deterministic = deterministic
  622. out = out_padded[..., :head_size_og]
  623. return out if not return_softmax else (out, softmax_lse, S_dmask)
  624. @staticmethod
  625. def backward(ctx, dout, *args):
  626. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  627. dq = torch.empty_like(q)
  628. kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
  629. dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
  630. head_size_og = dout.size(3)
  631. dout_padded = dout
  632. if head_size_og % 8 != 0:
  633. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  634. _wrapped_flash_attn_backward(
  635. dout_padded,
  636. q,
  637. k,
  638. v,
  639. out,
  640. softmax_lse,
  641. dq,
  642. dkv[:, :, 0],
  643. dkv[:, :, 1],
  644. ctx.dropout_p,
  645. ctx.softmax_scale,
  646. ctx.causal,
  647. ctx.window_size[0],
  648. ctx.window_size[1],
  649. ctx.softcap,
  650. ctx.alibi_slopes,
  651. ctx.deterministic,
  652. rng_state=rng_state,
  653. )
  654. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  655. dkv = dkv[..., : dout.shape[-1]]
  656. return dq, dkv, None, None, None, None, None, None, None, None, None
  657. class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
  658. @staticmethod
  659. def forward(
  660. ctx,
  661. q,
  662. kv,
  663. cu_seqlens_q,
  664. cu_seqlens_k,
  665. max_seqlen_q,
  666. max_seqlen_k,
  667. dropout_p,
  668. softmax_scale,
  669. causal,
  670. window_size,
  671. softcap,
  672. alibi_slopes,
  673. deterministic,
  674. return_softmax,
  675. is_grad_enabled,
  676. ):
  677. is_grad = is_grad_enabled and any(
  678. x.requires_grad for x in [q, kv]
  679. )
  680. if softmax_scale is None:
  681. softmax_scale = q.shape[-1] ** (-0.5)
  682. k, v = kv[:, 0].detach(), kv[:, 1].detach()
  683. head_size_og = q.size(2)
  684. if head_size_og % 8 != 0:
  685. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  686. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  687. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  688. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
  689. q,
  690. k,
  691. v,
  692. cu_seqlens_q,
  693. cu_seqlens_k,
  694. max_seqlen_q,
  695. max_seqlen_k,
  696. dropout_p,
  697. softmax_scale,
  698. causal=causal,
  699. window_size_left=window_size[0],
  700. window_size_right=window_size[1],
  701. softcap=softcap,
  702. alibi_slopes=alibi_slopes,
  703. return_softmax=return_softmax and dropout_p > 0,
  704. block_table=None,
  705. )
  706. if is_grad:
  707. ctx.save_for_backward(
  708. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
  709. )
  710. ctx.dropout_p = dropout_p
  711. ctx.max_seqlen_q = max_seqlen_q
  712. ctx.max_seqlen_k = max_seqlen_k
  713. ctx.softmax_scale = softmax_scale
  714. ctx.causal = causal
  715. ctx.window_size = window_size
  716. ctx.softcap = softcap
  717. ctx.alibi_slopes = alibi_slopes
  718. ctx.deterministic = deterministic
  719. out = out_padded[..., :head_size_og]
  720. return out if not return_softmax else (out, softmax_lse, S_dmask)
  721. @staticmethod
  722. def backward(ctx, dout, *args):
  723. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
  724. dq = torch.empty_like(q)
  725. kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
  726. dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
  727. head_size_og = dout.size(2)
  728. dout_padded = dout
  729. if head_size_og % 8 != 0:
  730. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  731. _wrapped_flash_attn_varlen_backward(
  732. dout_padded,
  733. q,
  734. k,
  735. v,
  736. out,
  737. softmax_lse,
  738. dq,
  739. dkv[:, 0],
  740. dkv[:, 1],
  741. cu_seqlens_q,
  742. cu_seqlens_k,
  743. ctx.max_seqlen_q,
  744. ctx.max_seqlen_k,
  745. ctx.dropout_p,
  746. ctx.softmax_scale,
  747. ctx.causal,
  748. ctx.window_size[0],
  749. ctx.window_size[1],
  750. ctx.softcap,
  751. ctx.alibi_slopes,
  752. ctx.deterministic,
  753. rng_state=rng_state,
  754. )
  755. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  756. dkv = dkv[..., : dout.shape[-1]]
  757. return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None
  758. class FlashAttnFunc(torch.autograd.Function):
  759. @staticmethod
  760. def forward(
  761. ctx,
  762. q,
  763. k,
  764. v,
  765. dropout_p,
  766. softmax_scale,
  767. causal,
  768. window_size,
  769. softcap,
  770. alibi_slopes,
  771. deterministic,
  772. return_softmax,
  773. is_grad_enabled,
  774. ):
  775. is_grad = is_grad_enabled and any(
  776. x.requires_grad for x in [q, k, v]
  777. )
  778. if softmax_scale is None:
  779. softmax_scale = q.shape[-1] ** (-0.5)
  780. head_size_og = q.size(3)
  781. if head_size_og % 8 != 0:
  782. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  783. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  784. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  785. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
  786. q,
  787. k,
  788. v,
  789. dropout_p,
  790. softmax_scale,
  791. causal=causal,
  792. window_size_left=window_size[0],
  793. window_size_right=window_size[1],
  794. softcap=softcap,
  795. alibi_slopes=alibi_slopes,
  796. return_softmax=return_softmax and dropout_p > 0,
  797. )
  798. if is_grad:
  799. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  800. ctx.dropout_p = dropout_p
  801. ctx.softmax_scale = softmax_scale
  802. ctx.causal = causal
  803. ctx.window_size = window_size
  804. ctx.softcap = softcap
  805. ctx.alibi_slopes = alibi_slopes
  806. ctx.deterministic = deterministic
  807. out = out_padded[..., :head_size_og]
  808. return out if not return_softmax else (out, softmax_lse, S_dmask)
  809. @staticmethod
  810. def backward(ctx, dout, *args):
  811. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  812. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  813. head_size_og = dout.size(3)
  814. dout_padded = dout
  815. if head_size_og % 8 != 0:
  816. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  817. _wrapped_flash_attn_backward(
  818. dout_padded,
  819. q,
  820. k,
  821. v,
  822. out,
  823. softmax_lse,
  824. dq,
  825. dk,
  826. dv,
  827. ctx.dropout_p,
  828. ctx.softmax_scale,
  829. ctx.causal,
  830. ctx.window_size[0],
  831. ctx.window_size[1],
  832. ctx.softcap,
  833. ctx.alibi_slopes,
  834. ctx.deterministic,
  835. rng_state=rng_state,
  836. )
  837. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  838. dk = dk[..., : dout.shape[-1]]
  839. dv = dv[..., : dout.shape[-1]]
  840. return dq, dk, dv, None, None, None, None, None, None, None, None, None
  841. class FlashAttnVarlenFunc(torch.autograd.Function):
  842. @staticmethod
  843. def forward(
  844. ctx,
  845. q,
  846. k,
  847. v,
  848. cu_seqlens_q,
  849. cu_seqlens_k,
  850. max_seqlen_q,
  851. max_seqlen_k,
  852. dropout_p,
  853. softmax_scale,
  854. causal,
  855. window_size,
  856. softcap,
  857. alibi_slopes,
  858. deterministic,
  859. return_softmax,
  860. block_table,
  861. is_grad_enabled,
  862. ):
  863. is_grad = is_grad_enabled and any(
  864. x.requires_grad for x in [q, k, v]
  865. )
  866. if softmax_scale is None:
  867. softmax_scale = q.shape[-1] ** (-0.5)
  868. head_size_og = q.size(2)
  869. if head_size_og % 8 != 0:
  870. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  871. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  872. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  873. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
  874. q,
  875. k,
  876. v,
  877. cu_seqlens_q,
  878. cu_seqlens_k,
  879. max_seqlen_q,
  880. max_seqlen_k,
  881. dropout_p,
  882. softmax_scale,
  883. causal=causal,
  884. window_size_left=window_size[0],
  885. window_size_right=window_size[1],
  886. softcap=softcap,
  887. alibi_slopes=alibi_slopes,
  888. return_softmax=return_softmax and dropout_p > 0,
  889. block_table=block_table,
  890. )
  891. if is_grad:
  892. ctx.save_for_backward(
  893. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
  894. )
  895. ctx.dropout_p = dropout_p
  896. ctx.max_seqlen_q = max_seqlen_q
  897. ctx.max_seqlen_k = max_seqlen_k
  898. ctx.softmax_scale = softmax_scale
  899. ctx.causal = causal
  900. ctx.window_size = window_size
  901. ctx.softcap = softcap
  902. ctx.alibi_slopes = alibi_slopes
  903. ctx.deterministic = deterministic
  904. out = out_padded[..., :head_size_og]
  905. return out if not return_softmax else (out, softmax_lse, S_dmask)
  906. @staticmethod
  907. def backward(ctx, dout, *args):
  908. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
  909. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  910. head_size_og = dout.size(2)
  911. dout_padded = dout
  912. if head_size_og % 8 != 0:
  913. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  914. _wrapped_flash_attn_varlen_backward(
  915. dout_padded,
  916. q,
  917. k,
  918. v,
  919. out,
  920. softmax_lse,
  921. dq,
  922. dk,
  923. dv,
  924. cu_seqlens_q,
  925. cu_seqlens_k,
  926. ctx.max_seqlen_q,
  927. ctx.max_seqlen_k,
  928. ctx.dropout_p,
  929. ctx.softmax_scale,
  930. ctx.causal,
  931. ctx.window_size[0],
  932. ctx.window_size[1],
  933. ctx.softcap,
  934. ctx.alibi_slopes,
  935. ctx.deterministic,
  936. rng_state=rng_state,
  937. )
  938. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  939. dk = dk[..., : dout.shape[-1]]
  940. dv = dv[..., : dout.shape[-1]]
  941. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
  942. def flash_attn_qkvpacked_func(
  943. qkv,
  944. dropout_p=0.0,
  945. softmax_scale=None,
  946. causal=False,
  947. window_size=(-1, -1), # -1 means infinite context window
  948. softcap=0.0, # <=0.0 means deactivate
  949. alibi_slopes=None,
  950. deterministic=False,
  951. return_attn_probs=False,
  952. ):
  953. """dropout_p should be set to 0.0 during evaluation
  954. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  955. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  956. of the gradients of Q, K, V.
  957. For multi-query and grouped-query attention (MQA/GQA), please see
  958. flash_attn_kvpacked_func and flash_attn_func.
  959. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  960. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  961. Arguments:
  962. qkv: (batch_size, seqlen, 3, nheads, headdim)
  963. dropout_p: float. Dropout probability.
  964. softmax_scale: float. The scaling of QK^T before applying softmax.
  965. Default to 1 / sqrt(headdim).
  966. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  967. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  968. softcap: float. Anything > 0 activates softcapping attention.
  969. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
  970. the attention score of query i and key j.
  971. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  972. which is slightly slower and uses more memory. The forward pass is always deterministic.
  973. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  974. testing only. The returned probabilities are not guaranteed to be correct
  975. (they might not have the right scaling).
  976. Return:
  977. out: (batch_size, seqlen, nheads, headdim).
  978. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  979. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  980. normalization factor).
  981. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  982. The output of softmax (possibly with different scaling). It also encodes the dropout
  983. pattern (negative means that location was dropped, nonnegative means it was kept).
  984. """
  985. return FlashAttnQKVPackedFunc.apply(
  986. qkv,
  987. dropout_p,
  988. softmax_scale,
  989. causal,
  990. window_size,
  991. softcap,
  992. alibi_slopes,
  993. deterministic,
  994. return_attn_probs,
  995. torch.is_grad_enabled(),
  996. )
  997. def flash_attn_kvpacked_func(
  998. q,
  999. kv,
  1000. dropout_p=0.0,
  1001. softmax_scale=None,
  1002. causal=False,
  1003. window_size=(-1, -1), # -1 means infinite context window
  1004. softcap=0.0, # 0.0 means deactivated
  1005. alibi_slopes=None,
  1006. deterministic=False,
  1007. return_attn_probs=False,
  1008. ):
  1009. """dropout_p should be set to 0.0 during evaluation
  1010. If K, V are already stacked into 1 tensor, this function will be faster than
  1011. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  1012. of the gradients of K, V.
  1013. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1014. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1015. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1016. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1017. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1018. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1019. 1 1 1 1 0
  1020. 1 1 1 1 1
  1021. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1022. 0 0
  1023. 0 0
  1024. 0 0
  1025. 1 0
  1026. 1 1
  1027. If the row of the mask is all zero, the output will be zero.
  1028. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1029. will only attend to keys between
  1030. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1031. Arguments:
  1032. q: (batch_size, seqlen, nheads, headdim)
  1033. kv: (batch_size, seqlen, 2, nheads_k, headdim)
  1034. dropout_p: float. Dropout probability.
  1035. softmax_scale: float. The scaling of QK^T before applying softmax.
  1036. Default to 1 / sqrt(headdim).
  1037. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1038. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1039. softcap: float. Anything > 0 activates softcapping attention.
  1040. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1041. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1042. is added to the attention score of query i and key j.
  1043. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1044. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1045. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1046. testing only. The returned probabilities are not guaranteed to be correct
  1047. (they might not have the right scaling).
  1048. Return:
  1049. out: (batch_size, seqlen, nheads, headdim).
  1050. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  1051. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1052. normalization factor).
  1053. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1054. The output of softmax (possibly with different scaling). It also encodes the dropout
  1055. pattern (negative means that location was dropped, nonnegative means it was kept).
  1056. """
  1057. return FlashAttnKVPackedFunc.apply(
  1058. q,
  1059. kv,
  1060. dropout_p,
  1061. softmax_scale,
  1062. causal,
  1063. window_size,
  1064. softcap,
  1065. alibi_slopes,
  1066. deterministic,
  1067. return_attn_probs,
  1068. torch.is_grad_enabled(),
  1069. )
  1070. def flash_attn_func(
  1071. q,
  1072. k,
  1073. v,
  1074. dropout_p=0.0,
  1075. softmax_scale=None,
  1076. causal=False,
  1077. window_size=(-1, -1), # -1 means infinite context window
  1078. softcap=0.0, # 0.0 means deactivated
  1079. alibi_slopes=None,
  1080. deterministic=False,
  1081. return_attn_probs=False,
  1082. ):
  1083. """dropout_p should be set to 0.0 during evaluation
  1084. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1085. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1086. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1087. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1088. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1089. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1090. 1 1 1 1 0
  1091. 1 1 1 1 1
  1092. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1093. 0 0
  1094. 0 0
  1095. 0 0
  1096. 1 0
  1097. 1 1
  1098. If the row of the mask is all zero, the output will be zero.
  1099. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1100. will only attend to keys between
  1101. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1102. Arguments:
  1103. q: (batch_size, seqlen, nheads, headdim)
  1104. k: (batch_size, seqlen, nheads_k, headdim)
  1105. v: (batch_size, seqlen, nheads_k, headdim)
  1106. dropout_p: float. Dropout probability.
  1107. softmax_scale: float. The scaling of QK^T before applying softmax.
  1108. Default to 1 / sqrt(headdim).
  1109. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1110. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1111. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1112. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1113. is added to the attention score of query i and key j.
  1114. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1115. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1116. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1117. testing only. The returned probabilities are not guaranteed to be correct
  1118. (they might not have the right scaling).
  1119. Return:
  1120. out: (batch_size, seqlen, nheads, headdim).
  1121. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  1122. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1123. normalization factor).
  1124. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1125. The output of softmax (possibly with different scaling). It also encodes the dropout
  1126. pattern (negative means that location was dropped, nonnegative means it was kept).
  1127. """
  1128. return FlashAttnFunc.apply(
  1129. q,
  1130. k,
  1131. v,
  1132. dropout_p,
  1133. softmax_scale,
  1134. causal,
  1135. window_size,
  1136. softcap,
  1137. alibi_slopes,
  1138. deterministic,
  1139. return_attn_probs,
  1140. torch.is_grad_enabled(),
  1141. )
  1142. def flash_attn_varlen_qkvpacked_func(
  1143. qkv,
  1144. cu_seqlens,
  1145. max_seqlen,
  1146. dropout_p=0.0,
  1147. softmax_scale=None,
  1148. causal=False,
  1149. window_size=(-1, -1), # -1 means infinite context window
  1150. softcap=0.0, # 0.0 means deactivated
  1151. alibi_slopes=None,
  1152. deterministic=False,
  1153. return_attn_probs=False,
  1154. ):
  1155. """dropout_p should be set to 0.0 during evaluation
  1156. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  1157. calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
  1158. of the gradients of Q, K, V.
  1159. For multi-query and grouped-query attention (MQA/GQA), please see
  1160. flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
  1161. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1162. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  1163. Arguments:
  1164. qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
  1165. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1166. of the sequences in the batch, used to index into qkv.
  1167. max_seqlen: int. Maximum sequence length in the batch.
  1168. dropout_p: float. Dropout probability.
  1169. softmax_scale: float. The scaling of QK^T before applying softmax.
  1170. Default to 1 / sqrt(headdim).
  1171. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1172. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1173. softcap: float. Anything > 0 activates softcapping attention.
  1174. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
  1175. is added to the attention score of query i and key j.
  1176. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1177. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1178. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1179. testing only. The returned probabilities are not guaranteed to be correct
  1180. (they might not have the right scaling).
  1181. Return:
  1182. out: (total, nheads, headdim).
  1183. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  1184. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1185. normalization factor).
  1186. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1187. The output of softmax (possibly with different scaling). It also encodes the dropout
  1188. pattern (negative means that location was dropped, nonnegative means it was kept).
  1189. """
  1190. return FlashAttnVarlenQKVPackedFunc.apply(
  1191. qkv,
  1192. cu_seqlens,
  1193. max_seqlen,
  1194. dropout_p,
  1195. softmax_scale,
  1196. causal,
  1197. window_size,
  1198. softcap,
  1199. alibi_slopes,
  1200. deterministic,
  1201. return_attn_probs,
  1202. torch.is_grad_enabled(),
  1203. )
  1204. def flash_attn_varlen_kvpacked_func(
  1205. q,
  1206. kv,
  1207. cu_seqlens_q,
  1208. cu_seqlens_k,
  1209. max_seqlen_q,
  1210. max_seqlen_k,
  1211. dropout_p=0.0,
  1212. softmax_scale=None,
  1213. causal=False,
  1214. window_size=(-1, -1), # -1 means infinite context window
  1215. softcap=0.0, # 0.0 means deactivated
  1216. alibi_slopes=None,
  1217. deterministic=False,
  1218. return_attn_probs=False,
  1219. ):
  1220. """dropout_p should be set to 0.0 during evaluation
  1221. If K, V are already stacked into 1 tensor, this function will be faster than
  1222. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  1223. of the gradients of K, V.
  1224. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1225. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1226. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1227. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1228. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1229. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1230. 1 1 1 1 0
  1231. 1 1 1 1 1
  1232. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1233. 0 0
  1234. 0 0
  1235. 0 0
  1236. 1 0
  1237. 1 1
  1238. If the row of the mask is all zero, the output will be zero.
  1239. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1240. will only attend to keys between
  1241. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1242. Arguments:
  1243. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  1244. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1245. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1246. of the sequences in the batch, used to index into q.
  1247. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1248. of the sequences in the batch, used to index into kv.
  1249. max_seqlen_q: int. Maximum query sequence length in the batch.
  1250. max_seqlen_k: int. Maximum key sequence length in the batch.
  1251. dropout_p: float. Dropout probability.
  1252. softmax_scale: float. The scaling of QK^T before applying softmax.
  1253. Default to 1 / sqrt(headdim).
  1254. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1255. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1256. softcap: float. Anything > 0 activates softcapping attention.
  1257. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1258. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1259. is added to the attention score of query i and key j.
  1260. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1261. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1262. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1263. testing only. The returned probabilities are not guaranteed to be correct
  1264. (they might not have the right scaling).
  1265. Return:
  1266. out: (total, nheads, headdim).
  1267. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  1268. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1269. normalization factor).
  1270. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1271. The output of softmax (possibly with different scaling). It also encodes the dropout
  1272. pattern (negative means that location was dropped, nonnegative means it was kept).
  1273. """
  1274. return FlashAttnVarlenKVPackedFunc.apply(
  1275. q,
  1276. kv,
  1277. cu_seqlens_q,
  1278. cu_seqlens_k,
  1279. max_seqlen_q,
  1280. max_seqlen_k,
  1281. dropout_p,
  1282. softmax_scale,
  1283. causal,
  1284. window_size,
  1285. softcap,
  1286. alibi_slopes,
  1287. deterministic,
  1288. return_attn_probs,
  1289. torch.is_grad_enabled(),
  1290. )
  1291. def flash_attn_varlen_func(
  1292. q,
  1293. k,
  1294. v,
  1295. cu_seqlens_q,
  1296. cu_seqlens_k,
  1297. max_seqlen_q,
  1298. max_seqlen_k,
  1299. dropout_p=0.0,
  1300. softmax_scale=None,
  1301. causal=False,
  1302. window_size=(-1, -1), # -1 means infinite context window
  1303. softcap=0.0, # 0.0 means deactivated
  1304. alibi_slopes=None,
  1305. deterministic=False,
  1306. return_attn_probs=False,
  1307. block_table=None,
  1308. ):
  1309. """dropout_p should be set to 0.0 during evaluation
  1310. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  1311. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1312. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1313. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1314. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1315. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1316. 1 1 1 1 0
  1317. 1 1 1 1 1
  1318. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1319. 0 0
  1320. 0 0
  1321. 0 0
  1322. 1 0
  1323. 1 1
  1324. If the row of the mask is all zero, the output will be zero.
  1325. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1326. will only attend to keys between
  1327. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1328. Arguments:
  1329. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  1330. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1331. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1332. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1333. of the sequences in the batch, used to index into q.
  1334. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1335. of the sequences in the batch, used to index into kv.
  1336. max_seqlen_q: int. Maximum query sequence length in the batch.
  1337. max_seqlen_k: int. Maximum key sequence length in the batch.
  1338. dropout_p: float. Dropout probability.
  1339. softmax_scale: float. The scaling of QK^T before applying softmax.
  1340. Default to 1 / sqrt(headdim).
  1341. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1342. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1343. softcap: float. Anything > 0 activates softcapping attention.
  1344. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1345. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1346. is added to the attention score of query i and key j.
  1347. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1348. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1349. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1350. testing only. The returned probabilities are not guaranteed to be correct
  1351. (they might not have the right scaling).
  1352. Return:
  1353. out: (total, nheads, headdim).
  1354. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  1355. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1356. normalization factor).
  1357. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1358. The output of softmax (possibly with different scaling). It also encodes the dropout
  1359. pattern (negative means that location was dropped, nonnegative means it was kept).
  1360. """
  1361. return FlashAttnVarlenFunc.apply(
  1362. q,
  1363. k,
  1364. v,
  1365. cu_seqlens_q,
  1366. cu_seqlens_k,
  1367. max_seqlen_q,
  1368. max_seqlen_k,
  1369. dropout_p,
  1370. softmax_scale,
  1371. causal,
  1372. window_size,
  1373. softcap,
  1374. alibi_slopes,
  1375. deterministic,
  1376. return_attn_probs,
  1377. block_table,
  1378. torch.is_grad_enabled(),
  1379. )
  1380. def flash_attn_with_kvcache(
  1381. q,
  1382. k_cache,
  1383. v_cache,
  1384. k=None,
  1385. v=None,
  1386. rotary_cos=None,
  1387. rotary_sin=None,
  1388. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  1389. cache_batch_idx: Optional[torch.Tensor] = None,
  1390. cache_leftpad: Optional[torch.Tensor] = None,
  1391. block_table: Optional[torch.Tensor] = None,
  1392. softmax_scale=None,
  1393. causal=False,
  1394. window_size=(-1, -1), # -1 means infinite context window
  1395. softcap=0.0, # 0.0 means deactivated
  1396. rotary_interleaved=True,
  1397. alibi_slopes=None,
  1398. num_splits=0,
  1399. return_softmax_lse=False,
  1400. ):
  1401. """
  1402. If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
  1403. k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
  1404. the previous step, and update them with the new keys/values from the current step, and do
  1405. attention with the updated cache, all in 1 kernel.
  1406. If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
  1407. For example, the KV cache could be pre-allocated with the max sequence length, and you can use
  1408. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
  1409. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
  1410. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  1411. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
  1412. and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  1413. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
  1414. indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
  1415. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
  1416. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1417. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1418. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1419. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1420. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1421. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1422. 1 1 1 1 0
  1423. 1 1 1 1 1
  1424. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1425. 0 0
  1426. 0 0
  1427. 0 0
  1428. 1 0
  1429. 1 1
  1430. If the row of the mask is all zero, the output will be zero.
  1431. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1432. will only attend to keys between
  1433. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1434. Note: Does not support backward pass.
  1435. Arguments:
  1436. q: (batch_size, seqlen, nheads, headdim)
  1437. k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  1438. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  1439. page_block_size must be a multiple of 256.
  1440. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  1441. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  1442. k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
  1443. k with k_cache, starting at the indices specified by cache_seqlens.
  1444. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
  1445. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
  1446. to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
  1447. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
  1448. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
  1449. KV cache.
  1450. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
  1451. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
  1452. If the indices are not distinct, and k and v are provided, the values updated in the cache
  1453. might come from any of the duplicate indices.
  1454. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
  1455. block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
  1456. softmax_scale: float. The scaling of QK^T before applying softmax.
  1457. Default to 1 / sqrt(headdim).
  1458. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1459. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1460. softcap: float. Anything > 0 activates softcapping attention.
  1461. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
  1462. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
  1463. rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
  1464. (i.e. GPT-NeoX style).
  1465. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1466. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1467. is added to the attention score of query i and key j.
  1468. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
  1469. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
  1470. to automatically determine the number of splits.
  1471. Don't change this unless you know what you are doing.
  1472. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
  1473. Return:
  1474. out: (batch_size, seqlen, nheads, headdim).
  1475. softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
  1476. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1477. normalization factor).
  1478. """
  1479. assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
  1480. assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
  1481. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  1482. if softmax_scale is None:
  1483. softmax_scale = q.shape[-1] ** (-0.5)
  1484. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  1485. cache_seqlens = torch.full(
  1486. (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
  1487. )
  1488. cache_seqlens = maybe_contiguous(cache_seqlens)
  1489. cache_batch_idx = maybe_contiguous(cache_batch_idx)
  1490. block_table = maybe_contiguous(block_table)
  1491. out, softmax_lse = flash_attn_gpu.fwd_kvcache(
  1492. q,
  1493. k_cache,
  1494. v_cache,
  1495. k,
  1496. v,
  1497. cache_seqlens,
  1498. rotary_cos,
  1499. rotary_sin,
  1500. cache_batch_idx,
  1501. cache_leftpad,
  1502. block_table,
  1503. alibi_slopes,
  1504. None,
  1505. softmax_scale,
  1506. causal,
  1507. window_size[0],
  1508. window_size[1],
  1509. softcap,
  1510. rotary_interleaved,
  1511. num_splits,
  1512. )
  1513. return (out, softmax_lse) if return_softmax_lse else out