flash_attn_interface.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flash_attn_3_cuda
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _flash_attn_forward(
  12. q,
  13. k,
  14. v,
  15. k_new,
  16. v_new,
  17. out,
  18. cu_seqlens_q,
  19. cu_seqlens_k,
  20. cu_seqlens_k_new,
  21. seqused_q,
  22. seqused_k,
  23. max_seqlen_q,
  24. max_seqlen_k,
  25. page_table,
  26. kv_batch_idx,
  27. leftpad_k,
  28. rotary_cos,
  29. rotary_sin,
  30. q_descale,
  31. k_descale,
  32. v_descale,
  33. softmax_scale,
  34. causal,
  35. window_size=(-1, -1),
  36. sink_token_length=0,
  37. softcap=0.0,
  38. rotary_interleaved=True,
  39. num_splits=1,
  40. pack_gqa=None,
  41. sm_margin=0):
  42. assert sink_token_length == 0, "sink_token_length not supported yet"
  43. q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
  44. v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
  45. cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
  46. maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
  47. ]
  48. seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
  49. page_table, kv_batch_idx, leftpad_k = [
  50. maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
  51. ]
  52. rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
  53. out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
  54. q,
  55. k,
  56. v,
  57. k_new,
  58. v_new,
  59. out,
  60. cu_seqlens_q,
  61. cu_seqlens_k,
  62. cu_seqlens_k_new,
  63. seqused_q,
  64. seqused_k,
  65. max_seqlen_q,
  66. max_seqlen_k,
  67. page_table,
  68. kv_batch_idx,
  69. leftpad_k,
  70. rotary_cos,
  71. rotary_sin,
  72. q_descale,
  73. k_descale,
  74. v_descale,
  75. softmax_scale,
  76. causal,
  77. window_size[0],
  78. window_size[1],
  79. sink_token_length,
  80. softcap,
  81. rotary_interleaved,
  82. num_splits,
  83. pack_gqa,
  84. sm_margin,
  85. )
  86. return (out, softmax_lse, *rest)
  87. def _flash_attn_backward(
  88. dout,
  89. q,
  90. k,
  91. v,
  92. out,
  93. softmax_lse,
  94. cu_seqlens_q,
  95. cu_seqlens_k,
  96. sequed_q,
  97. sequed_k,
  98. max_seqlen_q,
  99. max_seqlen_k,
  100. dq,
  101. dk,
  102. dv,
  103. softmax_scale,
  104. causal,
  105. window_size=(-1, -1),
  106. sink_token_length=0,
  107. softcap=0.0,
  108. deterministic=False,
  109. sm_margin=0,
  110. ):
  111. assert sink_token_length == 0, "sink_token_length not supported yet"
  112. # dq, dk, dv are allocated by us so they should already be contiguous
  113. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  114. dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
  115. dout,
  116. q,
  117. k,
  118. v,
  119. out,
  120. softmax_lse,
  121. dq,
  122. dk,
  123. dv,
  124. cu_seqlens_q,
  125. cu_seqlens_k,
  126. sequed_q,
  127. sequed_k,
  128. max_seqlen_q,
  129. max_seqlen_k,
  130. softmax_scale,
  131. causal,
  132. window_size[0],
  133. window_size[1],
  134. sink_token_length,
  135. softcap,
  136. deterministic,
  137. sm_margin,
  138. )
  139. return dq, dk, dv, softmax_d
  140. class FlashAttnQKVPackedFunc(torch.autograd.Function):
  141. @staticmethod
  142. def forward(
  143. ctx,
  144. qkv,
  145. softmax_scale,
  146. causal,
  147. q_descale=None, k_descale=None, v_descale=None,
  148. window_size=(-1, -1),
  149. sink_token_length=0,
  150. softcap=0.0,
  151. deterministic=False,
  152. num_heads_q=None,
  153. ):
  154. if softmax_scale is None:
  155. softmax_scale = qkv.shape[-1] ** (-0.5)
  156. if qkv.dim() == 5:
  157. assert qkv.shape[-3] == 3
  158. q, k, v = qkv.unbind(dim=-3)
  159. else:
  160. assert qkv.dim() == 4
  161. assert num_heads_q is not None
  162. num_heads_k = (qkv.shape[2] - num_heads_q) // 2
  163. assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
  164. q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
  165. out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
  166. q,
  167. k,
  168. v,
  169. softmax_scale,
  170. causal=causal,
  171. q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
  172. window_size=window_size, sink_token_length=sink_token_length,
  173. softcap=softcap,
  174. )
  175. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  176. ctx.softmax_scale = softmax_scale
  177. ctx.causal = causal
  178. ctx.window_size = window_size
  179. ctx.sink_token_length = sink_token_length
  180. ctx.softcap = softcap
  181. ctx.deterministic = deterministic
  182. ctx.ndim = qkv.dim()
  183. # return out, softmax_lse
  184. return out
  185. @staticmethod
  186. def backward(ctx, dout, *args):
  187. q, k, v, out, softmax_lse = ctx.saved_tensors
  188. if ctx.ndim == 5:
  189. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  190. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  191. dq, dk, dv = dqkv.unbind(dim=-3)
  192. else:
  193. num_heads_q = q.shape[2]
  194. num_heads_k = k.shape[2]
  195. qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
  196. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  197. dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
  198. _flash_attn_backward(
  199. dout,
  200. q,
  201. k,
  202. v,
  203. out,
  204. softmax_lse,
  205. dq,
  206. dk,
  207. dv,
  208. ctx.softmax_scale,
  209. ctx.causal,
  210. ctx.window_size,
  211. ctx.sink_token_length,
  212. ctx.softcap,
  213. ctx.deterministic,
  214. )
  215. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  216. return dqkv, None, None, None, None, None, None, None, None, None, None
  217. class FlashAttnFunc(torch.autograd.Function):
  218. @staticmethod
  219. def forward(
  220. ctx,
  221. q,
  222. k,
  223. v,
  224. softmax_scale,
  225. causal,
  226. q_descale=None, k_descale=None, v_descale=None,
  227. window_size=(-1, -1),
  228. sink_token_length=0,
  229. softcap=0.0,
  230. num_splits=1,
  231. pack_gqa=None,
  232. deterministic=False,
  233. sm_margin=0,
  234. ):
  235. if softmax_scale is None:
  236. softmax_scale = q.shape[-1] ** (-0.5)
  237. # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
  238. out, softmax_lse, *rest = _flash_attn_forward(
  239. q,
  240. k,
  241. v,
  242. None, None, # k_new, v_new
  243. None, # out
  244. None, None, None, # cu_seqlens_q/k/k_new
  245. None, None, # seqused_q/k
  246. None, None, # max_seqlen_q/k
  247. None, None, None, # page_table, kv_batch_idx, leftpad_k,
  248. None, None, # rotary_cos/sin
  249. q_descale, k_descale, v_descale,
  250. softmax_scale,
  251. causal=causal,
  252. window_size=window_size,
  253. sink_token_length=sink_token_length,
  254. softcap=softcap,
  255. num_splits=num_splits,
  256. pack_gqa=pack_gqa,
  257. sm_margin=sm_margin,
  258. )
  259. # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  260. ctx.save_for_backward(q, k, v, out, softmax_lse)
  261. ctx.softmax_scale = softmax_scale
  262. ctx.causal = causal
  263. ctx.window_size = window_size
  264. ctx.sink_token_length = sink_token_length
  265. ctx.softcap = softcap
  266. ctx.deterministic = deterministic
  267. ctx.sm_margin = sm_margin
  268. return out, softmax_lse
  269. @staticmethod
  270. def backward(ctx, dout, *args):
  271. q, k, v, out, softmax_lse = ctx.saved_tensors
  272. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  273. _flash_attn_backward(
  274. dout,
  275. q,
  276. k,
  277. v,
  278. out,
  279. softmax_lse,
  280. None, None, # cu_seqlens_q, cu_seqlens_k,
  281. None, None, # sequed_q, sequed_k,
  282. None, None, # max_seqlen_q, max_seqlen_k,
  283. dq,
  284. dk,
  285. dv,
  286. ctx.softmax_scale,
  287. ctx.causal,
  288. ctx.window_size,
  289. ctx.sink_token_length,
  290. ctx.softcap,
  291. ctx.deterministic,
  292. ctx.sm_margin,
  293. )
  294. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  295. dk = dk[..., : dout.shape[-1]]
  296. dv = dv[..., : dout.shape[-1]]
  297. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
  298. class FlashAttnVarlenFunc(torch.autograd.Function):
  299. @staticmethod
  300. def forward(
  301. ctx,
  302. q,
  303. k,
  304. v,
  305. cu_seqlens_q,
  306. cu_seqlens_k,
  307. seqused_q,
  308. seqused_k,
  309. max_seqlen_q,
  310. max_seqlen_k,
  311. softmax_scale,
  312. causal,
  313. q_descale=None, k_descale=None, v_descale=None,
  314. window_size=(-1, -1),
  315. sink_token_length=0,
  316. softcap=0.0,
  317. num_splits=1,
  318. pack_gqa=None,
  319. deterministic=False,
  320. sm_margin=0,
  321. ):
  322. if softmax_scale is None:
  323. softmax_scale = q.shape[-1] ** (-0.5)
  324. # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
  325. out, softmax_lse, *rest = _flash_attn_forward(
  326. q,
  327. k,
  328. v,
  329. None, None, # k_new, v_new
  330. None, # out
  331. cu_seqlens_q,
  332. cu_seqlens_k,
  333. None, # cu_seqlens_k_new
  334. seqused_q,
  335. seqused_k,
  336. max_seqlen_q,
  337. max_seqlen_k,
  338. None, None, None, # page_table, kv_batch_idx, leftpad_k,
  339. None, None, # rotary_cos/sin
  340. q_descale, k_descale, v_descale,
  341. softmax_scale,
  342. causal=causal,
  343. window_size=window_size,
  344. sink_token_length=sink_token_length,
  345. softcap=softcap,
  346. num_splits=num_splits,
  347. pack_gqa=pack_gqa,
  348. sm_margin=sm_margin,
  349. )
  350. # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
  351. ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
  352. ctx.max_seqlen_q = max_seqlen_q
  353. ctx.max_seqlen_k = max_seqlen_k
  354. ctx.softmax_scale = softmax_scale
  355. ctx.causal = causal
  356. ctx.window_size = window_size
  357. ctx.sink_token_length = sink_token_length
  358. ctx.softcap = softcap
  359. ctx.deterministic = deterministic
  360. ctx.sm_margin = sm_margin
  361. return out, softmax_lse
  362. @staticmethod
  363. def backward(ctx, dout, *args):
  364. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
  365. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  366. _flash_attn_backward(
  367. dout,
  368. q,
  369. k,
  370. v,
  371. out,
  372. softmax_lse,
  373. cu_seqlens_q,
  374. cu_seqlens_k,
  375. seqused_q,
  376. seqused_k,
  377. ctx.max_seqlen_q,
  378. ctx.max_seqlen_k,
  379. dq,
  380. dk,
  381. dv,
  382. ctx.softmax_scale,
  383. ctx.causal,
  384. ctx.window_size,
  385. ctx.sink_token_length,
  386. ctx.softcap,
  387. ctx.deterministic,
  388. ctx.sm_margin,
  389. )
  390. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  391. dk = dk[..., : dout.shape[-1]]
  392. dv = dv[..., : dout.shape[-1]]
  393. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
  394. def flash_attn_qkvpacked_func(
  395. qkv,
  396. softmax_scale=None,
  397. causal=False,
  398. q_descale=None, k_descale=None, v_descale=None,
  399. window_size=(-1, -1),
  400. sink_token_length=0,
  401. softcap=0.0,
  402. deterministic=False,
  403. num_heads_q=None,
  404. ):
  405. """dropout_p should be set to 0.0 during evaluation
  406. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  407. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  408. of the gradients of Q, K, V.
  409. For multi-query and grouped-query attention (MQA/GQA), please see
  410. flash_attn_kvpacked_func and flash_attn_func.
  411. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  412. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  413. Arguments:
  414. qkv: (batch_size, seqlen, 3, nheads, headdim)
  415. dropout_p: float. Dropout probability.
  416. softmax_scale: float. The scaling of QK^T before applying softmax.
  417. Default to 1 / sqrt(headdim).
  418. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  419. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  420. softcap: float. Anything > 0 activates softcapping attention.
  421. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
  422. the attention score of query i and key j.
  423. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  424. which is slightly slower and uses more memory. The forward pass is always deterministic.
  425. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  426. testing only. The returned probabilities are not guaranteed to be correct
  427. (they might not have the right scaling).
  428. Return:
  429. out: (batch_size, seqlen, nheads, headdim).
  430. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  431. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  432. normalization factor).
  433. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  434. The output of softmax (possibly with different scaling). It also encodes the dropout
  435. pattern (negative means that location was dropped, nonnegative means it was kept).
  436. """
  437. return FlashAttnQKVPackedFunc.apply(
  438. qkv,
  439. softmax_scale,
  440. causal,
  441. q_descale, k_descale, v_descale,
  442. window_size,
  443. sink_token_length,
  444. softcap,
  445. deterministic,
  446. num_heads_q,
  447. )
  448. def flash_attn_func(
  449. q,
  450. k,
  451. v,
  452. softmax_scale=None,
  453. causal=False,
  454. q_descale=None, k_descale=None, v_descale=None,
  455. window_size=(-1, -1),
  456. sink_token_length=0,
  457. softcap=0.0,
  458. num_splits=1,
  459. pack_gqa=None,
  460. deterministic=False,
  461. sm_margin=0,
  462. ):
  463. """dropout_p should be set to 0.0 during evaluation
  464. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  465. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  466. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  467. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  468. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  469. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  470. 1 1 1 1 0
  471. 1 1 1 1 1
  472. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  473. 0 0
  474. 0 0
  475. 0 0
  476. 1 0
  477. 1 1
  478. If the row of the mask is all zero, the output will be zero.
  479. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  480. will only attend to keys between
  481. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  482. Arguments:
  483. q: (batch_size, seqlen, nheads, headdim)
  484. k: (batch_size, seqlen, nheads_k, headdim)
  485. v: (batch_size, seqlen, nheads_k, headdim)
  486. dropout_p: float. Dropout probability.
  487. softmax_scale: float. The scaling of QK^T before applying softmax.
  488. Default to 1 / sqrt(headdim).
  489. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  490. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  491. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  492. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  493. is added to the attention score of query i and key j.
  494. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  495. which is slightly slower and uses more memory. The forward pass is always deterministic.
  496. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  497. testing only. The returned probabilities are not guaranteed to be correct
  498. (they might not have the right scaling).
  499. Return:
  500. out: (batch_size, seqlen, nheads, headdim).
  501. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  502. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  503. normalization factor).
  504. """
  505. return FlashAttnFunc.apply(
  506. q,
  507. k,
  508. v,
  509. softmax_scale,
  510. causal,
  511. q_descale, k_descale, v_descale,
  512. window_size,
  513. sink_token_length,
  514. softcap,
  515. num_splits,
  516. pack_gqa,
  517. deterministic,
  518. sm_margin,
  519. )
  520. def flash_attn_varlen_func(
  521. q,
  522. k,
  523. v,
  524. cu_seqlens_q,
  525. cu_seqlens_k,
  526. seqused_q,
  527. seqused_k,
  528. max_seqlen_q,
  529. max_seqlen_k,
  530. softmax_scale=None,
  531. causal=False,
  532. q_descale=None, k_descale=None, v_descale=None,
  533. window_size=(-1, -1),
  534. sink_token_length=0,
  535. softcap=0.0,
  536. num_splits=1,
  537. pack_gqa=None,
  538. deterministic=False,
  539. sm_margin=0,
  540. ):
  541. return FlashAttnVarlenFunc.apply(
  542. q,
  543. k,
  544. v,
  545. cu_seqlens_q,
  546. cu_seqlens_k,
  547. seqused_q,
  548. seqused_k,
  549. max_seqlen_q,
  550. max_seqlen_k,
  551. softmax_scale,
  552. causal,
  553. q_descale, k_descale, v_descale,
  554. window_size,
  555. sink_token_length,
  556. softcap,
  557. num_splits,
  558. pack_gqa,
  559. deterministic,
  560. sm_margin,
  561. )
  562. def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
  563. return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
  564. def flash_attn_with_kvcache(
  565. q,
  566. k_cache,
  567. v_cache,
  568. k=None,
  569. v=None,
  570. rotary_cos=None,
  571. rotary_sin=None,
  572. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  573. cache_batch_idx: Optional[torch.Tensor] = None,
  574. cache_leftpad: Optional[torch.Tensor] = None,
  575. page_table: Optional[torch.Tensor] = None,
  576. cu_seqlens_q: Optional[torch.Tensor] = None,
  577. cu_seqlens_k_new: Optional[torch.Tensor] = None,
  578. max_seqlen_q: Optional[int] = None,
  579. q_descale: Optional[torch.Tensor] = None,
  580. k_descale: Optional[torch.Tensor] = None,
  581. v_descale: Optional[torch.Tensor] = None,
  582. softmax_scale=None,
  583. causal=False,
  584. window_size=(-1, -1), # -1 means infinite context window
  585. sink_token_length=0,
  586. softcap=0.0, # 0.0 means deactivated
  587. rotary_interleaved=True,
  588. num_splits=0, # Can be tuned for speed
  589. pack_gqa=None, # Can be tuned for speed
  590. sm_margin=0, # Can be tuned if some SMs are used for communication
  591. return_softmax_lse=False,
  592. ):
  593. """
  594. If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
  595. k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
  596. the previous step, and update them with the new keys/values from the current step, and do
  597. attention with the updated cache, all in 1 kernel.
  598. If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
  599. For example, the KV cache could be pre-allocated with the max sequence length, and you can use
  600. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
  601. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
  602. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  603. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
  604. and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  605. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
  606. indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
  607. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
  608. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  609. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  610. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  611. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  612. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  613. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  614. 1 1 1 1 0
  615. 1 1 1 1 1
  616. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  617. 0 0
  618. 0 0
  619. 0 0
  620. 1 0
  621. 1 1
  622. If the row of the mask is all zero, the output will be zero.
  623. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  624. will only attend to keys between
  625. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  626. Note: Does not support backward pass.
  627. Arguments:
  628. q: (batch_size, seqlen, nheads, headdim)
  629. k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
  630. or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
  631. page_block_size must be a multiple of 256.
  632. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no _table,
  633. or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
  634. k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
  635. k with k_cache, starting at the indices specified by cache_seqlens.
  636. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
  637. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
  638. to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
  639. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
  640. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
  641. KV cache.
  642. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
  643. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
  644. If the indices are not distinct, and k and v are provided, the values updated in the cache
  645. might come from any of the duplicate indices.
  646. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
  647. page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
  648. softmax_scale: float. The scaling of QK^T before applying softmax.
  649. Default to 1 / sqrt(headdim).
  650. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  651. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  652. softcap: float. Anything > 0 activates softcapping attention.
  653. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
  654. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
  655. rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
  656. (i.e. GPT-NeoX style).
  657. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
  658. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
  659. to automatically determine the number of splits.
  660. Don't change this unless you know what you are doing.
  661. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
  662. Return:
  663. out: (batch_size, seqlen, nheads, headdim).
  664. softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
  665. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  666. normalization factor).
  667. """
  668. assert sink_token_length == 0
  669. assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
  670. assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
  671. if softmax_scale is None:
  672. softmax_scale = q.shape[-1] ** (-0.5)
  673. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  674. cache_seqlens = torch.full(
  675. (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
  676. )
  677. cache_seqlens = maybe_contiguous(cache_seqlens)
  678. out, softmax_lse, *rest = _flash_attn_forward(
  679. q,
  680. k_cache,
  681. v_cache,
  682. k,
  683. v,
  684. None, # out
  685. cu_seqlens_q,
  686. None, # cu_seqlens_k
  687. cu_seqlens_k_new,
  688. None, # seqused_q
  689. cache_seqlens,
  690. max_seqlen_q,
  691. None, # max_seqlen_k
  692. page_table,
  693. cache_batch_idx,
  694. cache_leftpad,
  695. rotary_cos,
  696. rotary_sin,
  697. q_descale, k_descale, v_descale,
  698. softmax_scale,
  699. causal=causal,
  700. window_size=window_size,
  701. sink_token_length=sink_token_length,
  702. softcap=softcap,
  703. rotary_interleaved=rotary_interleaved,
  704. num_splits=num_splits,
  705. pack_gqa=pack_gqa,
  706. sm_margin=sm_margin,
  707. )
  708. # return (out, softmax_lse) if return_softmax_lse else out
  709. return (out, softmax_lse, *rest) if return_softmax_lse else out