flash_attn_interface.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flashattn_hopper_cuda
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q = None, descale_k = None, descale_v = None, gqa_parallel=False):
  12. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  13. out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
  14. q,
  15. k,
  16. v,
  17. None,
  18. softmax_scale,
  19. descale_q,
  20. descale_k,
  21. descale_v,
  22. causal,
  23. window_size[0],
  24. window_size[1],
  25. gqa_parallel
  26. )
  27. return out, q, k, v, out_padded, softmax_lse, S_dmask
  28. def _flash_attn_backward(
  29. dout,
  30. q,
  31. k,
  32. v,
  33. out,
  34. softmax_lse,
  35. dq,
  36. dk,
  37. dv,
  38. softmax_scale,
  39. causal,
  40. window_size,
  41. deterministic=False
  42. ):
  43. # dq, dk, dv are allocated by us so they should already be contiguous
  44. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  45. dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd(
  46. dout,
  47. q,
  48. k,
  49. v,
  50. out,
  51. softmax_lse,
  52. dq,
  53. dk,
  54. dv,
  55. softmax_scale,
  56. causal,
  57. window_size[0],
  58. window_size[1],
  59. deterministic,
  60. )
  61. return dq, dk, dv, softmax_d
  62. def _flash_attn_varlen_forward(
  63. q,
  64. k,
  65. v,
  66. cu_seqlens_q,
  67. cu_seqlens_k,
  68. max_seqlen_q,
  69. max_seqlen_k,
  70. softmax_scale,
  71. causal,
  72. window_size=(-1, -1),
  73. seqused_q=None,
  74. seqused_k=None,
  75. ):
  76. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  77. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  78. out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
  79. q,
  80. k,
  81. v,
  82. None,
  83. cu_seqlens_q,
  84. cu_seqlens_k,
  85. seqused_q,
  86. seqused_k,
  87. max_seqlen_q,
  88. max_seqlen_k,
  89. softmax_scale,
  90. causal,
  91. window_size[0],
  92. window_size[1],
  93. )
  94. # if out.isnan().any() or softmax_lse.isnan().any():
  95. # breakpoint()
  96. return out, q, k, v, out_padded, softmax_lse
  97. def _flash_attn_varlen_backward(
  98. dout,
  99. q,
  100. k,
  101. v,
  102. out,
  103. softmax_lse,
  104. dq,
  105. dk,
  106. dv,
  107. cu_seqlens_q,
  108. cu_seqlens_k,
  109. max_seqlen_q,
  110. max_seqlen_k,
  111. softmax_scale,
  112. causal,
  113. window_size,
  114. deterministic=False,
  115. seqused_q=None,
  116. seqused_k=None,
  117. ):
  118. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  119. # dq, dk, dv are allocated by us so they should already be contiguous
  120. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  121. (
  122. dq,
  123. dk,
  124. dv,
  125. softmax_d,
  126. *rest,
  127. ) = flashattn_hopper_cuda.varlen_bwd(
  128. dout,
  129. q,
  130. k,
  131. v,
  132. out,
  133. softmax_lse,
  134. dq,
  135. dk,
  136. dv,
  137. cu_seqlens_q,
  138. cu_seqlens_k,
  139. seqused_q,
  140. seqused_k,
  141. max_seqlen_q,
  142. max_seqlen_k,
  143. softmax_scale,
  144. causal,
  145. window_size[0],
  146. window_size[1],
  147. deterministic,
  148. )
  149. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  150. # breakpoint()
  151. return dq, dk, dv, softmax_d
  152. class FlashAttnFunc(torch.autograd.Function):
  153. @staticmethod
  154. def forward(
  155. ctx,
  156. q,
  157. k,
  158. v,
  159. softmax_scale,
  160. causal,
  161. window_size,
  162. deterministic=False,
  163. descale_q=None,
  164. descale_k=None,
  165. descale_v=None,
  166. gqa_parallel=False,
  167. ):
  168. if softmax_scale is None:
  169. softmax_scale = q.shape[-1] ** (-0.5)
  170. out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
  171. q,
  172. k,
  173. v,
  174. softmax_scale,
  175. causal,
  176. window_size,
  177. descale_q=descale_q,
  178. descale_k=descale_k,
  179. descale_v=descale_v,
  180. gqa_parallel=gqa_parallel,
  181. )
  182. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  183. ctx.softmax_scale = softmax_scale
  184. ctx.causal = causal
  185. ctx.window_size = window_size
  186. ctx.deterministic = deterministic
  187. ctx.gqa_parallel = gqa_parallel
  188. return out, softmax_lse
  189. @staticmethod
  190. def backward(ctx, dout, *args):
  191. q, k, v, out, softmax_lse = ctx.saved_tensors
  192. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  193. _flash_attn_backward(
  194. dout,
  195. q,
  196. k,
  197. v,
  198. out,
  199. softmax_lse,
  200. dq,
  201. dk,
  202. dv,
  203. ctx.softmax_scale,
  204. ctx.causal,
  205. ctx.window_size,
  206. ctx.deterministic,
  207. )
  208. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  209. dk = dk[..., : dout.shape[-1]]
  210. dv = dv[..., : dout.shape[-1]]
  211. return dq, dk, dv, None, None, None, None, None, None, None, None
  212. class FlashAttnVarlenFunc(torch.autograd.Function):
  213. @staticmethod
  214. def forward(
  215. ctx,
  216. q,
  217. k,
  218. v,
  219. cu_seqlens_q,
  220. cu_seqlens_k,
  221. max_seqlen_q,
  222. max_seqlen_k,
  223. softmax_scale,
  224. causal,
  225. window_size,
  226. deterministic=False,
  227. seqused_q=None,
  228. seqused_k=None,
  229. ):
  230. if softmax_scale is None:
  231. softmax_scale = q.shape[-1] ** (-0.5)
  232. out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
  233. q,
  234. k,
  235. v,
  236. cu_seqlens_q,
  237. cu_seqlens_k,
  238. max_seqlen_q,
  239. max_seqlen_k,
  240. softmax_scale,
  241. causal=causal,
  242. window_size=window_size,
  243. seqused_q=seqused_q,
  244. seqused_k=seqused_k,
  245. )
  246. ctx.save_for_backward(
  247. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k,
  248. seqused_q, seqused_k
  249. )
  250. ctx.max_seqlen_q = max_seqlen_q
  251. ctx.max_seqlen_k = max_seqlen_k
  252. ctx.softmax_scale = softmax_scale
  253. ctx.causal = causal
  254. ctx.window_size = window_size
  255. ctx.deterministic = deterministic
  256. return out, softmax_lse
  257. @staticmethod
  258. def backward(ctx, dout, *args):
  259. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
  260. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  261. _flash_attn_varlen_backward(
  262. dout,
  263. q,
  264. k,
  265. v,
  266. out,
  267. softmax_lse,
  268. dq,
  269. dk,
  270. dv,
  271. cu_seqlens_q,
  272. cu_seqlens_k,
  273. ctx.max_seqlen_q,
  274. ctx.max_seqlen_k,
  275. ctx.softmax_scale,
  276. ctx.causal,
  277. ctx.window_size,
  278. ctx.deterministic,
  279. seqused_q,
  280. seqused_k,
  281. )
  282. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  283. dk = dk[..., : dout.shape[-1]]
  284. dv = dv[..., : dout.shape[-1]]
  285. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
  286. def flash_attn_func(
  287. q,
  288. k,
  289. v,
  290. softmax_scale=None,
  291. causal=False,
  292. window_size=(-1, -1),
  293. deterministic=False,
  294. descale_q=None,
  295. descale_k=None,
  296. descale_v=None,
  297. gqa_parallel=False,
  298. ):
  299. """dropout_p should be set to 0.0 during evaluation
  300. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  301. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  302. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  303. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  304. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  305. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  306. 1 1 1 1 0
  307. 1 1 1 1 1
  308. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  309. 0 0
  310. 0 0
  311. 0 0
  312. 1 0
  313. 1 1
  314. If the row of the mask is all zero, the output will be zero.
  315. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  316. will only attend to keys between
  317. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  318. Arguments:
  319. q: (batch_size, seqlen, nheads, headdim)
  320. k: (batch_size, seqlen, nheads_k, headdim)
  321. v: (batch_size, seqlen, nheads_k, headdim)
  322. dropout_p: float. Dropout probability.
  323. softmax_scale: float. The scaling of QK^T before applying softmax.
  324. Default to 1 / sqrt(headdim).
  325. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  326. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  327. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  328. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  329. is added to the attention score of query i and key j.
  330. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  331. which is slightly slower and uses more memory. The forward pass is always deterministic.
  332. descale_q: (1,), fp32. A de-quantization scaling factor for q in fp8 execution.
  333. descale_k: (1,), fp32. A de-quantization scaling factor for k in fp8 execution.
  334. descale_v: (1,), fp32. A de-quantization scaling factor for v in fp8 execution.
  335. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  336. testing only. The returned probabilities are not guaranteed to be correct
  337. (they might not have the right scaling).
  338. Return:
  339. out: (batch_size, seqlen, nheads, headdim).
  340. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  341. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  342. normalization factor).
  343. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  344. The output of softmax (possibly with different scaling). It also encodes the dropout
  345. pattern (negative means that location was dropped, nonnegative means it was kept).
  346. """
  347. return FlashAttnFunc.apply(
  348. q,
  349. k,
  350. v,
  351. softmax_scale,
  352. causal,
  353. window_size,
  354. deterministic,
  355. descale_q,
  356. descale_k,
  357. descale_v,
  358. gqa_parallel
  359. )
  360. def flash_attn_varlen_func(
  361. q,
  362. k,
  363. v,
  364. cu_seqlens_q,
  365. cu_seqlens_k,
  366. max_seqlen_q,
  367. max_seqlen_k,
  368. softmax_scale=None,
  369. causal=False,
  370. window_size=(-1, -1),
  371. deterministic=False,
  372. seqused_q=None,
  373. seqused_k=None,
  374. ):
  375. """
  376. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  377. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  378. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  379. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  380. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  381. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  382. 1 1 1 1 0
  383. 1 1 1 1 1
  384. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  385. 0 0
  386. 0 0
  387. 0 0
  388. 1 0
  389. 1 1
  390. If the row of the mask is all zero, the output will be zero.
  391. Arguments:
  392. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  393. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  394. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  395. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  396. of the sequences in the batch, used to index into q.
  397. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  398. of the sequences in the batch, used to index into kv.
  399. max_seqlen_q: int. Maximum query sequence length in the batch.
  400. max_seqlen_k: int. Maximum key sequence length in the batch.
  401. softmax_scale: float. The scaling of QK^T before applying softmax.
  402. Default to 1 / sqrt(headdim).
  403. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  404. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  405. seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
  406. query and output tokens in each sequence.
  407. seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
  408. key and value tokens in each sequence.
  409. Return:
  410. out: (total, nheads, headdim).
  411. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  412. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  413. normalization factor).
  414. """
  415. return FlashAttnVarlenFunc.apply(
  416. q,
  417. k,
  418. v,
  419. cu_seqlens_q,
  420. cu_seqlens_k,
  421. max_seqlen_q,
  422. max_seqlen_k,
  423. softmax_scale,
  424. causal,
  425. window_size,
  426. deterministic,
  427. seqused_q,
  428. seqused_k,
  429. )
  430. def flash_attn_with_kvcache(
  431. q,
  432. k_cache,
  433. v_cache,
  434. # k=None,
  435. # v=None,
  436. # rotary_cos=None,
  437. # rotary_sin=None,
  438. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  439. cache_batch_idx: Optional[torch.Tensor] = None,
  440. # cache_leftpad: Optional[torch.Tensor] = None,
  441. # block_table: Optional[torch.Tensor] = None,
  442. softmax_scale=None,
  443. causal=False,
  444. window_size=(-1, -1), # -1 means infinite context window
  445. # softcap=0.0, # 0.0 means deactivated
  446. # rotary_interleaved=True,
  447. # alibi_slopes=None,
  448. num_splits=0,
  449. return_softmax_lse=False,
  450. gqa_parallel=None,
  451. max_seqlen_k_hint=None,
  452. descale_q=None,
  453. descale_k=None,
  454. descale_v=None,
  455. ):
  456. """
  457. NOTE: The KV cache API for FlashAttention-3 is a work in progress. We reproduce the description
  458. from the FlashAttention-2 method of the same name below.
  459. If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
  460. k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
  461. the previous step, and update them with the new keys/values from the current step, and do
  462. attention with the updated cache, all in 1 kernel.
  463. If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
  464. For example, the KV cache could be pre-allocated with the max sequence length, and you can use
  465. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
  466. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
  467. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  468. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
  469. and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  470. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
  471. indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
  472. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
  473. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  474. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  475. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  476. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  477. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  478. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  479. 1 1 1 1 0
  480. 1 1 1 1 1
  481. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  482. 0 0
  483. 0 0
  484. 0 0
  485. 1 0
  486. 1 1
  487. If the row of the mask is all zero, the output will be zero.
  488. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  489. will only attend to keys between
  490. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  491. Note: Does not support backward pass.
  492. Arguments:
  493. q: (batch_size, seqlen, nheads, headdim)
  494. k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  495. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  496. page_block_size must be a multiple of 256.
  497. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  498. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  499. k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
  500. k with k_cache, starting at the indices specified by cache_seqlens.
  501. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
  502. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
  503. to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
  504. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
  505. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
  506. KV cache.
  507. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
  508. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
  509. If the indices are not distinct, and k and v are provided, the values updated in the cache
  510. might come from any of the duplicate indices.
  511. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
  512. block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
  513. softmax_scale: float. The scaling of QK^T before applying softmax.
  514. Default to 1 / sqrt(headdim).
  515. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  516. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  517. softcap: float. Anything > 0 activates softcapping attention.
  518. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
  519. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
  520. rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
  521. (i.e. GPT-NeoX style).
  522. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  523. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  524. is added to the attention score of query i and key j.
  525. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
  526. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
  527. to automatically determine the number of splits.
  528. Don't change this unless you know what you are doing.
  529. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
  530. Return:
  531. out: (batch_size, seqlen, nheads, headdim).
  532. softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
  533. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  534. normalization factor).
  535. """
  536. # unimplemented kwargs
  537. k=None
  538. v=None
  539. rotary_cos=None
  540. rotary_sin=None
  541. cache_leftpad=None
  542. block_table=None
  543. softcap=0.0
  544. rotary_interleaved=True
  545. alibi_slopes=None
  546. assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
  547. assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
  548. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  549. if softmax_scale is None:
  550. softmax_scale = q.shape[-1] ** (-0.5)
  551. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  552. cache_seqlens = torch.full(
  553. (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
  554. )
  555. cache_seqlens = maybe_contiguous(cache_seqlens)
  556. cache_batch_idx = maybe_contiguous(cache_batch_idx)
  557. # block_table = maybe_contiguous(block_table)
  558. if gqa_parallel is None:
  559. gqa_parallel = True if q.shape[1] <= 64 else False
  560. # not in gqa/mqa setup
  561. if q.shape[2] == k_cache.shape[2]:
  562. gqa_parallel = False
  563. if max_seqlen_k_hint is None:
  564. max_seqlen_k_hint = k_cache.shape[1]
  565. out, softmax_lse = flashattn_hopper_cuda.fwd_kvcache(
  566. q,
  567. k_cache,
  568. v_cache,
  569. k,
  570. v,
  571. cache_seqlens,
  572. rotary_cos,
  573. rotary_sin,
  574. cache_batch_idx,
  575. cache_leftpad,
  576. block_table,
  577. alibi_slopes,
  578. None,
  579. softmax_scale,
  580. descale_q,
  581. descale_k,
  582. descale_v,
  583. causal,
  584. window_size[0],
  585. window_size[1],
  586. softcap,
  587. rotary_interleaved,
  588. num_splits,
  589. max_seqlen_k_hint,
  590. gqa_parallel
  591. )
  592. return (out, softmax_lse) if return_softmax_lse else out