flash_attn_interface.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. # isort: off
  4. # We need to import the CUDA kernels after importing torch
  5. import flashattn_hopper_cuda
  6. import torch
  7. import torch.nn as nn
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _flash_attn_forward(
  12. q,
  13. k,
  14. v,
  15. softmax_scale,
  16. causal,
  17. window_size,
  18. descale_q=None,
  19. descale_k=None,
  20. descale_v=None,
  21. gqa_parallel=False,
  22. ):
  23. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  24. out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
  25. q,
  26. k,
  27. v,
  28. None,
  29. softmax_scale,
  30. descale_q,
  31. descale_k,
  32. descale_v,
  33. causal,
  34. window_size[0],
  35. window_size[1],
  36. gqa_parallel,
  37. )
  38. return out, q, k, v, out_padded, softmax_lse, S_dmask
  39. def _flash_attn_backward(
  40. dout,
  41. q,
  42. k,
  43. v,
  44. out,
  45. softmax_lse,
  46. dq,
  47. dk,
  48. dv,
  49. softmax_scale,
  50. causal,
  51. window_size,
  52. deterministic=False,
  53. ):
  54. # dq, dk, dv are allocated by us so they should already be contiguous
  55. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  56. dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd(
  57. dout,
  58. q,
  59. k,
  60. v,
  61. out,
  62. softmax_lse,
  63. dq,
  64. dk,
  65. dv,
  66. softmax_scale,
  67. causal,
  68. window_size[0],
  69. window_size[1],
  70. deterministic,
  71. )
  72. return dq, dk, dv, softmax_d
  73. def _flash_attn_varlen_forward(
  74. q,
  75. k,
  76. v,
  77. cu_seqlens_q,
  78. cu_seqlens_k,
  79. max_seqlen_q,
  80. max_seqlen_k,
  81. softmax_scale,
  82. causal,
  83. block_table=None,
  84. window_size=(-1, -1),
  85. seqused_q=None,
  86. seqused_k=None,
  87. ):
  88. assert (
  89. block_table is None or k.dtype != torch.float8_e4m3fn
  90. ), "Paged Attention / block_table is not supported for fp8 just yet"
  91. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  92. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  93. out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
  94. q,
  95. k,
  96. v,
  97. None,
  98. cu_seqlens_q,
  99. cu_seqlens_k,
  100. seqused_q,
  101. seqused_k,
  102. block_table,
  103. max_seqlen_q,
  104. max_seqlen_k,
  105. softmax_scale,
  106. causal,
  107. window_size[0],
  108. window_size[1],
  109. )
  110. # if out.isnan().any() or softmax_lse.isnan().any():
  111. # breakpoint()
  112. return out, q, k, v, out_padded, softmax_lse
  113. def _flash_attn_varlen_backward(
  114. dout,
  115. q,
  116. k,
  117. v,
  118. out,
  119. softmax_lse,
  120. dq,
  121. dk,
  122. dv,
  123. cu_seqlens_q,
  124. cu_seqlens_k,
  125. max_seqlen_q,
  126. max_seqlen_k,
  127. softmax_scale,
  128. causal,
  129. window_size,
  130. deterministic=False,
  131. seqused_q=None,
  132. seqused_k=None,
  133. ):
  134. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  135. # dq, dk, dv are allocated by us so they should already be contiguous
  136. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  137. (
  138. dq,
  139. dk,
  140. dv,
  141. softmax_d,
  142. *rest,
  143. ) = flashattn_hopper_cuda.varlen_bwd(
  144. dout,
  145. q,
  146. k,
  147. v,
  148. out,
  149. softmax_lse,
  150. dq,
  151. dk,
  152. dv,
  153. cu_seqlens_q,
  154. cu_seqlens_k,
  155. seqused_q,
  156. seqused_k,
  157. max_seqlen_q,
  158. max_seqlen_k,
  159. softmax_scale,
  160. causal,
  161. window_size[0],
  162. window_size[1],
  163. deterministic,
  164. )
  165. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  166. # breakpoint()
  167. return dq, dk, dv, softmax_d
  168. class FlashAttnFunc(torch.autograd.Function):
  169. @staticmethod
  170. def forward(
  171. ctx,
  172. q,
  173. k,
  174. v,
  175. softmax_scale,
  176. causal,
  177. window_size,
  178. deterministic=False,
  179. descale_q=None,
  180. descale_k=None,
  181. descale_v=None,
  182. gqa_parallel=False,
  183. ):
  184. if softmax_scale is None:
  185. softmax_scale = q.shape[-1] ** (-0.5)
  186. out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
  187. q,
  188. k,
  189. v,
  190. softmax_scale,
  191. causal,
  192. window_size,
  193. descale_q=descale_q,
  194. descale_k=descale_k,
  195. descale_v=descale_v,
  196. gqa_parallel=gqa_parallel,
  197. )
  198. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  199. ctx.softmax_scale = softmax_scale
  200. ctx.causal = causal
  201. ctx.window_size = window_size
  202. ctx.deterministic = deterministic
  203. ctx.gqa_parallel = gqa_parallel
  204. return out, softmax_lse
  205. @staticmethod
  206. def backward(ctx, dout, *args):
  207. q, k, v, out, softmax_lse = ctx.saved_tensors
  208. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  209. _flash_attn_backward(
  210. dout,
  211. q,
  212. k,
  213. v,
  214. out,
  215. softmax_lse,
  216. dq,
  217. dk,
  218. dv,
  219. ctx.softmax_scale,
  220. ctx.causal,
  221. ctx.window_size,
  222. ctx.deterministic,
  223. )
  224. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  225. dk = dk[..., : dout.shape[-1]]
  226. dv = dv[..., : dout.shape[-1]]
  227. return dq, dk, dv, None, None, None, None, None, None, None, None
  228. class FlashAttnVarlenFunc(torch.autograd.Function):
  229. @staticmethod
  230. def forward(
  231. ctx,
  232. q,
  233. k,
  234. v,
  235. cu_seqlens_q,
  236. cu_seqlens_k,
  237. max_seqlen_q,
  238. max_seqlen_k,
  239. softmax_scale,
  240. causal,
  241. window_size,
  242. deterministic=False,
  243. seqused_q=None,
  244. seqused_k=None,
  245. block_table=None,
  246. ):
  247. if softmax_scale is None:
  248. softmax_scale = q.shape[-1] ** (-0.5)
  249. out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
  250. q,
  251. k,
  252. v,
  253. cu_seqlens_q,
  254. cu_seqlens_k,
  255. max_seqlen_q,
  256. max_seqlen_k,
  257. softmax_scale,
  258. causal=causal,
  259. window_size=window_size,
  260. seqused_q=seqused_q,
  261. seqused_k=seqused_k,
  262. block_table=block_table,
  263. )
  264. ctx.save_for_backward(
  265. q,
  266. k,
  267. v,
  268. out_padded,
  269. softmax_lse,
  270. cu_seqlens_q,
  271. cu_seqlens_k,
  272. seqused_q,
  273. seqused_k,
  274. )
  275. ctx.max_seqlen_q = max_seqlen_q
  276. ctx.max_seqlen_k = max_seqlen_k
  277. ctx.softmax_scale = softmax_scale
  278. ctx.causal = causal
  279. ctx.window_size = window_size
  280. ctx.deterministic = deterministic
  281. return out, softmax_lse
  282. @staticmethod
  283. def backward(ctx, dout, *args):
  284. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = (
  285. ctx.saved_tensors
  286. )
  287. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  288. _flash_attn_varlen_backward(
  289. dout,
  290. q,
  291. k,
  292. v,
  293. out,
  294. softmax_lse,
  295. dq,
  296. dk,
  297. dv,
  298. cu_seqlens_q,
  299. cu_seqlens_k,
  300. ctx.max_seqlen_q,
  301. ctx.max_seqlen_k,
  302. ctx.softmax_scale,
  303. ctx.causal,
  304. ctx.window_size,
  305. ctx.deterministic,
  306. seqused_q,
  307. seqused_k,
  308. )
  309. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  310. dk = dk[..., : dout.shape[-1]]
  311. dv = dv[..., : dout.shape[-1]]
  312. return (
  313. dq,
  314. dk,
  315. dv,
  316. None,
  317. None,
  318. None,
  319. None,
  320. None,
  321. None,
  322. None,
  323. None,
  324. None,
  325. None,
  326. None,
  327. )
  328. def flash_attn_func(
  329. q,
  330. k,
  331. v,
  332. softmax_scale=None,
  333. causal=False,
  334. window_size=(-1, -1),
  335. deterministic=False,
  336. descale_q=None,
  337. descale_k=None,
  338. descale_v=None,
  339. gqa_parallel=False,
  340. ):
  341. """dropout_p should be set to 0.0 during evaluation
  342. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  343. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  344. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  345. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  346. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  347. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  348. 1 1 1 1 0
  349. 1 1 1 1 1
  350. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  351. 0 0
  352. 0 0
  353. 0 0
  354. 1 0
  355. 1 1
  356. If the row of the mask is all zero, the output will be zero.
  357. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  358. will only attend to keys between
  359. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  360. Arguments:
  361. q: (batch_size, seqlen, nheads, headdim)
  362. k: (batch_size, seqlen, nheads_k, headdim)
  363. v: (batch_size, seqlen, nheads_k, headdim)
  364. dropout_p: float. Dropout probability.
  365. softmax_scale: float. The scaling of QK^T before applying softmax.
  366. Default to 1 / sqrt(headdim).
  367. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  368. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  369. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  370. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  371. is added to the attention score of query i and key j.
  372. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  373. which is slightly slower and uses more memory. The forward pass is always deterministic.
  374. descale_q: (1,), fp32. A de-quantization scaling factor for q in fp8 execution.
  375. descale_k: (1,), fp32. A de-quantization scaling factor for k in fp8 execution.
  376. descale_v: (1,), fp32. A de-quantization scaling factor for v in fp8 execution.
  377. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  378. testing only. The returned probabilities are not guaranteed to be correct
  379. (they might not have the right scaling).
  380. Return:
  381. out: (batch_size, seqlen, nheads, headdim).
  382. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  383. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  384. normalization factor).
  385. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  386. The output of softmax (possibly with different scaling). It also encodes the dropout
  387. pattern (negative means that location was dropped, nonnegative means it was kept).
  388. """
  389. return FlashAttnFunc.apply(
  390. q,
  391. k,
  392. v,
  393. softmax_scale,
  394. causal,
  395. window_size,
  396. deterministic,
  397. descale_q,
  398. descale_k,
  399. descale_v,
  400. gqa_parallel,
  401. )
  402. def flash_attn_varlen_func(
  403. q,
  404. k,
  405. v,
  406. cu_seqlens_q,
  407. cu_seqlens_k,
  408. max_seqlen_q,
  409. max_seqlen_k,
  410. softmax_scale=None,
  411. causal=False,
  412. window_size=(-1, -1),
  413. deterministic=False,
  414. seqused_q=None,
  415. seqused_k=None,
  416. block_table=None,
  417. ):
  418. """
  419. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  420. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  421. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  422. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  423. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  424. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  425. 1 1 1 1 0
  426. 1 1 1 1 1
  427. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  428. 0 0
  429. 0 0
  430. 0 0
  431. 1 0
  432. 1 1
  433. If the row of the mask is all zero, the output will be zero.
  434. Arguments:
  435. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  436. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  437. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  438. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  439. of the sequences in the batch, used to index into q.
  440. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  441. of the sequences in the batch, used to index into kv.
  442. max_seqlen_q: int. Maximum query sequence length in the batch.
  443. max_seqlen_k: int. Maximum key sequence length in the batch.
  444. softmax_scale: float. The scaling of QK^T before applying softmax.
  445. Default to 1 / sqrt(headdim).
  446. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  447. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  448. seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
  449. query and output tokens in each sequence.
  450. seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
  451. key and value tokens in each sequence.
  452. block_table: Optional block_table of dtype int32 and shape [batch_size, num_blocks_per_seq] used for paged attention.
  453. Return:
  454. out: (total, nheads, headdim).
  455. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  456. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  457. normalization factor).
  458. """
  459. return FlashAttnVarlenFunc.apply(
  460. q,
  461. k,
  462. v,
  463. cu_seqlens_q,
  464. cu_seqlens_k,
  465. max_seqlen_q,
  466. max_seqlen_k,
  467. softmax_scale,
  468. causal,
  469. window_size,
  470. deterministic,
  471. seqused_q,
  472. seqused_k,
  473. block_table,
  474. )
  475. def flash_attn_with_kvcache(
  476. q,
  477. k_cache,
  478. v_cache,
  479. # k=None,
  480. # v=None,
  481. # rotary_cos=None,
  482. # rotary_sin=None,
  483. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  484. cache_batch_idx: Optional[torch.Tensor] = None,
  485. # cache_leftpad: Optional[torch.Tensor] = None,
  486. # block_table: Optional[torch.Tensor] = None,
  487. softmax_scale=None,
  488. causal=False,
  489. window_size=(-1, -1), # -1 means infinite context window
  490. # softcap=0.0, # 0.0 means deactivated
  491. # rotary_interleaved=True,
  492. # alibi_slopes=None,
  493. num_splits=0,
  494. return_softmax_lse=False,
  495. gqa_parallel=None,
  496. max_seqlen_k_hint=None,
  497. descale_q=None,
  498. descale_k=None,
  499. descale_v=None,
  500. ):
  501. """
  502. NOTE: The KV cache API for FlashAttention-3 is a work in progress. We reproduce the description
  503. from the FlashAttention-2 method of the same name below.
  504. If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
  505. k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
  506. the previous step, and update them with the new keys/values from the current step, and do
  507. attention with the updated cache, all in 1 kernel.
  508. If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
  509. For example, the KV cache could be pre-allocated with the max sequence length, and you can use
  510. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
  511. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
  512. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  513. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
  514. and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  515. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
  516. indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
  517. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
  518. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  519. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  520. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  521. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  522. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  523. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  524. 1 1 1 1 0
  525. 1 1 1 1 1
  526. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  527. 0 0
  528. 0 0
  529. 0 0
  530. 1 0
  531. 1 1
  532. If the row of the mask is all zero, the output will be zero.
  533. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  534. will only attend to keys between
  535. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  536. Note: Does not support backward pass.
  537. Arguments:
  538. q: (batch_size, seqlen, nheads, headdim)
  539. k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  540. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  541. page_block_size must be a multiple of 256.
  542. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  543. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  544. k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
  545. k with k_cache, starting at the indices specified by cache_seqlens.
  546. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
  547. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
  548. to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
  549. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
  550. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
  551. KV cache.
  552. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
  553. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
  554. If the indices are not distinct, and k and v are provided, the values updated in the cache
  555. might come from any of the duplicate indices.
  556. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
  557. block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
  558. softmax_scale: float. The scaling of QK^T before applying softmax.
  559. Default to 1 / sqrt(headdim).
  560. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  561. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  562. softcap: float. Anything > 0 activates softcapping attention.
  563. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
  564. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
  565. rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
  566. (i.e. GPT-NeoX style).
  567. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  568. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  569. is added to the attention score of query i and key j.
  570. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
  571. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
  572. to automatically determine the number of splits.
  573. Don't change this unless you know what you are doing.
  574. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
  575. Return:
  576. out: (batch_size, seqlen, nheads, headdim).
  577. softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
  578. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  579. normalization factor).
  580. """
  581. # unimplemented kwargs
  582. k = None
  583. v = None
  584. rotary_cos = None
  585. rotary_sin = None
  586. cache_leftpad = None
  587. block_table = None
  588. softcap = 0.0
  589. rotary_interleaved = True
  590. alibi_slopes = None
  591. assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
  592. assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
  593. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  594. if softmax_scale is None:
  595. softmax_scale = q.shape[-1] ** (-0.5)
  596. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  597. cache_seqlens = torch.full(
  598. (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
  599. )
  600. cache_seqlens = maybe_contiguous(cache_seqlens)
  601. cache_batch_idx = maybe_contiguous(cache_batch_idx)
  602. # block_table = maybe_contiguous(block_table)
  603. if gqa_parallel is None:
  604. gqa_parallel = True if q.shape[1] <= 64 else False
  605. # not in gqa/mqa setup
  606. if q.shape[2] == k_cache.shape[2]:
  607. gqa_parallel = False
  608. if max_seqlen_k_hint is None:
  609. max_seqlen_k_hint = k_cache.shape[1]
  610. out, softmax_lse = flashattn_hopper_cuda.fwd_kvcache(
  611. q,
  612. k_cache,
  613. v_cache,
  614. k,
  615. v,
  616. cache_seqlens,
  617. rotary_cos,
  618. rotary_sin,
  619. cache_batch_idx,
  620. cache_leftpad,
  621. block_table,
  622. alibi_slopes,
  623. None,
  624. softmax_scale,
  625. descale_q,
  626. descale_k,
  627. descale_v,
  628. causal,
  629. window_size[0],
  630. window_size[1],
  631. softcap,
  632. rotary_interleaved,
  633. num_splits,
  634. max_seqlen_k_hint,
  635. gqa_parallel,
  636. )
  637. return (out, softmax_lse) if return_softmax_lse else out