1
0

flash_attn_interface.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flashattn_hopper_cuda
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _flash_attn_forward(q, k, v, softmax_scale, causal,
  12. q_descale=None, k_descale=None, v_descale=None,
  13. window_size=(-1, -1),
  14. sink_token_length=0,
  15. softcap=0.0,
  16. num_splits=1,
  17. pack_gqa=None):
  18. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  19. q, k = [maybe_contiguous(x) for x in (q, k)]
  20. v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
  21. out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.fwd(
  22. q,
  23. k,
  24. v,
  25. None,
  26. softmax_scale,
  27. causal,
  28. q_descale, k_descale, v_descale,
  29. window_size[0], window_size[1], sink_token_length,
  30. softcap,
  31. num_splits,
  32. pack_gqa
  33. )
  34. return out, q, k, v, out_padded, softmax_lse
  35. def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal,
  36. q_descale=None, k_descale=None, v_descale=None,
  37. window_size=(-1, -1), softcap=0.0,
  38. num_splits=1,
  39. pack_gqa=None):
  40. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  41. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  42. out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.fwd_varlen(
  43. q,
  44. k,
  45. v,
  46. None,
  47. cu_seqlens_q, cu_seqlens_k, None, None, max_seqlen_q, max_seqlen_k,
  48. softmax_scale,
  49. causal,
  50. q_descale, k_descale, v_descale,
  51. window_size[0], window_size[1],
  52. softcap,
  53. num_splits,
  54. pack_gqa
  55. )
  56. # breakpoint()
  57. return out, q, k, v, out_padded, softmax_lse
  58. def _flash_attn_backward(
  59. dout,
  60. q,
  61. k,
  62. v,
  63. out,
  64. softmax_lse,
  65. dq,
  66. dk,
  67. dv,
  68. softmax_scale,
  69. causal,
  70. window_size=(-1, -1),
  71. sink_token_length=0,
  72. softcap=0.0,
  73. deterministic=False
  74. ):
  75. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  76. # dq, dk, dv are allocated by us so they should already be contiguous
  77. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  78. dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd(
  79. dout,
  80. q,
  81. k,
  82. v,
  83. out,
  84. softmax_lse,
  85. dq,
  86. dk,
  87. dv,
  88. softmax_scale,
  89. causal,
  90. window_size[0],
  91. window_size[1],
  92. sink_token_length,
  93. softcap,
  94. deterministic,
  95. )
  96. return dq, dk, dv, softmax_d
  97. def _flash_attn_varlen_backward(
  98. dout,
  99. q,
  100. k,
  101. v,
  102. out,
  103. softmax_lse,
  104. cu_seqlens_q,
  105. cu_seqlens_k,
  106. max_seqlen_q,
  107. max_seqlen_k,
  108. dq,
  109. dk,
  110. dv,
  111. softmax_scale,
  112. causal,
  113. window_size=(-1, -1),
  114. softcap=0.0,
  115. deterministic=False
  116. ):
  117. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  118. # dq, dk, dv are allocated by us so they should already be contiguous
  119. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  120. dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd_varlen(
  121. dout,
  122. q,
  123. k,
  124. v,
  125. out,
  126. softmax_lse,
  127. dq,
  128. dk,
  129. dv,
  130. cu_seqlens_q,
  131. cu_seqlens_k,
  132. None, None,
  133. max_seqlen_q,
  134. max_seqlen_k,
  135. softmax_scale,
  136. causal,
  137. window_size[0],
  138. window_size[1],
  139. softcap,
  140. deterministic,
  141. )
  142. return dq, dk, dv, softmax_d
  143. class FlashAttnQKVPackedFunc(torch.autograd.Function):
  144. @staticmethod
  145. def forward(
  146. ctx,
  147. qkv,
  148. softmax_scale,
  149. causal,
  150. q_descale=None, k_descale=None, v_descale=None,
  151. window_size=(-1, -1),
  152. sink_token_length=0,
  153. softcap=0.0,
  154. deterministic=False,
  155. num_heads_q=None,
  156. ):
  157. if softmax_scale is None:
  158. softmax_scale = qkv.shape[-1] ** (-0.5)
  159. if qkv.dim() == 5:
  160. assert qkv.shape[-3] == 3
  161. q, k, v = qkv.unbind(dim=-3)
  162. else:
  163. assert qkv.dim() == 4
  164. assert num_heads_q is not None
  165. num_heads_k = (qkv.shape[2] - num_heads_q) // 2
  166. assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
  167. q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
  168. out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
  169. q,
  170. k,
  171. v,
  172. softmax_scale,
  173. causal=causal,
  174. q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
  175. window_size=window_size, sink_token_length=sink_token_length,
  176. softcap=softcap,
  177. )
  178. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  179. ctx.softmax_scale = softmax_scale
  180. ctx.causal = causal
  181. ctx.window_size = window_size
  182. ctx.sink_token_length = sink_token_length
  183. ctx.softcap = softcap
  184. ctx.deterministic = deterministic
  185. ctx.ndim = qkv.dim()
  186. # return out, softmax_lse
  187. return out
  188. @staticmethod
  189. def backward(ctx, dout, *args):
  190. q, k, v, out, softmax_lse = ctx.saved_tensors
  191. if ctx.ndim == 5:
  192. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  193. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  194. dq, dk, dv = dqkv.unbind(dim=-3)
  195. else:
  196. num_heads_q = q.shape[2]
  197. num_heads_k = k.shape[2]
  198. qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
  199. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  200. dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
  201. _flash_attn_backward(
  202. dout,
  203. q,
  204. k,
  205. v,
  206. out,
  207. softmax_lse,
  208. dq,
  209. dk,
  210. dv,
  211. ctx.softmax_scale,
  212. ctx.causal,
  213. ctx.window_size,
  214. ctx.sink_token_length,
  215. ctx.softcap,
  216. ctx.deterministic,
  217. )
  218. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  219. return dqkv, None, None, None, None, None, None, None, None, None, None
  220. class FlashAttnFunc(torch.autograd.Function):
  221. @staticmethod
  222. def forward(
  223. ctx,
  224. q,
  225. k,
  226. v,
  227. softmax_scale,
  228. causal,
  229. q_descale=None, k_descale=None, v_descale=None,
  230. window_size=(-1, -1),
  231. sink_token_length=0,
  232. softcap=0.0,
  233. num_splits=1,
  234. pack_gqa=None,
  235. deterministic=False,
  236. ):
  237. if softmax_scale is None:
  238. softmax_scale = q.shape[-1] ** (-0.5)
  239. out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
  240. q,
  241. k,
  242. v,
  243. softmax_scale,
  244. causal=causal,
  245. q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
  246. window_size=window_size,
  247. sink_token_length=sink_token_length,
  248. softcap=softcap,
  249. num_splits=num_splits,
  250. pack_gqa=pack_gqa,
  251. )
  252. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  253. ctx.softmax_scale = softmax_scale
  254. ctx.causal = causal
  255. ctx.window_size = window_size
  256. ctx.sink_token_length = sink_token_length
  257. ctx.softcap = softcap
  258. ctx.deterministic = deterministic
  259. return out, softmax_lse
  260. @staticmethod
  261. def backward(ctx, dout, *args):
  262. q, k, v, out, softmax_lse = ctx.saved_tensors
  263. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  264. _flash_attn_backward(
  265. dout,
  266. q,
  267. k,
  268. v,
  269. out,
  270. softmax_lse,
  271. dq,
  272. dk,
  273. dv,
  274. ctx.softmax_scale,
  275. ctx.causal,
  276. ctx.window_size,
  277. ctx.sink_token_length,
  278. ctx.softcap,
  279. ctx.deterministic,
  280. )
  281. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  282. dk = dk[..., : dout.shape[-1]]
  283. dv = dv[..., : dout.shape[-1]]
  284. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
  285. class FlashAttnVarlenFunc(torch.autograd.Function):
  286. @staticmethod
  287. def forward(
  288. ctx,
  289. q,
  290. k,
  291. v,
  292. cu_seqlens_q,
  293. cu_seqlens_k,
  294. max_seqlen_q,
  295. max_seqlen_k,
  296. softmax_scale,
  297. causal,
  298. q_descale=None, k_descale=None, v_descale=None,
  299. window_size=(-1, -1),
  300. softcap=0.0,
  301. num_splits=1,
  302. pack_gqa=None,
  303. deterministic=False,
  304. ):
  305. if softmax_scale is None:
  306. softmax_scale = q.shape[-1] ** (-0.5)
  307. out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
  308. q,
  309. k,
  310. v,
  311. cu_seqlens_q,
  312. cu_seqlens_k,
  313. max_seqlen_q,
  314. max_seqlen_k,
  315. softmax_scale,
  316. causal=causal,
  317. q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
  318. window_size=window_size,
  319. softcap=softcap,
  320. num_splits=num_splits,
  321. pack_gqa=pack_gqa,
  322. )
  323. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k)
  324. ctx.max_seqlen_q = max_seqlen_q
  325. ctx.max_seqlen_k = max_seqlen_k
  326. ctx.softmax_scale = softmax_scale
  327. ctx.causal = causal
  328. ctx.window_size = window_size
  329. ctx.softcap = softcap
  330. ctx.deterministic = deterministic
  331. return out, softmax_lse
  332. @staticmethod
  333. def backward(ctx, dout, *args):
  334. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
  335. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  336. _flash_attn_varlen_backward(
  337. dout,
  338. q,
  339. k,
  340. v,
  341. out,
  342. softmax_lse,
  343. cu_seqlens_q,
  344. cu_seqlens_k,
  345. ctx.max_seqlen_q,
  346. ctx.max_seqlen_k,
  347. dq,
  348. dk,
  349. dv,
  350. ctx.softmax_scale,
  351. ctx.causal,
  352. ctx.window_size,
  353. ctx.softcap,
  354. ctx.deterministic,
  355. )
  356. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  357. dk = dk[..., : dout.shape[-1]]
  358. dv = dv[..., : dout.shape[-1]]
  359. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
  360. def flash_attn_qkvpacked_func(
  361. qkv,
  362. softmax_scale=None,
  363. causal=False,
  364. q_descale=None, k_descale=None, v_descale=None,
  365. window_size=(-1, -1),
  366. sink_token_length=0,
  367. softcap=0.0,
  368. deterministic=False,
  369. num_heads_q=None,
  370. ):
  371. """dropout_p should be set to 0.0 during evaluation
  372. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  373. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  374. of the gradients of Q, K, V.
  375. For multi-query and grouped-query attention (MQA/GQA), please see
  376. flash_attn_kvpacked_func and flash_attn_func.
  377. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  378. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  379. Arguments:
  380. qkv: (batch_size, seqlen, 3, nheads, headdim)
  381. dropout_p: float. Dropout probability.
  382. softmax_scale: float. The scaling of QK^T before applying softmax.
  383. Default to 1 / sqrt(headdim).
  384. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  385. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  386. softcap: float. Anything > 0 activates softcapping attention.
  387. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
  388. the attention score of query i and key j.
  389. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  390. which is slightly slower and uses more memory. The forward pass is always deterministic.
  391. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  392. testing only. The returned probabilities are not guaranteed to be correct
  393. (they might not have the right scaling).
  394. Return:
  395. out: (batch_size, seqlen, nheads, headdim).
  396. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  397. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  398. normalization factor).
  399. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  400. The output of softmax (possibly with different scaling). It also encodes the dropout
  401. pattern (negative means that location was dropped, nonnegative means it was kept).
  402. """
  403. return FlashAttnQKVPackedFunc.apply(
  404. qkv,
  405. softmax_scale,
  406. causal,
  407. q_descale, k_descale, v_descale,
  408. window_size,
  409. sink_token_length,
  410. softcap,
  411. deterministic,
  412. num_heads_q,
  413. )
  414. def flash_attn_func(
  415. q,
  416. k,
  417. v,
  418. softmax_scale=None,
  419. causal=False,
  420. q_descale=None, k_descale=None, v_descale=None,
  421. window_size=(-1, -1),
  422. sink_token_length=0,
  423. softcap=0.0,
  424. num_splits=1,
  425. pack_gqa=None,
  426. deterministic=False
  427. ):
  428. """dropout_p should be set to 0.0 during evaluation
  429. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  430. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  431. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  432. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  433. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  434. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  435. 1 1 1 1 0
  436. 1 1 1 1 1
  437. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  438. 0 0
  439. 0 0
  440. 0 0
  441. 1 0
  442. 1 1
  443. If the row of the mask is all zero, the output will be zero.
  444. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  445. will only attend to keys between
  446. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  447. Arguments:
  448. q: (batch_size, seqlen, nheads, headdim)
  449. k: (batch_size, seqlen, nheads_k, headdim)
  450. v: (batch_size, seqlen, nheads_k, headdim)
  451. dropout_p: float. Dropout probability.
  452. softmax_scale: float. The scaling of QK^T before applying softmax.
  453. Default to 1 / sqrt(headdim).
  454. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  455. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  456. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  457. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  458. is added to the attention score of query i and key j.
  459. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  460. which is slightly slower and uses more memory. The forward pass is always deterministic.
  461. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  462. testing only. The returned probabilities are not guaranteed to be correct
  463. (they might not have the right scaling).
  464. Return:
  465. out: (batch_size, seqlen, nheads, headdim).
  466. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  467. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  468. normalization factor).
  469. """
  470. return FlashAttnFunc.apply(
  471. q,
  472. k,
  473. v,
  474. softmax_scale,
  475. causal,
  476. q_descale, k_descale, v_descale,
  477. window_size,
  478. sink_token_length,
  479. softcap,
  480. num_splits,
  481. pack_gqa,
  482. deterministic,
  483. )
  484. def flash_attn_varlen_func(
  485. q,
  486. k,
  487. v,
  488. cu_seqlens_q,
  489. cu_seqlens_k,
  490. max_seqlen_q,
  491. max_seqlen_k,
  492. softmax_scale=None,
  493. causal=False,
  494. q_descale=None, k_descale=None, v_descale=None,
  495. window_size=(-1, -1),
  496. softcap=0.0,
  497. num_splits=1,
  498. pack_gqa=None,
  499. deterministic=False
  500. ):
  501. return FlashAttnVarlenFunc.apply(
  502. q,
  503. k,
  504. v,
  505. cu_seqlens_q,
  506. cu_seqlens_k,
  507. max_seqlen_q,
  508. max_seqlen_k,
  509. softmax_scale,
  510. causal,
  511. q_descale, k_descale, v_descale,
  512. window_size,
  513. softcap,
  514. num_splits,
  515. pack_gqa,
  516. deterministic,
  517. )
  518. def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
  519. return flashattn_hopper_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
  520. def flash_attn_with_kvcache(
  521. q,
  522. k_cache,
  523. v_cache,
  524. k=None,
  525. v=None,
  526. rotary_cos=None,
  527. rotary_sin=None,
  528. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  529. cache_batch_idx: Optional[torch.Tensor] = None,
  530. cache_leftpad: Optional[torch.Tensor] = None,
  531. page_table: Optional[torch.Tensor] = None,
  532. cu_seqlens_q: Optional[torch.Tensor] = None,
  533. max_seqlen_q: Optional[int] = None,
  534. softmax_scale=None,
  535. causal=False,
  536. window_size=(-1, -1), # -1 means infinite context window
  537. sink_token_length=0,
  538. softcap=0.0, # 0.0 means deactivated
  539. rotary_interleaved=True,
  540. num_splits=0, # Can be tuned for speed
  541. pack_gqa=None, # Can be tuned for speed
  542. return_softmax_lse=False,
  543. ):
  544. """
  545. If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
  546. k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
  547. the previous step, and update them with the new keys/values from the current step, and do
  548. attention with the updated cache, all in 1 kernel.
  549. If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
  550. For example, the KV cache could be pre-allocated with the max sequence length, and you can use
  551. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
  552. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
  553. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  554. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
  555. and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  556. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
  557. indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
  558. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
  559. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  560. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  561. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  562. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  563. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  564. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  565. 1 1 1 1 0
  566. 1 1 1 1 1
  567. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  568. 0 0
  569. 0 0
  570. 0 0
  571. 1 0
  572. 1 1
  573. If the row of the mask is all zero, the output will be zero.
  574. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  575. will only attend to keys between
  576. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  577. Note: Does not support backward pass.
  578. Arguments:
  579. q: (batch_size, seqlen, nheads, headdim)
  580. k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
  581. or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
  582. page_block_size must be a multiple of 256.
  583. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no _table,
  584. or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
  585. k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
  586. k with k_cache, starting at the indices specified by cache_seqlens.
  587. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
  588. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
  589. to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
  590. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
  591. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
  592. KV cache.
  593. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
  594. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
  595. If the indices are not distinct, and k and v are provided, the values updated in the cache
  596. might come from any of the duplicate indices.
  597. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
  598. page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
  599. softmax_scale: float. The scaling of QK^T before applying softmax.
  600. Default to 1 / sqrt(headdim).
  601. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  602. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  603. softcap: float. Anything > 0 activates softcapping attention.
  604. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
  605. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
  606. rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
  607. (i.e. GPT-NeoX style).
  608. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
  609. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
  610. to automatically determine the number of splits.
  611. Don't change this unless you know what you are doing.
  612. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
  613. Return:
  614. out: (batch_size, seqlen, nheads, headdim).
  615. softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
  616. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  617. normalization factor).
  618. """
  619. assert sink_token_length == 0
  620. assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
  621. assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
  622. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  623. if softmax_scale is None:
  624. softmax_scale = q.shape[-1] ** (-0.5)
  625. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  626. cache_seqlens = torch.full(
  627. (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
  628. )
  629. cache_seqlens = maybe_contiguous(cache_seqlens)
  630. cache_batch_idx = maybe_contiguous(cache_batch_idx)
  631. page_table = maybe_contiguous(page_table)
  632. cu_seqlens_q = maybe_contiguous(cu_seqlens_q)
  633. out, softmax_lse, *rest = flashattn_hopper_cuda.fwd_kvcache(
  634. q,
  635. k_cache,
  636. v_cache,
  637. k,
  638. v,
  639. None, # out
  640. cache_seqlens,
  641. rotary_cos,
  642. rotary_sin,
  643. cache_batch_idx,
  644. cache_leftpad,
  645. page_table,
  646. cu_seqlens_q,
  647. max_seqlen_q,
  648. softmax_scale,
  649. causal,
  650. None, None, None, # qkv_descale
  651. window_size[0],
  652. window_size[1],
  653. sink_token_length,
  654. softcap,
  655. rotary_interleaved,
  656. num_splits,
  657. pack_gqa
  658. )
  659. # return (out, softmax_lse) if return_softmax_lse else out
  660. return (out, softmax_lse, *rest) if return_softmax_lse else out