flash_attn_interface.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flashattn_hopper_cuda
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _flash_attn_forward(
  12. q,
  13. k,
  14. v,
  15. softmax_scale,
  16. causal,
  17. window_size,
  18. descale_q = None,
  19. descale_k = None,
  20. descale_v = None,
  21. gqa_parallel=False
  22. ):
  23. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  24. out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
  25. q,
  26. k,
  27. v,
  28. None,
  29. softmax_scale,
  30. descale_q,
  31. descale_k,
  32. descale_v,
  33. causal,
  34. window_size[0],
  35. window_size[1],
  36. gqa_parallel
  37. )
  38. return out, q, k, v, out_padded, softmax_lse, S_dmask
  39. def _flash_attn_backward(
  40. dout,
  41. q,
  42. k,
  43. v,
  44. out,
  45. softmax_lse,
  46. dq,
  47. dk,
  48. dv,
  49. softmax_scale,
  50. causal,
  51. window_size,
  52. deterministic=False
  53. ):
  54. # dq, dk, dv are allocated by us so they should already be contiguous
  55. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  56. dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd(
  57. dout,
  58. q,
  59. k,
  60. v,
  61. out,
  62. softmax_lse,
  63. dq,
  64. dk,
  65. dv,
  66. softmax_scale,
  67. causal,
  68. window_size[0],
  69. window_size[1],
  70. deterministic,
  71. )
  72. return dq, dk, dv, softmax_d
  73. def _flash_attn_varlen_forward(
  74. q,
  75. k,
  76. v,
  77. cu_seqlens_q,
  78. cu_seqlens_k,
  79. max_seqlen_q,
  80. max_seqlen_k,
  81. softmax_scale,
  82. causal,
  83. window_size=(-1, -1),
  84. seqused_q=None,
  85. seqused_k=None,
  86. descale_q=None,
  87. descale_k=None,
  88. descale_v=None,
  89. ):
  90. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  91. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  92. out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
  93. q,
  94. k,
  95. v,
  96. None,
  97. cu_seqlens_q,
  98. cu_seqlens_k,
  99. seqused_q,
  100. seqused_k,
  101. max_seqlen_q,
  102. max_seqlen_k,
  103. softmax_scale,
  104. descale_q,
  105. descale_k,
  106. descale_v,
  107. causal,
  108. window_size[0],
  109. window_size[1],
  110. )
  111. # if out.isnan().any() or softmax_lse.isnan().any():
  112. # breakpoint()
  113. return out, q, k, v, out_padded, softmax_lse
  114. def _flash_attn_varlen_backward(
  115. dout,
  116. q,
  117. k,
  118. v,
  119. out,
  120. softmax_lse,
  121. dq,
  122. dk,
  123. dv,
  124. cu_seqlens_q,
  125. cu_seqlens_k,
  126. max_seqlen_q,
  127. max_seqlen_k,
  128. softmax_scale,
  129. causal,
  130. window_size,
  131. deterministic=False,
  132. seqused_q=None,
  133. seqused_k=None,
  134. ):
  135. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  136. # dq, dk, dv are allocated by us so they should already be contiguous
  137. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  138. (
  139. dq,
  140. dk,
  141. dv,
  142. softmax_d,
  143. *rest,
  144. ) = flashattn_hopper_cuda.varlen_bwd(
  145. dout,
  146. q,
  147. k,
  148. v,
  149. out,
  150. softmax_lse,
  151. dq,
  152. dk,
  153. dv,
  154. cu_seqlens_q,
  155. cu_seqlens_k,
  156. seqused_q,
  157. seqused_k,
  158. max_seqlen_q,
  159. max_seqlen_k,
  160. softmax_scale,
  161. causal,
  162. window_size[0],
  163. window_size[1],
  164. deterministic,
  165. )
  166. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  167. # breakpoint()
  168. return dq, dk, dv, softmax_d
  169. class FlashAttnFunc(torch.autograd.Function):
  170. @staticmethod
  171. def forward(
  172. ctx,
  173. q,
  174. k,
  175. v,
  176. softmax_scale,
  177. causal,
  178. window_size,
  179. deterministic=False,
  180. descale_q=None,
  181. descale_k=None,
  182. descale_v=None,
  183. gqa_parallel=False,
  184. ):
  185. if softmax_scale is None:
  186. softmax_scale = q.shape[-1] ** (-0.5)
  187. out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
  188. q,
  189. k,
  190. v,
  191. softmax_scale,
  192. causal,
  193. window_size,
  194. descale_q=descale_q,
  195. descale_k=descale_k,
  196. descale_v=descale_v,
  197. gqa_parallel=gqa_parallel,
  198. )
  199. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  200. ctx.softmax_scale = softmax_scale
  201. ctx.causal = causal
  202. ctx.window_size = window_size
  203. ctx.deterministic = deterministic
  204. ctx.gqa_parallel = gqa_parallel
  205. return out, softmax_lse
  206. @staticmethod
  207. def backward(ctx, dout, *args):
  208. q, k, v, out, softmax_lse = ctx.saved_tensors
  209. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  210. _flash_attn_backward(
  211. dout,
  212. q,
  213. k,
  214. v,
  215. out,
  216. softmax_lse,
  217. dq,
  218. dk,
  219. dv,
  220. ctx.softmax_scale,
  221. ctx.causal,
  222. ctx.window_size,
  223. ctx.deterministic,
  224. )
  225. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  226. dk = dk[..., : dout.shape[-1]]
  227. dv = dv[..., : dout.shape[-1]]
  228. return dq, dk, dv, None, None, None, None, None, None, None, None
  229. class FlashAttnVarlenFunc(torch.autograd.Function):
  230. @staticmethod
  231. def forward(
  232. ctx,
  233. q,
  234. k,
  235. v,
  236. cu_seqlens_q,
  237. cu_seqlens_k,
  238. max_seqlen_q,
  239. max_seqlen_k,
  240. softmax_scale,
  241. causal,
  242. window_size,
  243. deterministic=False,
  244. seqused_q=None,
  245. seqused_k=None,
  246. descale_q=None,
  247. descale_k=None,
  248. descale_v=None,
  249. ):
  250. if softmax_scale is None:
  251. softmax_scale = q.shape[-1] ** (-0.5)
  252. out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
  253. q,
  254. k,
  255. v,
  256. cu_seqlens_q,
  257. cu_seqlens_k,
  258. max_seqlen_q,
  259. max_seqlen_k,
  260. softmax_scale,
  261. causal=causal,
  262. window_size=window_size,
  263. seqused_q=seqused_q,
  264. seqused_k=seqused_k,
  265. descale_q=descale_q,
  266. descale_k=descale_k,
  267. descale_v=descale_v,
  268. )
  269. ctx.save_for_backward(
  270. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k,
  271. seqused_q, seqused_k
  272. )
  273. ctx.max_seqlen_q = max_seqlen_q
  274. ctx.max_seqlen_k = max_seqlen_k
  275. ctx.softmax_scale = softmax_scale
  276. ctx.causal = causal
  277. ctx.window_size = window_size
  278. ctx.deterministic = deterministic
  279. return out, softmax_lse
  280. @staticmethod
  281. def backward(ctx, dout, *args):
  282. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
  283. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  284. _flash_attn_varlen_backward(
  285. dout,
  286. q,
  287. k,
  288. v,
  289. out,
  290. softmax_lse,
  291. dq,
  292. dk,
  293. dv,
  294. cu_seqlens_q,
  295. cu_seqlens_k,
  296. ctx.max_seqlen_q,
  297. ctx.max_seqlen_k,
  298. ctx.softmax_scale,
  299. ctx.causal,
  300. ctx.window_size,
  301. ctx.deterministic,
  302. seqused_q,
  303. seqused_k,
  304. )
  305. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  306. dk = dk[..., : dout.shape[-1]]
  307. dv = dv[..., : dout.shape[-1]]
  308. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
  309. def flash_attn_func(
  310. q,
  311. k,
  312. v,
  313. softmax_scale=None,
  314. causal=False,
  315. window_size=(-1, -1),
  316. deterministic=False,
  317. descale_q=None,
  318. descale_k=None,
  319. descale_v=None,
  320. gqa_parallel=False,
  321. ):
  322. """dropout_p should be set to 0.0 during evaluation
  323. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  324. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  325. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  326. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  327. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  328. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  329. 1 1 1 1 0
  330. 1 1 1 1 1
  331. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  332. 0 0
  333. 0 0
  334. 0 0
  335. 1 0
  336. 1 1
  337. If the row of the mask is all zero, the output will be zero.
  338. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  339. will only attend to keys between
  340. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  341. Arguments:
  342. q: (batch_size, seqlen, nheads, headdim)
  343. k: (batch_size, seqlen, nheads_k, headdim)
  344. v: (batch_size, seqlen, nheads_k, headdim)
  345. dropout_p: float. Dropout probability.
  346. softmax_scale: float. The scaling of QK^T before applying softmax.
  347. Default to 1 / sqrt(headdim).
  348. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  349. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  350. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  351. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  352. is added to the attention score of query i and key j.
  353. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  354. which is slightly slower and uses more memory. The forward pass is always deterministic.
  355. descale_q: (1,), fp32. A de-quantization scaling factor for q in fp8 execution.
  356. descale_k: (1,), fp32. A de-quantization scaling factor for k in fp8 execution.
  357. descale_v: (1,), fp32. A de-quantization scaling factor for v in fp8 execution.
  358. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  359. testing only. The returned probabilities are not guaranteed to be correct
  360. (they might not have the right scaling).
  361. Return:
  362. out: (batch_size, seqlen, nheads, headdim).
  363. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  364. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  365. normalization factor).
  366. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  367. The output of softmax (possibly with different scaling). It also encodes the dropout
  368. pattern (negative means that location was dropped, nonnegative means it was kept).
  369. """
  370. return FlashAttnFunc.apply(
  371. q,
  372. k,
  373. v,
  374. softmax_scale,
  375. causal,
  376. window_size,
  377. deterministic,
  378. descale_q,
  379. descale_k,
  380. descale_v,
  381. gqa_parallel
  382. )
  383. def flash_attn_varlen_func(
  384. q,
  385. k,
  386. v,
  387. cu_seqlens_q,
  388. cu_seqlens_k,
  389. max_seqlen_q,
  390. max_seqlen_k,
  391. softmax_scale=None,
  392. causal=False,
  393. window_size=(-1, -1),
  394. deterministic=False,
  395. seqused_q=None,
  396. seqused_k=None,
  397. descale_q=None,
  398. descale_k=None,
  399. descale_v=None,
  400. ):
  401. """
  402. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  403. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  404. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  405. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  406. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  407. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  408. 1 1 1 1 0
  409. 1 1 1 1 1
  410. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  411. 0 0
  412. 0 0
  413. 0 0
  414. 1 0
  415. 1 1
  416. If the row of the mask is all zero, the output will be zero.
  417. Arguments:
  418. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  419. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  420. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  421. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  422. of the sequences in the batch, used to index into q.
  423. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  424. of the sequences in the batch, used to index into kv.
  425. max_seqlen_q: int. Maximum query sequence length in the batch.
  426. max_seqlen_k: int. Maximum key sequence length in the batch.
  427. softmax_scale: float. The scaling of QK^T before applying softmax.
  428. Default to 1 / sqrt(headdim).
  429. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  430. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  431. seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
  432. query and output tokens in each sequence.
  433. seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
  434. key and value tokens in each sequence.
  435. Return:
  436. out: (total, nheads, headdim).
  437. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  438. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  439. normalization factor).
  440. """
  441. return FlashAttnVarlenFunc.apply(
  442. q,
  443. k,
  444. v,
  445. cu_seqlens_q,
  446. cu_seqlens_k,
  447. max_seqlen_q,
  448. max_seqlen_k,
  449. softmax_scale,
  450. causal,
  451. window_size,
  452. deterministic,
  453. seqused_q,
  454. seqused_k,
  455. descale_q,
  456. descale_k,
  457. descale_v,
  458. )
  459. def flash_attn_with_kvcache(
  460. q,
  461. k_cache,
  462. v_cache,
  463. # k=None,
  464. # v=None,
  465. # rotary_cos=None,
  466. # rotary_sin=None,
  467. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  468. cache_batch_idx: Optional[torch.Tensor] = None,
  469. # cache_leftpad: Optional[torch.Tensor] = None,
  470. # block_table: Optional[torch.Tensor] = None,
  471. softmax_scale=None,
  472. causal=False,
  473. window_size=(-1, -1), # -1 means infinite context window
  474. # softcap=0.0, # 0.0 means deactivated
  475. # rotary_interleaved=True,
  476. # alibi_slopes=None,
  477. num_splits=0,
  478. return_softmax_lse=False,
  479. gqa_parallel=None,
  480. max_seqlen_k_hint=None,
  481. descale_q=None,
  482. descale_k=None,
  483. descale_v=None,
  484. ):
  485. """
  486. NOTE: The KV cache API for FlashAttention-3 is a work in progress. We reproduce the description
  487. from the FlashAttention-2 method of the same name below.
  488. If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
  489. k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
  490. the previous step, and update them with the new keys/values from the current step, and do
  491. attention with the updated cache, all in 1 kernel.
  492. If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
  493. For example, the KV cache could be pre-allocated with the max sequence length, and you can use
  494. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
  495. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
  496. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  497. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
  498. and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  499. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
  500. indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
  501. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
  502. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  503. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  504. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  505. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  506. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  507. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  508. 1 1 1 1 0
  509. 1 1 1 1 1
  510. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  511. 0 0
  512. 0 0
  513. 0 0
  514. 1 0
  515. 1 1
  516. If the row of the mask is all zero, the output will be zero.
  517. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  518. will only attend to keys between
  519. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  520. Note: Does not support backward pass.
  521. Arguments:
  522. q: (batch_size, seqlen, nheads, headdim)
  523. k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  524. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  525. page_block_size must be a multiple of 256.
  526. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  527. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  528. k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
  529. k with k_cache, starting at the indices specified by cache_seqlens.
  530. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
  531. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
  532. to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
  533. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
  534. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
  535. KV cache.
  536. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
  537. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
  538. If the indices are not distinct, and k and v are provided, the values updated in the cache
  539. might come from any of the duplicate indices.
  540. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
  541. block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
  542. softmax_scale: float. The scaling of QK^T before applying softmax.
  543. Default to 1 / sqrt(headdim).
  544. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  545. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  546. softcap: float. Anything > 0 activates softcapping attention.
  547. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
  548. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
  549. rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
  550. (i.e. GPT-NeoX style).
  551. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  552. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  553. is added to the attention score of query i and key j.
  554. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
  555. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
  556. to automatically determine the number of splits.
  557. Don't change this unless you know what you are doing.
  558. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
  559. Return:
  560. out: (batch_size, seqlen, nheads, headdim).
  561. softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
  562. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  563. normalization factor).
  564. """
  565. # unimplemented kwargs
  566. k=None
  567. v=None
  568. rotary_cos=None
  569. rotary_sin=None
  570. cache_leftpad=None
  571. block_table=None
  572. softcap=0.0
  573. rotary_interleaved=True
  574. alibi_slopes=None
  575. assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
  576. assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
  577. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  578. if softmax_scale is None:
  579. softmax_scale = q.shape[-1] ** (-0.5)
  580. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  581. cache_seqlens = torch.full(
  582. (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
  583. )
  584. cache_seqlens = maybe_contiguous(cache_seqlens)
  585. cache_batch_idx = maybe_contiguous(cache_batch_idx)
  586. # block_table = maybe_contiguous(block_table)
  587. if gqa_parallel is None:
  588. gqa_parallel = True if q.shape[1] <= 64 else False
  589. # not in gqa/mqa setup
  590. if q.shape[2] == k_cache.shape[2]:
  591. gqa_parallel = False
  592. if max_seqlen_k_hint is None:
  593. max_seqlen_k_hint = k_cache.shape[1]
  594. out, softmax_lse = flashattn_hopper_cuda.fwd_kvcache(
  595. q,
  596. k_cache,
  597. v_cache,
  598. k,
  599. v,
  600. cache_seqlens,
  601. rotary_cos,
  602. rotary_sin,
  603. cache_batch_idx,
  604. cache_leftpad,
  605. block_table,
  606. alibi_slopes,
  607. None,
  608. softmax_scale,
  609. descale_q,
  610. descale_k,
  611. descale_v,
  612. causal,
  613. window_size[0],
  614. window_size[1],
  615. softcap,
  616. rotary_interleaved,
  617. num_splits,
  618. max_seqlen_k_hint,
  619. gqa_parallel
  620. )
  621. return (out, softmax_lse) if return_softmax_lse else out