flash_attn_interface.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flash_attn_3_cuda
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _flash_attn_forward(
  12. q,
  13. k,
  14. v,
  15. k_new,
  16. v_new,
  17. qv,
  18. out,
  19. cu_seqlens_q,
  20. cu_seqlens_k,
  21. cu_seqlens_k_new,
  22. seqused_q,
  23. seqused_k,
  24. max_seqlen_q,
  25. max_seqlen_k,
  26. page_table,
  27. kv_batch_idx,
  28. leftpad_k,
  29. rotary_cos,
  30. rotary_sin,
  31. q_descale,
  32. k_descale,
  33. v_descale,
  34. softmax_scale,
  35. causal,
  36. window_size=(-1, -1),
  37. softcap=0.0,
  38. rotary_interleaved=True,
  39. num_splits=1,
  40. pack_gqa=None,
  41. sm_margin=0):
  42. q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
  43. v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
  44. cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
  45. maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
  46. ]
  47. seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
  48. page_table, kv_batch_idx, leftpad_k = [
  49. maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
  50. ]
  51. rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
  52. out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
  53. q,
  54. k,
  55. v,
  56. k_new,
  57. v_new,
  58. qv,
  59. out,
  60. cu_seqlens_q,
  61. cu_seqlens_k,
  62. cu_seqlens_k_new,
  63. seqused_q,
  64. seqused_k,
  65. max_seqlen_q,
  66. max_seqlen_k,
  67. page_table,
  68. kv_batch_idx,
  69. leftpad_k,
  70. rotary_cos,
  71. rotary_sin,
  72. q_descale,
  73. k_descale,
  74. v_descale,
  75. softmax_scale,
  76. causal,
  77. window_size[0],
  78. window_size[1],
  79. softcap,
  80. rotary_interleaved,
  81. num_splits,
  82. pack_gqa,
  83. sm_margin,
  84. )
  85. return (out, softmax_lse, *rest)
  86. def _flash_attn_backward(
  87. dout,
  88. q,
  89. k,
  90. v,
  91. out,
  92. softmax_lse,
  93. cu_seqlens_q,
  94. cu_seqlens_k,
  95. sequed_q,
  96. sequed_k,
  97. max_seqlen_q,
  98. max_seqlen_k,
  99. dq,
  100. dk,
  101. dv,
  102. softmax_scale,
  103. causal,
  104. window_size=(-1, -1),
  105. softcap=0.0,
  106. deterministic=False,
  107. sm_margin=0,
  108. ):
  109. # dq, dk, dv are allocated by us so they should already be contiguous
  110. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  111. dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
  112. dout,
  113. q,
  114. k,
  115. v,
  116. out,
  117. softmax_lse,
  118. dq,
  119. dk,
  120. dv,
  121. cu_seqlens_q,
  122. cu_seqlens_k,
  123. sequed_q,
  124. sequed_k,
  125. max_seqlen_q,
  126. max_seqlen_k,
  127. softmax_scale,
  128. causal,
  129. window_size[0],
  130. window_size[1],
  131. softcap,
  132. deterministic,
  133. sm_margin,
  134. )
  135. return dq, dk, dv, softmax_d
  136. class FlashAttnQKVPackedFunc(torch.autograd.Function):
  137. @staticmethod
  138. def forward(
  139. ctx,
  140. qkv,
  141. softmax_scale,
  142. causal,
  143. q_descale=None, k_descale=None, v_descale=None,
  144. window_size=(-1, -1),
  145. softcap=0.0,
  146. deterministic=False,
  147. num_heads_q=None,
  148. ):
  149. if softmax_scale is None:
  150. softmax_scale = qkv.shape[-1] ** (-0.5)
  151. if qkv.dim() == 5:
  152. assert qkv.shape[-3] == 3
  153. q, k, v = qkv.unbind(dim=-3)
  154. else:
  155. assert qkv.dim() == 4
  156. assert num_heads_q is not None
  157. num_heads_k = (qkv.shape[2] - num_heads_q) // 2
  158. assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
  159. q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
  160. out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
  161. q,
  162. k,
  163. v,
  164. softmax_scale,
  165. causal=causal,
  166. q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
  167. window_size=window_size,
  168. softcap=softcap,
  169. )
  170. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  171. ctx.softmax_scale = softmax_scale
  172. ctx.causal = causal
  173. ctx.window_size = window_size
  174. ctx.softcap = softcap
  175. ctx.deterministic = deterministic
  176. ctx.ndim = qkv.dim()
  177. # return out, softmax_lse
  178. return out
  179. @staticmethod
  180. def backward(ctx, dout, *args):
  181. q, k, v, out, softmax_lse = ctx.saved_tensors
  182. if ctx.ndim == 5:
  183. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  184. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  185. dq, dk, dv = dqkv.unbind(dim=-3)
  186. else:
  187. num_heads_q = q.shape[2]
  188. num_heads_k = k.shape[2]
  189. qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
  190. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  191. dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
  192. _flash_attn_backward(
  193. dout,
  194. q,
  195. k,
  196. v,
  197. out,
  198. softmax_lse,
  199. dq,
  200. dk,
  201. dv,
  202. ctx.softmax_scale,
  203. ctx.causal,
  204. ctx.window_size,
  205. ctx.softcap,
  206. ctx.deterministic,
  207. )
  208. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  209. return dqkv, None, None, None, None, None, None, None, None, None, None
  210. class FlashAttnFunc(torch.autograd.Function):
  211. @staticmethod
  212. def forward(
  213. ctx,
  214. q,
  215. k,
  216. v,
  217. softmax_scale,
  218. causal,
  219. qv=None,
  220. q_descale=None, k_descale=None, v_descale=None,
  221. window_size=(-1, -1),
  222. softcap=0.0,
  223. num_splits=1,
  224. pack_gqa=None,
  225. deterministic=False,
  226. sm_margin=0,
  227. ):
  228. if softmax_scale is None:
  229. softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
  230. # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
  231. out, softmax_lse, *rest = _flash_attn_forward(
  232. q,
  233. k,
  234. v,
  235. None, None, # k_new, v_new
  236. qv, # qv
  237. None, # out
  238. None, None, None, # cu_seqlens_q/k/k_new
  239. None, None, # seqused_q/k
  240. None, None, # max_seqlen_q/k
  241. None, None, None, # page_table, kv_batch_idx, leftpad_k,
  242. None, None, # rotary_cos/sin
  243. q_descale, k_descale, v_descale,
  244. softmax_scale,
  245. causal=causal,
  246. window_size=window_size,
  247. softcap=softcap,
  248. num_splits=num_splits,
  249. pack_gqa=pack_gqa,
  250. sm_margin=sm_margin,
  251. )
  252. # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  253. ctx.save_for_backward(q, k, v, out, softmax_lse)
  254. ctx.softmax_scale = softmax_scale
  255. ctx.causal = causal
  256. ctx.window_size = window_size
  257. ctx.softcap = softcap
  258. ctx.deterministic = deterministic
  259. ctx.sm_margin = sm_margin
  260. return out, softmax_lse
  261. @staticmethod
  262. def backward(ctx, dout, *args):
  263. q, k, v, out, softmax_lse = ctx.saved_tensors
  264. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  265. _flash_attn_backward(
  266. dout,
  267. q,
  268. k,
  269. v,
  270. out,
  271. softmax_lse,
  272. None, None, # cu_seqlens_q, cu_seqlens_k,
  273. None, None, # sequed_q, sequed_k,
  274. None, None, # max_seqlen_q, max_seqlen_k,
  275. dq,
  276. dk,
  277. dv,
  278. ctx.softmax_scale,
  279. ctx.causal,
  280. ctx.window_size,
  281. ctx.softcap,
  282. ctx.deterministic,
  283. ctx.sm_margin,
  284. )
  285. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  286. dk = dk[..., : dout.shape[-1]]
  287. dv = dv[..., : dout.shape[-1]]
  288. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
  289. class FlashAttnVarlenFunc(torch.autograd.Function):
  290. @staticmethod
  291. def forward(
  292. ctx,
  293. q,
  294. k,
  295. v,
  296. cu_seqlens_q,
  297. cu_seqlens_k,
  298. seqused_q,
  299. seqused_k,
  300. max_seqlen_q,
  301. max_seqlen_k,
  302. softmax_scale,
  303. causal,
  304. qv=None,
  305. q_descale=None, k_descale=None, v_descale=None,
  306. window_size=(-1, -1),
  307. softcap=0.0,
  308. num_splits=1,
  309. pack_gqa=None,
  310. deterministic=False,
  311. sm_margin=0,
  312. ):
  313. if softmax_scale is None:
  314. softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
  315. # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
  316. out, softmax_lse, *rest = _flash_attn_forward(
  317. q,
  318. k,
  319. v,
  320. None, None, # k_new, v_new
  321. qv, # qv
  322. None, # out
  323. cu_seqlens_q,
  324. cu_seqlens_k,
  325. None, # cu_seqlens_k_new
  326. seqused_q,
  327. seqused_k,
  328. max_seqlen_q,
  329. max_seqlen_k,
  330. None, None, None, # page_table, kv_batch_idx, leftpad_k,
  331. None, None, # rotary_cos/sin
  332. q_descale, k_descale, v_descale,
  333. softmax_scale,
  334. causal=causal,
  335. window_size=window_size,
  336. softcap=softcap,
  337. num_splits=num_splits,
  338. pack_gqa=pack_gqa,
  339. sm_margin=sm_margin,
  340. )
  341. # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
  342. ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
  343. ctx.max_seqlen_q = max_seqlen_q
  344. ctx.max_seqlen_k = max_seqlen_k
  345. ctx.softmax_scale = softmax_scale
  346. ctx.causal = causal
  347. ctx.window_size = window_size
  348. ctx.softcap = softcap
  349. ctx.deterministic = deterministic
  350. ctx.sm_margin = sm_margin
  351. return out, softmax_lse
  352. @staticmethod
  353. def backward(ctx, dout, *args):
  354. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
  355. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  356. _flash_attn_backward(
  357. dout,
  358. q,
  359. k,
  360. v,
  361. out,
  362. softmax_lse,
  363. cu_seqlens_q,
  364. cu_seqlens_k,
  365. seqused_q,
  366. seqused_k,
  367. ctx.max_seqlen_q,
  368. ctx.max_seqlen_k,
  369. dq,
  370. dk,
  371. dv,
  372. ctx.softmax_scale,
  373. ctx.causal,
  374. ctx.window_size,
  375. ctx.softcap,
  376. ctx.deterministic,
  377. ctx.sm_margin,
  378. )
  379. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  380. dk = dk[..., : dout.shape[-1]]
  381. dv = dv[..., : dout.shape[-1]]
  382. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
  383. def flash_attn_qkvpacked_func(
  384. qkv,
  385. softmax_scale=None,
  386. causal=False,
  387. q_descale=None, k_descale=None, v_descale=None,
  388. window_size=(-1, -1),
  389. softcap=0.0,
  390. deterministic=False,
  391. num_heads_q=None,
  392. ):
  393. """dropout_p should be set to 0.0 during evaluation
  394. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  395. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  396. of the gradients of Q, K, V.
  397. For multi-query and grouped-query attention (MQA/GQA), please see
  398. flash_attn_kvpacked_func and flash_attn_func.
  399. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  400. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  401. Arguments:
  402. qkv: (batch_size, seqlen, 3, nheads, headdim)
  403. dropout_p: float. Dropout probability.
  404. softmax_scale: float. The scaling of QK^T before applying softmax.
  405. Default to 1 / sqrt(headdim).
  406. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  407. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  408. softcap: float. Anything > 0 activates softcapping attention.
  409. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
  410. the attention score of query i and key j.
  411. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  412. which is slightly slower and uses more memory. The forward pass is always deterministic.
  413. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  414. testing only. The returned probabilities are not guaranteed to be correct
  415. (they might not have the right scaling).
  416. Return:
  417. out: (batch_size, seqlen, nheads, headdim).
  418. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  419. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  420. normalization factor).
  421. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  422. The output of softmax (possibly with different scaling). It also encodes the dropout
  423. pattern (negative means that location was dropped, nonnegative means it was kept).
  424. """
  425. return FlashAttnQKVPackedFunc.apply(
  426. qkv,
  427. softmax_scale,
  428. causal,
  429. q_descale, k_descale, v_descale,
  430. window_size,
  431. softcap,
  432. deterministic,
  433. num_heads_q,
  434. )
  435. def flash_attn_func(
  436. q,
  437. k,
  438. v,
  439. softmax_scale=None,
  440. causal=False,
  441. qv=None,
  442. q_descale=None, k_descale=None, v_descale=None,
  443. window_size=(-1, -1),
  444. softcap=0.0,
  445. num_splits=1,
  446. pack_gqa=None,
  447. deterministic=False,
  448. sm_margin=0,
  449. ):
  450. """dropout_p should be set to 0.0 during evaluation
  451. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  452. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  453. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  454. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  455. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  456. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  457. 1 1 1 1 0
  458. 1 1 1 1 1
  459. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  460. 0 0
  461. 0 0
  462. 0 0
  463. 1 0
  464. 1 1
  465. If the row of the mask is all zero, the output will be zero.
  466. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  467. will only attend to keys between
  468. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  469. Arguments:
  470. q: (batch_size, seqlen, nheads, headdim)
  471. k: (batch_size, seqlen, nheads_k, headdim)
  472. v: (batch_size, seqlen, nheads_k, headdim)
  473. dropout_p: float. Dropout probability.
  474. softmax_scale: float. The scaling of QK^T before applying softmax.
  475. Default to 1 / sqrt(headdim).
  476. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  477. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  478. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  479. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  480. is added to the attention score of query i and key j.
  481. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  482. which is slightly slower and uses more memory. The forward pass is always deterministic.
  483. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  484. testing only. The returned probabilities are not guaranteed to be correct
  485. (they might not have the right scaling).
  486. Return:
  487. out: (batch_size, seqlen, nheads, headdim).
  488. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  489. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  490. normalization factor).
  491. """
  492. return FlashAttnFunc.apply(
  493. q,
  494. k,
  495. v,
  496. softmax_scale,
  497. causal,
  498. qv,
  499. q_descale, k_descale, v_descale,
  500. window_size,
  501. softcap,
  502. num_splits,
  503. pack_gqa,
  504. deterministic,
  505. sm_margin,
  506. )
  507. def flash_attn_varlen_func(
  508. q,
  509. k,
  510. v,
  511. cu_seqlens_q,
  512. cu_seqlens_k,
  513. max_seqlen_q,
  514. max_seqlen_k,
  515. seqused_q=None,
  516. seqused_k=None,
  517. softmax_scale=None,
  518. causal=False,
  519. qv=None,
  520. q_descale=None, k_descale=None, v_descale=None,
  521. window_size=(-1, -1),
  522. softcap=0.0,
  523. num_splits=1,
  524. pack_gqa=None,
  525. deterministic=False,
  526. sm_margin=0,
  527. ):
  528. return FlashAttnVarlenFunc.apply(
  529. q,
  530. k,
  531. v,
  532. cu_seqlens_q,
  533. cu_seqlens_k,
  534. seqused_q,
  535. seqused_k,
  536. max_seqlen_q,
  537. max_seqlen_k,
  538. softmax_scale,
  539. causal,
  540. qv,
  541. q_descale, k_descale, v_descale,
  542. window_size,
  543. softcap,
  544. num_splits,
  545. pack_gqa,
  546. deterministic,
  547. sm_margin,
  548. )
  549. def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
  550. return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
  551. def flash_attn_with_kvcache(
  552. q,
  553. k_cache,
  554. v_cache,
  555. k=None,
  556. v=None,
  557. qv=None,
  558. rotary_cos=None,
  559. rotary_sin=None,
  560. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  561. cache_batch_idx: Optional[torch.Tensor] = None,
  562. cache_leftpad: Optional[torch.Tensor] = None,
  563. page_table: Optional[torch.Tensor] = None,
  564. cu_seqlens_q: Optional[torch.Tensor] = None,
  565. cu_seqlens_k_new: Optional[torch.Tensor] = None,
  566. max_seqlen_q: Optional[int] = None,
  567. q_descale: Optional[torch.Tensor] = None,
  568. k_descale: Optional[torch.Tensor] = None,
  569. v_descale: Optional[torch.Tensor] = None,
  570. softmax_scale=None,
  571. causal=False,
  572. window_size=(-1, -1), # -1 means infinite context window
  573. softcap=0.0, # 0.0 means deactivated
  574. rotary_interleaved=True,
  575. num_splits=0, # Can be tuned for speed
  576. pack_gqa=None, # Can be tuned for speed
  577. sm_margin=0, # Can be tuned if some SMs are used for communication
  578. return_softmax_lse=False,
  579. ):
  580. """
  581. If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
  582. k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
  583. the previous step, and update them with the new keys/values from the current step, and do
  584. attention with the updated cache, all in 1 kernel.
  585. If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
  586. For example, the KV cache could be pre-allocated with the max sequence length, and you can use
  587. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
  588. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
  589. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  590. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
  591. and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  592. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
  593. indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
  594. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
  595. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  596. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  597. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  598. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  599. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  600. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  601. 1 1 1 1 0
  602. 1 1 1 1 1
  603. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  604. 0 0
  605. 0 0
  606. 0 0
  607. 1 0
  608. 1 1
  609. If the row of the mask is all zero, the output will be zero.
  610. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  611. will only attend to keys between
  612. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  613. Note: Does not support backward pass.
  614. Arguments:
  615. q: (batch_size, seqlen, nheads, headdim)
  616. k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
  617. or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
  618. page_block_size must be a multiple of 256.
  619. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
  620. or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
  621. k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
  622. k with k_cache, starting at the indices specified by cache_seqlens.
  623. v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
  624. qv [optional]: (batch_size, seqlen, nheads, headdim_v)
  625. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
  626. to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
  627. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
  628. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
  629. KV cache.
  630. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
  631. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
  632. If the indices are not distinct, and k and v are provided, the values updated in the cache
  633. might come from any of the duplicate indices.
  634. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
  635. page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
  636. softmax_scale: float. The scaling of QK^T before applying softmax.
  637. Default to 1 / sqrt(headdim).
  638. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  639. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  640. softcap: float. Anything > 0 activates softcapping attention.
  641. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
  642. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
  643. rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
  644. (i.e. GPT-NeoX style).
  645. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
  646. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
  647. to automatically determine the number of splits.
  648. Don't change this unless you know what you are doing.
  649. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
  650. Return:
  651. out: (batch_size, seqlen, nheads, headdim).
  652. softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
  653. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  654. normalization factor).
  655. """
  656. assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
  657. assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
  658. if softmax_scale is None:
  659. softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
  660. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  661. cache_seqlens = torch.full(
  662. (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
  663. )
  664. cache_seqlens = maybe_contiguous(cache_seqlens)
  665. out, softmax_lse, *rest = _flash_attn_forward(
  666. q,
  667. k_cache,
  668. v_cache,
  669. k,
  670. v,
  671. qv,
  672. None, # out
  673. cu_seqlens_q,
  674. None, # cu_seqlens_k
  675. cu_seqlens_k_new,
  676. None, # seqused_q
  677. cache_seqlens,
  678. max_seqlen_q,
  679. None, # max_seqlen_k
  680. page_table,
  681. cache_batch_idx,
  682. cache_leftpad,
  683. rotary_cos,
  684. rotary_sin,
  685. q_descale, k_descale, v_descale,
  686. softmax_scale,
  687. causal=causal,
  688. window_size=window_size,
  689. softcap=softcap,
  690. rotary_interleaved=rotary_interleaved,
  691. num_splits=num_splits,
  692. pack_gqa=pack_gqa,
  693. sm_margin=sm_margin,
  694. )
  695. # return (out, softmax_lse) if return_softmax_lse else out
  696. return (out, softmax_lse, *rest) if return_softmax_lse else out