1
0

flash_attn_interface.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Sequence, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flash_attn_2_cuda as flash_attn_cuda
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _get_block_size_n(device, head_dim, is_dropout, is_causal):
  12. # This should match the block sizes in the CUDA kernel
  13. assert head_dim <= 256
  14. major, minor = torch.cuda.get_device_capability(device)
  15. is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
  16. is_sm80 = major == 8 and minor == 0
  17. is_sm90 = major == 9 and minor == 0
  18. if head_dim <= 32:
  19. return 128
  20. if head_dim <= 64:
  21. return 128 if not is_dropout else 64
  22. elif head_dim <= 96:
  23. return 64
  24. elif head_dim <= 128:
  25. if is_sm8x:
  26. return 64 if (not is_dropout and is_causal) else 32
  27. else:
  28. return 64 if not is_dropout else 32
  29. elif head_dim <= 160:
  30. if is_sm8x:
  31. return 64
  32. else:
  33. return 32
  34. elif head_dim <= 192:
  35. return 64
  36. elif head_dim <= 224:
  37. return 64
  38. elif head_dim <= 256:
  39. return 64
  40. def round_multiple(x, m):
  41. return (x + m - 1) // m * m
  42. # torch.compile() support is only enabled for pytorch >= 2.4
  43. # The reason for this is that we are using the new custom_op and register_fake
  44. # APIs, which support inplace modification of inputs in the function itself
  45. if torch.__version__ >= "2.4.0":
  46. _torch_custom_op_wrapper = torch.library.custom_op
  47. _torch_register_fake_wrapper = torch.library.register_fake
  48. else:
  49. def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
  50. def wrap(func):
  51. return func
  52. if fn is None:
  53. return wrap
  54. return fn
  55. def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
  56. def wrap(func):
  57. return func
  58. if fn is None:
  59. return wrap
  60. return fn
  61. _torch_custom_op_wrapper = noop_custom_op_wrapper
  62. _torch_register_fake_wrapper = noop_register_fake_wrapper
  63. @_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda")
  64. def _flash_attn_forward(
  65. q: torch.Tensor,
  66. k: torch.Tensor,
  67. v: torch.Tensor,
  68. dropout_p: float,
  69. softmax_scale: float,
  70. causal: bool,
  71. window_size_left: int,
  72. window_size_right: int,
  73. softcap: float,
  74. alibi_slopes: Optional[torch.Tensor],
  75. return_softmax: bool
  76. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  77. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  78. out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
  79. q,
  80. k,
  81. v,
  82. None,
  83. alibi_slopes,
  84. dropout_p,
  85. softmax_scale,
  86. causal,
  87. window_size_left,
  88. window_size_right,
  89. softcap,
  90. return_softmax,
  91. None,
  92. )
  93. return out, softmax_lse, S_dmask, rng_state
  94. @_torch_register_fake_wrapper("flash_attn::_flash_attn_forward")
  95. def _flash_attn_forward_fake(
  96. q: torch.Tensor,
  97. k: torch.Tensor,
  98. v: torch.Tensor,
  99. dropout_p: float,
  100. softmax_scale: float,
  101. causal: bool,
  102. window_size_left: int,
  103. window_size_right: int,
  104. softcap: float,
  105. alibi_slopes: Optional[torch.Tensor],
  106. return_softmax: bool
  107. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  108. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  109. batch_size, seqlen_q, num_heads, head_size = q.shape
  110. seqlen_k = k.shape[1]
  111. out = torch.empty_like(q)
  112. softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)
  113. p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
  114. if return_softmax:
  115. p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
  116. rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
  117. return out, softmax_lse, p, rng_state
  118. if torch.__version__ >= "2.4.0":
  119. _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward
  120. else:
  121. _wrapped_flash_attn_forward = _flash_attn_forward
  122. @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
  123. def _flash_attn_varlen_forward(
  124. q: torch.Tensor,
  125. k: torch.Tensor,
  126. v: torch.Tensor,
  127. cu_seqlens_q: torch.Tensor,
  128. cu_seqlens_k: torch.Tensor,
  129. max_seqlen_q: int,
  130. max_seqlen_k: int,
  131. dropout_p: float,
  132. softmax_scale: float,
  133. causal: bool,
  134. window_size_left: int = -1,
  135. window_size_right: int = -1,
  136. softcap: float = 0.0,
  137. alibi_slopes: Optional[torch.Tensor] = None,
  138. return_softmax: bool = False,
  139. block_table: Optional[torch.Tensor] = None,
  140. leftpad_k: Optional[torch.Tensor] = None,
  141. seqused_k: Optional[torch.Tensor] = None,
  142. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  143. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  144. out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
  145. q,
  146. k,
  147. v,
  148. None,
  149. cu_seqlens_q,
  150. cu_seqlens_k,
  151. seqused_k,
  152. leftpad_k,
  153. block_table,
  154. alibi_slopes,
  155. max_seqlen_q,
  156. max_seqlen_k,
  157. dropout_p,
  158. softmax_scale,
  159. False,
  160. causal,
  161. window_size_left,
  162. window_size_right,
  163. softcap,
  164. return_softmax,
  165. None,
  166. )
  167. # if out.isnan().any() or softmax_lse.isnan().any():
  168. # breakpoint()
  169. return out, softmax_lse, S_dmask, rng_state
  170. @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward")
  171. def _flash_attn_varlen_forward_fake(
  172. q: torch.Tensor,
  173. k: torch.Tensor,
  174. v: torch.Tensor,
  175. cu_seqlens_q: torch.Tensor,
  176. cu_seqlens_k: torch.Tensor,
  177. max_seqlen_q: int,
  178. max_seqlen_k: int,
  179. dropout_p: float,
  180. softmax_scale: float,
  181. causal: bool,
  182. window_size_left: int = -1,
  183. window_size_right: int = -1,
  184. softcap: float = 0.0,
  185. alibi_slopes: Optional[torch.Tensor] = None,
  186. return_softmax: bool = False,
  187. block_table: Optional[torch.Tensor] = None,
  188. leftpad_k: Optional[torch.Tensor] = None,
  189. seqused_k: Optional[torch.Tensor] = None,
  190. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  191. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  192. paged_kv = block_table is not None
  193. batch_size = cu_seqlens_q.numel() - 1
  194. total_q, num_heads, _ = q.shape
  195. out = torch.empty_like(q)
  196. softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
  197. p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
  198. seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
  199. seqlen_k_rounded = round_multiple(max_seqlen_k, 128)
  200. if return_softmax:
  201. p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout)
  202. rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
  203. return out, softmax_lse, p, rng_state
  204. if torch.__version__ >= "2.4.0":
  205. _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
  206. else:
  207. _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
  208. @_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
  209. def _flash_attn_backward(
  210. dout: torch.Tensor,
  211. q: torch.Tensor,
  212. k: torch.Tensor,
  213. v: torch.Tensor,
  214. out: torch.Tensor,
  215. softmax_lse: torch.Tensor,
  216. dq: Optional[torch.Tensor],
  217. dk: Optional[torch.Tensor],
  218. dv: Optional[torch.Tensor],
  219. dropout_p: float,
  220. softmax_scale: float,
  221. causal: bool,
  222. window_size_left: int,
  223. window_size_right: int,
  224. softcap: float,
  225. alibi_slopes: Optional[torch.Tensor],
  226. deterministic: bool,
  227. rng_state: Optional[torch.Tensor] = None,
  228. ) -> torch.Tensor:
  229. # dq, dk, dv are allocated by us so they should already be contiguous
  230. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  231. (
  232. dq,
  233. dk,
  234. dv,
  235. softmax_d,
  236. ) = flash_attn_cuda.bwd(
  237. dout,
  238. q,
  239. k,
  240. v,
  241. out,
  242. softmax_lse,
  243. dq,
  244. dk,
  245. dv,
  246. alibi_slopes,
  247. dropout_p,
  248. softmax_scale,
  249. causal,
  250. window_size_left,
  251. window_size_right,
  252. softcap,
  253. deterministic,
  254. None,
  255. rng_state,
  256. )
  257. return softmax_d
  258. @_torch_register_fake_wrapper("flash_attn::_flash_attn_backward")
  259. def _flash_attn_backward_fake(
  260. dout: torch.Tensor,
  261. q: torch.Tensor,
  262. k: torch.Tensor,
  263. v: torch.Tensor,
  264. out: torch.Tensor,
  265. softmax_lse: torch.Tensor,
  266. dq: Optional[torch.Tensor],
  267. dk: Optional[torch.Tensor],
  268. dv: Optional[torch.Tensor],
  269. dropout_p: float,
  270. softmax_scale: float,
  271. causal: bool,
  272. window_size_left: int,
  273. window_size_right: int,
  274. softcap: float,
  275. alibi_slopes: Optional[torch.Tensor],
  276. deterministic: bool,
  277. rng_state: Optional[torch.Tensor] = None,
  278. ) -> torch.Tensor:
  279. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  280. if dq is None:
  281. dq = torch.empty_like(q)
  282. if dk is None:
  283. dk = torch.empty_like(k)
  284. if dv is None:
  285. dv = torch.empty_like(v)
  286. batch_size, seqlen_q, num_heads, _ = q.shape
  287. softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)
  288. return softmax_d
  289. if torch.__version__ >= "2.4.0":
  290. _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward
  291. else:
  292. _wrapped_flash_attn_backward = _flash_attn_backward
  293. @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
  294. def _flash_attn_varlen_backward(
  295. dout: torch.Tensor,
  296. q: torch.Tensor,
  297. k: torch.Tensor,
  298. v: torch.Tensor,
  299. out: torch.Tensor,
  300. softmax_lse: torch.Tensor,
  301. dq: Optional[torch.Tensor],
  302. dk: Optional[torch.Tensor],
  303. dv: Optional[torch.Tensor],
  304. cu_seqlens_q: torch.Tensor,
  305. cu_seqlens_k: torch.Tensor,
  306. max_seqlen_q: int,
  307. max_seqlen_k: int,
  308. dropout_p: float,
  309. softmax_scale: float,
  310. causal: bool,
  311. window_size_left: int,
  312. window_size_right: int,
  313. softcap: float,
  314. alibi_slopes: Optional[torch.Tensor],
  315. deterministic: bool,
  316. rng_state: Optional[torch.Tensor] = None,
  317. ) -> torch.Tensor:
  318. # dq, dk, dv are allocated by us so they should already be contiguous
  319. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  320. (
  321. dq,
  322. dk,
  323. dv,
  324. softmax_d,
  325. ) = flash_attn_cuda.varlen_bwd(
  326. dout,
  327. q,
  328. k,
  329. v,
  330. out,
  331. softmax_lse,
  332. dq,
  333. dk,
  334. dv,
  335. cu_seqlens_q,
  336. cu_seqlens_k,
  337. alibi_slopes,
  338. max_seqlen_q,
  339. max_seqlen_k,
  340. dropout_p,
  341. softmax_scale,
  342. False,
  343. causal,
  344. window_size_left,
  345. window_size_right,
  346. softcap,
  347. deterministic,
  348. None,
  349. rng_state,
  350. )
  351. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  352. # breakpoint()
  353. return softmax_d
  354. @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward")
  355. def _flash_attn_varlen_backward_fake(
  356. dout: torch.Tensor,
  357. q: torch.Tensor,
  358. k: torch.Tensor,
  359. v: torch.Tensor,
  360. out: torch.Tensor,
  361. softmax_lse: torch.Tensor,
  362. dq: Optional[torch.Tensor],
  363. dk: Optional[torch.Tensor],
  364. dv: Optional[torch.Tensor],
  365. cu_seqlens_q: torch.Tensor,
  366. cu_seqlens_k: torch.Tensor,
  367. max_seqlen_q: int,
  368. max_seqlen_k: int,
  369. dropout_p: float,
  370. softmax_scale: float,
  371. causal: bool,
  372. window_size_left: int,
  373. window_size_right: int,
  374. softcap: float,
  375. alibi_slopes: Optional[torch.Tensor],
  376. deterministic: bool,
  377. rng_state: Optional[torch.Tensor] = None,
  378. ) -> torch.Tensor:
  379. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  380. batch_size = cu_seqlens_q.numel() - 1
  381. total_q, num_heads, _ = q.shape
  382. if dq is None:
  383. dq = torch.empty_like(q)
  384. if dk is None:
  385. dk = torch.empty_like(k)
  386. if dv is None:
  387. dv = torch.empty_like(v)
  388. softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)
  389. return softmax_d
  390. if torch.__version__ >= "2.4.0":
  391. _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward
  392. else:
  393. _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward
  394. class FlashAttnQKVPackedFunc(torch.autograd.Function):
  395. @staticmethod
  396. def forward(
  397. ctx,
  398. qkv,
  399. dropout_p,
  400. softmax_scale,
  401. causal,
  402. window_size,
  403. softcap,
  404. alibi_slopes,
  405. deterministic,
  406. return_softmax,
  407. ):
  408. if softmax_scale is None:
  409. softmax_scale = qkv.shape[-1] ** (-0.5)
  410. q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()
  411. head_size_og = q.size(3)
  412. if head_size_og % 8 != 0:
  413. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  414. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  415. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  416. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
  417. q,
  418. k,
  419. v,
  420. dropout_p,
  421. softmax_scale,
  422. causal=causal,
  423. window_size_left=window_size[0],
  424. window_size_right=window_size[1],
  425. softcap=softcap,
  426. alibi_slopes=alibi_slopes,
  427. return_softmax=return_softmax and dropout_p > 0,
  428. )
  429. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  430. ctx.dropout_p = dropout_p
  431. ctx.softmax_scale = softmax_scale
  432. ctx.causal = causal
  433. ctx.window_size = window_size
  434. ctx.softcap = softcap
  435. ctx.alibi_slopes = alibi_slopes
  436. ctx.deterministic = deterministic
  437. out = out_padded[..., :head_size_og]
  438. return out if not return_softmax else (out, softmax_lse, S_dmask)
  439. @staticmethod
  440. def backward(ctx, dout, *args):
  441. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  442. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  443. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  444. head_size_og = dout.size(3)
  445. dout_padded = dout
  446. if head_size_og % 8 != 0:
  447. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  448. _wrapped_flash_attn_backward(
  449. dout_padded,
  450. q,
  451. k,
  452. v,
  453. out,
  454. softmax_lse,
  455. dqkv[:, :, 0],
  456. dqkv[:, :, 1],
  457. dqkv[:, :, 2],
  458. ctx.dropout_p,
  459. ctx.softmax_scale,
  460. ctx.causal,
  461. ctx.window_size[0],
  462. ctx.window_size[1],
  463. ctx.softcap,
  464. ctx.alibi_slopes,
  465. ctx.deterministic,
  466. rng_state=rng_state,
  467. )
  468. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  469. return dqkv, None, None, None, None, None, None, None, None
  470. class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
  471. @staticmethod
  472. def forward(
  473. ctx,
  474. qkv,
  475. cu_seqlens,
  476. max_seqlen,
  477. dropout_p,
  478. softmax_scale,
  479. causal,
  480. window_size,
  481. softcap,
  482. alibi_slopes,
  483. deterministic,
  484. return_softmax,
  485. ):
  486. if softmax_scale is None:
  487. softmax_scale = qkv.shape[-1] ** (-0.5)
  488. q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
  489. head_size_og = q.size(2)
  490. if head_size_og % 8 != 0:
  491. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  492. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  493. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  494. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
  495. q,
  496. k,
  497. v,
  498. cu_seqlens,
  499. cu_seqlens,
  500. max_seqlen,
  501. max_seqlen,
  502. dropout_p,
  503. softmax_scale,
  504. causal=causal,
  505. window_size_left=window_size[0],
  506. window_size_right=window_size[1],
  507. softcap=softcap,
  508. alibi_slopes=alibi_slopes,
  509. return_softmax=return_softmax and dropout_p > 0,
  510. block_table=None,
  511. )
  512. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
  513. ctx.dropout_p = dropout_p
  514. ctx.max_seqlen = max_seqlen
  515. ctx.softmax_scale = softmax_scale
  516. ctx.causal = causal
  517. ctx.window_size = window_size
  518. ctx.softcap = softcap
  519. ctx.alibi_slopes = alibi_slopes
  520. ctx.deterministic = deterministic
  521. out = out_padded[..., :head_size_og]
  522. return out if not return_softmax else (out, softmax_lse, S_dmask)
  523. @staticmethod
  524. def backward(ctx, dout, *args):
  525. q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
  526. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  527. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  528. head_size_og = dout.size(2)
  529. dout_padded = dout
  530. if head_size_og % 8 != 0:
  531. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  532. _wrapped_flash_attn_varlen_backward(
  533. dout_padded,
  534. q,
  535. k,
  536. v,
  537. out,
  538. softmax_lse,
  539. dqkv[:, 0],
  540. dqkv[:, 1],
  541. dqkv[:, 2],
  542. cu_seqlens,
  543. cu_seqlens,
  544. ctx.max_seqlen,
  545. ctx.max_seqlen,
  546. ctx.dropout_p,
  547. ctx.softmax_scale,
  548. ctx.causal,
  549. ctx.window_size[0],
  550. ctx.window_size[1],
  551. ctx.softcap,
  552. ctx.alibi_slopes,
  553. ctx.deterministic,
  554. rng_state=rng_state,
  555. )
  556. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  557. return dqkv, None, None, None, None, None, None, None, None, None, None
  558. class FlashAttnKVPackedFunc(torch.autograd.Function):
  559. @staticmethod
  560. def forward(
  561. ctx,
  562. q,
  563. kv,
  564. dropout_p,
  565. softmax_scale,
  566. causal,
  567. window_size,
  568. softcap,
  569. alibi_slopes,
  570. deterministic,
  571. return_softmax,
  572. ):
  573. if softmax_scale is None:
  574. softmax_scale = q.shape[-1] ** (-0.5)
  575. k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach()
  576. head_size_og = q.size(3)
  577. if head_size_og % 8 != 0:
  578. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  579. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  580. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  581. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
  582. q,
  583. k,
  584. v,
  585. dropout_p,
  586. softmax_scale,
  587. causal=causal,
  588. window_size_left=window_size[0],
  589. window_size_right=window_size[1],
  590. softcap=softcap,
  591. alibi_slopes=alibi_slopes,
  592. return_softmax=return_softmax and dropout_p > 0,
  593. )
  594. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  595. ctx.dropout_p = dropout_p
  596. ctx.softmax_scale = softmax_scale
  597. ctx.causal = causal
  598. ctx.window_size = window_size
  599. ctx.softcap = softcap
  600. ctx.alibi_slopes = alibi_slopes
  601. ctx.deterministic = deterministic
  602. out = out_padded[..., :head_size_og]
  603. return out if not return_softmax else (out, softmax_lse, S_dmask)
  604. @staticmethod
  605. def backward(ctx, dout, *args):
  606. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  607. dq = torch.empty_like(q)
  608. kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
  609. dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
  610. head_size_og = dout.size(3)
  611. dout_padded = dout
  612. if head_size_og % 8 != 0:
  613. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  614. _wrapped_flash_attn_backward(
  615. dout_padded,
  616. q,
  617. k,
  618. v,
  619. out,
  620. softmax_lse,
  621. dq,
  622. dkv[:, :, 0],
  623. dkv[:, :, 1],
  624. ctx.dropout_p,
  625. ctx.softmax_scale,
  626. ctx.causal,
  627. ctx.window_size[0],
  628. ctx.window_size[1],
  629. ctx.softcap,
  630. ctx.alibi_slopes,
  631. ctx.deterministic,
  632. rng_state=rng_state,
  633. )
  634. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  635. dkv = dkv[..., : dout.shape[-1]]
  636. return dq, dkv, None, None, None, None, None, None, None, None
  637. class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
  638. @staticmethod
  639. def forward(
  640. ctx,
  641. q,
  642. kv,
  643. cu_seqlens_q,
  644. cu_seqlens_k,
  645. max_seqlen_q,
  646. max_seqlen_k,
  647. dropout_p,
  648. softmax_scale,
  649. causal,
  650. window_size,
  651. softcap,
  652. alibi_slopes,
  653. deterministic,
  654. return_softmax,
  655. ):
  656. if softmax_scale is None:
  657. softmax_scale = q.shape[-1] ** (-0.5)
  658. k, v = kv[:, 0].detach(), kv[:, 1].detach()
  659. head_size_og = q.size(2)
  660. if head_size_og % 8 != 0:
  661. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  662. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  663. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  664. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
  665. q,
  666. k,
  667. v,
  668. cu_seqlens_q,
  669. cu_seqlens_k,
  670. max_seqlen_q,
  671. max_seqlen_k,
  672. dropout_p,
  673. softmax_scale,
  674. causal=causal,
  675. window_size_left=window_size[0],
  676. window_size_right=window_size[1],
  677. softcap=softcap,
  678. alibi_slopes=alibi_slopes,
  679. return_softmax=return_softmax and dropout_p > 0,
  680. block_table=None,
  681. )
  682. ctx.save_for_backward(
  683. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
  684. )
  685. ctx.dropout_p = dropout_p
  686. ctx.max_seqlen_q = max_seqlen_q
  687. ctx.max_seqlen_k = max_seqlen_k
  688. ctx.softmax_scale = softmax_scale
  689. ctx.causal = causal
  690. ctx.window_size = window_size
  691. ctx.softcap = softcap
  692. ctx.alibi_slopes = alibi_slopes
  693. ctx.deterministic = deterministic
  694. out = out_padded[..., :head_size_og]
  695. return out if not return_softmax else (out, softmax_lse, S_dmask)
  696. @staticmethod
  697. def backward(ctx, dout, *args):
  698. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
  699. dq = torch.empty_like(q)
  700. kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
  701. dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
  702. head_size_og = dout.size(2)
  703. dout_padded = dout
  704. if head_size_og % 8 != 0:
  705. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  706. _wrapped_flash_attn_varlen_backward(
  707. dout_padded,
  708. q,
  709. k,
  710. v,
  711. out,
  712. softmax_lse,
  713. dq,
  714. dkv[:, 0],
  715. dkv[:, 1],
  716. cu_seqlens_q,
  717. cu_seqlens_k,
  718. ctx.max_seqlen_q,
  719. ctx.max_seqlen_k,
  720. ctx.dropout_p,
  721. ctx.softmax_scale,
  722. ctx.causal,
  723. ctx.window_size[0],
  724. ctx.window_size[1],
  725. ctx.softcap,
  726. ctx.alibi_slopes,
  727. ctx.deterministic,
  728. rng_state=rng_state,
  729. )
  730. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  731. dkv = dkv[..., : dout.shape[-1]]
  732. return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
  733. class FlashAttnFunc(torch.autograd.Function):
  734. @staticmethod
  735. def forward(
  736. ctx,
  737. q,
  738. k,
  739. v,
  740. dropout_p,
  741. softmax_scale,
  742. causal,
  743. window_size,
  744. softcap,
  745. alibi_slopes,
  746. deterministic,
  747. return_softmax,
  748. ):
  749. if softmax_scale is None:
  750. softmax_scale = q.shape[-1] ** (-0.5)
  751. head_size_og = q.size(3)
  752. if head_size_og % 8 != 0:
  753. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  754. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  755. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  756. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
  757. q,
  758. k,
  759. v,
  760. dropout_p,
  761. softmax_scale,
  762. causal=causal,
  763. window_size_left=window_size[0],
  764. window_size_right=window_size[1],
  765. softcap=softcap,
  766. alibi_slopes=alibi_slopes,
  767. return_softmax=return_softmax and dropout_p > 0,
  768. )
  769. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  770. ctx.dropout_p = dropout_p
  771. ctx.softmax_scale = softmax_scale
  772. ctx.causal = causal
  773. ctx.window_size = window_size
  774. ctx.softcap = softcap
  775. ctx.alibi_slopes = alibi_slopes
  776. ctx.deterministic = deterministic
  777. out = out_padded[..., :head_size_og]
  778. return out if not return_softmax else (out, softmax_lse, S_dmask)
  779. @staticmethod
  780. def backward(ctx, dout, *args):
  781. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  782. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  783. head_size_og = dout.size(3)
  784. dout_padded = dout
  785. if head_size_og % 8 != 0:
  786. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  787. _wrapped_flash_attn_backward(
  788. dout_padded,
  789. q,
  790. k,
  791. v,
  792. out,
  793. softmax_lse,
  794. dq,
  795. dk,
  796. dv,
  797. ctx.dropout_p,
  798. ctx.softmax_scale,
  799. ctx.causal,
  800. ctx.window_size[0],
  801. ctx.window_size[1],
  802. ctx.softcap,
  803. ctx.alibi_slopes,
  804. ctx.deterministic,
  805. rng_state=rng_state,
  806. )
  807. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  808. dk = dk[..., : dout.shape[-1]]
  809. dv = dv[..., : dout.shape[-1]]
  810. return dq, dk, dv, None, None, None, None, None, None, None, None
  811. class FlashAttnVarlenFunc(torch.autograd.Function):
  812. @staticmethod
  813. def forward(
  814. ctx,
  815. q,
  816. k,
  817. v,
  818. cu_seqlens_q,
  819. cu_seqlens_k,
  820. max_seqlen_q,
  821. max_seqlen_k,
  822. dropout_p,
  823. softmax_scale,
  824. causal,
  825. window_size,
  826. softcap,
  827. alibi_slopes,
  828. deterministic,
  829. return_softmax,
  830. block_table,
  831. ):
  832. if softmax_scale is None:
  833. softmax_scale = q.shape[-1] ** (-0.5)
  834. head_size_og = q.size(2)
  835. if head_size_og % 8 != 0:
  836. q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
  837. k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
  838. v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
  839. out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
  840. q,
  841. k,
  842. v,
  843. cu_seqlens_q,
  844. cu_seqlens_k,
  845. max_seqlen_q,
  846. max_seqlen_k,
  847. dropout_p,
  848. softmax_scale,
  849. causal=causal,
  850. window_size_left=window_size[0],
  851. window_size_right=window_size[1],
  852. softcap=softcap,
  853. alibi_slopes=alibi_slopes,
  854. return_softmax=return_softmax and dropout_p > 0,
  855. block_table=block_table,
  856. )
  857. ctx.save_for_backward(
  858. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
  859. )
  860. ctx.dropout_p = dropout_p
  861. ctx.max_seqlen_q = max_seqlen_q
  862. ctx.max_seqlen_k = max_seqlen_k
  863. ctx.softmax_scale = softmax_scale
  864. ctx.causal = causal
  865. ctx.window_size = window_size
  866. ctx.softcap = softcap
  867. ctx.alibi_slopes = alibi_slopes
  868. ctx.deterministic = deterministic
  869. out = out_padded[..., :head_size_og]
  870. return out if not return_softmax else (out, softmax_lse, S_dmask)
  871. @staticmethod
  872. def backward(ctx, dout, *args):
  873. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
  874. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  875. head_size_og = dout.size(2)
  876. dout_padded = dout
  877. if head_size_og % 8 != 0:
  878. dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
  879. _wrapped_flash_attn_varlen_backward(
  880. dout_padded,
  881. q,
  882. k,
  883. v,
  884. out,
  885. softmax_lse,
  886. dq,
  887. dk,
  888. dv,
  889. cu_seqlens_q,
  890. cu_seqlens_k,
  891. ctx.max_seqlen_q,
  892. ctx.max_seqlen_k,
  893. ctx.dropout_p,
  894. ctx.softmax_scale,
  895. ctx.causal,
  896. ctx.window_size[0],
  897. ctx.window_size[1],
  898. ctx.softcap,
  899. ctx.alibi_slopes,
  900. ctx.deterministic,
  901. rng_state=rng_state,
  902. )
  903. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  904. dk = dk[..., : dout.shape[-1]]
  905. dv = dv[..., : dout.shape[-1]]
  906. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
  907. def flash_attn_qkvpacked_func(
  908. qkv,
  909. dropout_p=0.0,
  910. softmax_scale=None,
  911. causal=False,
  912. window_size=(-1, -1), # -1 means infinite context window
  913. softcap=0.0, # <=0.0 means deactivate
  914. alibi_slopes=None,
  915. deterministic=False,
  916. return_attn_probs=False,
  917. ):
  918. """dropout_p should be set to 0.0 during evaluation
  919. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  920. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  921. of the gradients of Q, K, V.
  922. For multi-query and grouped-query attention (MQA/GQA), please see
  923. flash_attn_kvpacked_func and flash_attn_func.
  924. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  925. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  926. Arguments:
  927. qkv: (batch_size, seqlen, 3, nheads, headdim)
  928. dropout_p: float. Dropout probability.
  929. softmax_scale: float. The scaling of QK^T before applying softmax.
  930. Default to 1 / sqrt(headdim).
  931. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  932. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  933. softcap: float. Anything > 0 activates softcapping attention.
  934. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
  935. the attention score of query i and key j.
  936. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  937. which is slightly slower and uses more memory. The forward pass is always deterministic.
  938. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  939. testing only. The returned probabilities are not guaranteed to be correct
  940. (they might not have the right scaling).
  941. Return:
  942. out: (batch_size, seqlen, nheads, headdim).
  943. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  944. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  945. normalization factor).
  946. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  947. The output of softmax (possibly with different scaling). It also encodes the dropout
  948. pattern (negative means that location was dropped, nonnegative means it was kept).
  949. """
  950. return FlashAttnQKVPackedFunc.apply(
  951. qkv,
  952. dropout_p,
  953. softmax_scale,
  954. causal,
  955. window_size,
  956. softcap,
  957. alibi_slopes,
  958. deterministic,
  959. return_attn_probs,
  960. )
  961. def flash_attn_kvpacked_func(
  962. q,
  963. kv,
  964. dropout_p=0.0,
  965. softmax_scale=None,
  966. causal=False,
  967. window_size=(-1, -1), # -1 means infinite context window
  968. softcap=0.0, # 0.0 means deactivated
  969. alibi_slopes=None,
  970. deterministic=False,
  971. return_attn_probs=False,
  972. ):
  973. """dropout_p should be set to 0.0 during evaluation
  974. If K, V are already stacked into 1 tensor, this function will be faster than
  975. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  976. of the gradients of K, V.
  977. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  978. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  979. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  980. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  981. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  982. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  983. 1 1 1 1 0
  984. 1 1 1 1 1
  985. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  986. 0 0
  987. 0 0
  988. 0 0
  989. 1 0
  990. 1 1
  991. If the row of the mask is all zero, the output will be zero.
  992. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  993. will only attend to keys between
  994. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  995. Arguments:
  996. q: (batch_size, seqlen, nheads, headdim)
  997. kv: (batch_size, seqlen, 2, nheads_k, headdim)
  998. dropout_p: float. Dropout probability.
  999. softmax_scale: float. The scaling of QK^T before applying softmax.
  1000. Default to 1 / sqrt(headdim).
  1001. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1002. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1003. softcap: float. Anything > 0 activates softcapping attention.
  1004. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1005. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1006. is added to the attention score of query i and key j.
  1007. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1008. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1009. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1010. testing only. The returned probabilities are not guaranteed to be correct
  1011. (they might not have the right scaling).
  1012. Return:
  1013. out: (batch_size, seqlen, nheads, headdim).
  1014. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  1015. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1016. normalization factor).
  1017. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1018. The output of softmax (possibly with different scaling). It also encodes the dropout
  1019. pattern (negative means that location was dropped, nonnegative means it was kept).
  1020. """
  1021. return FlashAttnKVPackedFunc.apply(
  1022. q,
  1023. kv,
  1024. dropout_p,
  1025. softmax_scale,
  1026. causal,
  1027. window_size,
  1028. softcap,
  1029. alibi_slopes,
  1030. deterministic,
  1031. return_attn_probs,
  1032. )
  1033. def flash_attn_func(
  1034. q,
  1035. k,
  1036. v,
  1037. dropout_p=0.0,
  1038. softmax_scale=None,
  1039. causal=False,
  1040. window_size=(-1, -1), # -1 means infinite context window
  1041. softcap=0.0, # 0.0 means deactivated
  1042. alibi_slopes=None,
  1043. deterministic=False,
  1044. return_attn_probs=False,
  1045. ):
  1046. """dropout_p should be set to 0.0 during evaluation
  1047. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1048. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1049. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1050. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1051. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1052. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1053. 1 1 1 1 0
  1054. 1 1 1 1 1
  1055. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1056. 0 0
  1057. 0 0
  1058. 0 0
  1059. 1 0
  1060. 1 1
  1061. If the row of the mask is all zero, the output will be zero.
  1062. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1063. will only attend to keys between
  1064. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1065. Arguments:
  1066. q: (batch_size, seqlen, nheads, headdim)
  1067. k: (batch_size, seqlen, nheads_k, headdim)
  1068. v: (batch_size, seqlen, nheads_k, headdim)
  1069. dropout_p: float. Dropout probability.
  1070. softmax_scale: float. The scaling of QK^T before applying softmax.
  1071. Default to 1 / sqrt(headdim).
  1072. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1073. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1074. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1075. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1076. is added to the attention score of query i and key j.
  1077. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1078. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1079. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1080. testing only. The returned probabilities are not guaranteed to be correct
  1081. (they might not have the right scaling).
  1082. Return:
  1083. out: (batch_size, seqlen, nheads, headdim).
  1084. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  1085. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1086. normalization factor).
  1087. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1088. The output of softmax (possibly with different scaling). It also encodes the dropout
  1089. pattern (negative means that location was dropped, nonnegative means it was kept).
  1090. """
  1091. return FlashAttnFunc.apply(
  1092. q,
  1093. k,
  1094. v,
  1095. dropout_p,
  1096. softmax_scale,
  1097. causal,
  1098. window_size,
  1099. softcap,
  1100. alibi_slopes,
  1101. deterministic,
  1102. return_attn_probs,
  1103. )
  1104. def flash_attn_varlen_qkvpacked_func(
  1105. qkv,
  1106. cu_seqlens,
  1107. max_seqlen,
  1108. dropout_p=0.0,
  1109. softmax_scale=None,
  1110. causal=False,
  1111. window_size=(-1, -1), # -1 means infinite context window
  1112. softcap=0.0, # 0.0 means deactivated
  1113. alibi_slopes=None,
  1114. deterministic=False,
  1115. return_attn_probs=False,
  1116. ):
  1117. """dropout_p should be set to 0.0 during evaluation
  1118. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  1119. calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
  1120. of the gradients of Q, K, V.
  1121. For multi-query and grouped-query attention (MQA/GQA), please see
  1122. flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
  1123. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1124. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  1125. Arguments:
  1126. qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
  1127. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1128. of the sequences in the batch, used to index into qkv.
  1129. max_seqlen: int. Maximum sequence length in the batch.
  1130. dropout_p: float. Dropout probability.
  1131. softmax_scale: float. The scaling of QK^T before applying softmax.
  1132. Default to 1 / sqrt(headdim).
  1133. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1134. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1135. softcap: float. Anything > 0 activates softcapping attention.
  1136. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
  1137. is added to the attention score of query i and key j.
  1138. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1139. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1140. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1141. testing only. The returned probabilities are not guaranteed to be correct
  1142. (they might not have the right scaling).
  1143. Return:
  1144. out: (total, nheads, headdim).
  1145. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  1146. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1147. normalization factor).
  1148. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1149. The output of softmax (possibly with different scaling). It also encodes the dropout
  1150. pattern (negative means that location was dropped, nonnegative means it was kept).
  1151. """
  1152. return FlashAttnVarlenQKVPackedFunc.apply(
  1153. qkv,
  1154. cu_seqlens,
  1155. max_seqlen,
  1156. dropout_p,
  1157. softmax_scale,
  1158. causal,
  1159. window_size,
  1160. softcap,
  1161. alibi_slopes,
  1162. deterministic,
  1163. return_attn_probs,
  1164. )
  1165. def flash_attn_varlen_kvpacked_func(
  1166. q,
  1167. kv,
  1168. cu_seqlens_q,
  1169. cu_seqlens_k,
  1170. max_seqlen_q,
  1171. max_seqlen_k,
  1172. dropout_p=0.0,
  1173. softmax_scale=None,
  1174. causal=False,
  1175. window_size=(-1, -1), # -1 means infinite context window
  1176. softcap=0.0, # 0.0 means deactivated
  1177. alibi_slopes=None,
  1178. deterministic=False,
  1179. return_attn_probs=False,
  1180. ):
  1181. """dropout_p should be set to 0.0 during evaluation
  1182. If K, V are already stacked into 1 tensor, this function will be faster than
  1183. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  1184. of the gradients of K, V.
  1185. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1186. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1187. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1188. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1189. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1190. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1191. 1 1 1 1 0
  1192. 1 1 1 1 1
  1193. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1194. 0 0
  1195. 0 0
  1196. 0 0
  1197. 1 0
  1198. 1 1
  1199. If the row of the mask is all zero, the output will be zero.
  1200. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1201. will only attend to keys between
  1202. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1203. Arguments:
  1204. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  1205. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1206. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1207. of the sequences in the batch, used to index into q.
  1208. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1209. of the sequences in the batch, used to index into kv.
  1210. max_seqlen_q: int. Maximum query sequence length in the batch.
  1211. max_seqlen_k: int. Maximum key sequence length in the batch.
  1212. dropout_p: float. Dropout probability.
  1213. softmax_scale: float. The scaling of QK^T before applying softmax.
  1214. Default to 1 / sqrt(headdim).
  1215. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1216. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1217. softcap: float. Anything > 0 activates softcapping attention.
  1218. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1219. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1220. is added to the attention score of query i and key j.
  1221. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1222. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1223. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1224. testing only. The returned probabilities are not guaranteed to be correct
  1225. (they might not have the right scaling).
  1226. Return:
  1227. out: (total, nheads, headdim).
  1228. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  1229. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1230. normalization factor).
  1231. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1232. The output of softmax (possibly with different scaling). It also encodes the dropout
  1233. pattern (negative means that location was dropped, nonnegative means it was kept).
  1234. """
  1235. return FlashAttnVarlenKVPackedFunc.apply(
  1236. q,
  1237. kv,
  1238. cu_seqlens_q,
  1239. cu_seqlens_k,
  1240. max_seqlen_q,
  1241. max_seqlen_k,
  1242. dropout_p,
  1243. softmax_scale,
  1244. causal,
  1245. window_size,
  1246. softcap,
  1247. alibi_slopes,
  1248. deterministic,
  1249. return_attn_probs,
  1250. )
  1251. def flash_attn_varlen_func(
  1252. q,
  1253. k,
  1254. v,
  1255. cu_seqlens_q,
  1256. cu_seqlens_k,
  1257. max_seqlen_q,
  1258. max_seqlen_k,
  1259. dropout_p=0.0,
  1260. softmax_scale=None,
  1261. causal=False,
  1262. window_size=(-1, -1), # -1 means infinite context window
  1263. softcap=0.0, # 0.0 means deactivated
  1264. alibi_slopes=None,
  1265. deterministic=False,
  1266. return_attn_probs=False,
  1267. block_table=None,
  1268. ):
  1269. """dropout_p should be set to 0.0 during evaluation
  1270. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  1271. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1272. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1273. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1274. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1275. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1276. 1 1 1 1 0
  1277. 1 1 1 1 1
  1278. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1279. 0 0
  1280. 0 0
  1281. 0 0
  1282. 1 0
  1283. 1 1
  1284. If the row of the mask is all zero, the output will be zero.
  1285. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1286. will only attend to keys between
  1287. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1288. Arguments:
  1289. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  1290. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1291. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1292. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1293. of the sequences in the batch, used to index into q.
  1294. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1295. of the sequences in the batch, used to index into kv.
  1296. max_seqlen_q: int. Maximum query sequence length in the batch.
  1297. max_seqlen_k: int. Maximum key sequence length in the batch.
  1298. dropout_p: float. Dropout probability.
  1299. softmax_scale: float. The scaling of QK^T before applying softmax.
  1300. Default to 1 / sqrt(headdim).
  1301. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1302. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1303. softcap: float. Anything > 0 activates softcapping attention.
  1304. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1305. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1306. is added to the attention score of query i and key j.
  1307. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1308. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1309. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1310. testing only. The returned probabilities are not guaranteed to be correct
  1311. (they might not have the right scaling).
  1312. Return:
  1313. out: (total, nheads, headdim).
  1314. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  1315. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1316. normalization factor).
  1317. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1318. The output of softmax (possibly with different scaling). It also encodes the dropout
  1319. pattern (negative means that location was dropped, nonnegative means it was kept).
  1320. """
  1321. return FlashAttnVarlenFunc.apply(
  1322. q,
  1323. k,
  1324. v,
  1325. cu_seqlens_q,
  1326. cu_seqlens_k,
  1327. max_seqlen_q,
  1328. max_seqlen_k,
  1329. dropout_p,
  1330. softmax_scale,
  1331. causal,
  1332. window_size,
  1333. softcap,
  1334. alibi_slopes,
  1335. deterministic,
  1336. return_attn_probs,
  1337. block_table,
  1338. )
  1339. def flash_attn_with_kvcache(
  1340. q,
  1341. k_cache,
  1342. v_cache,
  1343. k=None,
  1344. v=None,
  1345. rotary_cos=None,
  1346. rotary_sin=None,
  1347. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  1348. cache_batch_idx: Optional[torch.Tensor] = None,
  1349. cache_leftpad: Optional[torch.Tensor] = None,
  1350. block_table: Optional[torch.Tensor] = None,
  1351. softmax_scale=None,
  1352. causal=False,
  1353. window_size=(-1, -1), # -1 means infinite context window
  1354. softcap=0.0, # 0.0 means deactivated
  1355. rotary_interleaved=True,
  1356. alibi_slopes=None,
  1357. num_splits=0,
  1358. return_softmax_lse=False,
  1359. ):
  1360. """
  1361. If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
  1362. k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
  1363. the previous step, and update them with the new keys/values from the current step, and do
  1364. attention with the updated cache, all in 1 kernel.
  1365. If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
  1366. For example, the KV cache could be pre-allocated with the max sequence length, and you can use
  1367. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
  1368. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
  1369. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  1370. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
  1371. and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  1372. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
  1373. indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
  1374. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
  1375. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1376. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1377. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1378. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1379. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1380. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1381. 1 1 1 1 0
  1382. 1 1 1 1 1
  1383. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1384. 0 0
  1385. 0 0
  1386. 0 0
  1387. 1 0
  1388. 1 1
  1389. If the row of the mask is all zero, the output will be zero.
  1390. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1391. will only attend to keys between
  1392. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1393. Note: Does not support backward pass.
  1394. Arguments:
  1395. q: (batch_size, seqlen, nheads, headdim)
  1396. k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  1397. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  1398. page_block_size must be a multiple of 256.
  1399. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  1400. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  1401. k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
  1402. k with k_cache, starting at the indices specified by cache_seqlens.
  1403. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
  1404. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
  1405. to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
  1406. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
  1407. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
  1408. KV cache.
  1409. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
  1410. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
  1411. If the indices are not distinct, and k and v are provided, the values updated in the cache
  1412. might come from any of the duplicate indices.
  1413. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
  1414. block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
  1415. softmax_scale: float. The scaling of QK^T before applying softmax.
  1416. Default to 1 / sqrt(headdim).
  1417. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1418. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1419. softcap: float. Anything > 0 activates softcapping attention.
  1420. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
  1421. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
  1422. rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
  1423. (i.e. GPT-NeoX style).
  1424. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1425. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1426. is added to the attention score of query i and key j.
  1427. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
  1428. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
  1429. to automatically determine the number of splits.
  1430. Don't change this unless you know what you are doing.
  1431. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
  1432. Return:
  1433. out: (batch_size, seqlen, nheads, headdim).
  1434. softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
  1435. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1436. normalization factor).
  1437. """
  1438. assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
  1439. assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
  1440. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  1441. if softmax_scale is None:
  1442. softmax_scale = q.shape[-1] ** (-0.5)
  1443. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  1444. cache_seqlens = torch.full(
  1445. (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
  1446. )
  1447. cache_seqlens = maybe_contiguous(cache_seqlens)
  1448. cache_batch_idx = maybe_contiguous(cache_batch_idx)
  1449. block_table = maybe_contiguous(block_table)
  1450. out, softmax_lse = flash_attn_cuda.fwd_kvcache(
  1451. q,
  1452. k_cache,
  1453. v_cache,
  1454. k,
  1455. v,
  1456. cache_seqlens,
  1457. rotary_cos,
  1458. rotary_sin,
  1459. cache_batch_idx,
  1460. cache_leftpad,
  1461. block_table,
  1462. alibi_slopes,
  1463. None,
  1464. softmax_scale,
  1465. causal,
  1466. window_size[0],
  1467. window_size[1],
  1468. softcap,
  1469. rotary_interleaved,
  1470. num_splits,
  1471. )
  1472. return (out, softmax_lse) if return_softmax_lse else out