flash_attn_interface.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flash_attn_2_cuda as flash_attn_cuda
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _get_block_size_n(device, head_dim, is_dropout, is_causal):
  12. # This should match the block sizes in the CUDA kernel
  13. assert head_dim <= 256
  14. major, minor = torch.cuda.get_device_capability(device)
  15. is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
  16. is_sm80 = major == 8 and minor == 0
  17. is_sm90 = major == 9 and minor == 0
  18. if head_dim <= 32:
  19. return 128
  20. if head_dim <= 64:
  21. return 128 if not is_dropout else 64
  22. elif head_dim <= 96:
  23. return 64
  24. elif head_dim <= 128:
  25. if is_sm8x:
  26. return 64 if (not is_dropout and is_causal) else 32
  27. else:
  28. return 64 if not is_dropout else 32
  29. elif head_dim <= 160:
  30. if is_sm8x:
  31. return 64
  32. else:
  33. return 32
  34. elif head_dim <= 192:
  35. return 64
  36. elif head_dim <= 224:
  37. return 64
  38. elif head_dim <= 256:
  39. return 64
  40. def _flash_attn_forward(
  41. q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
  42. ):
  43. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  44. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
  45. q,
  46. k,
  47. v,
  48. None,
  49. alibi_slopes,
  50. dropout_p,
  51. softmax_scale,
  52. causal,
  53. window_size[0],
  54. window_size[1],
  55. softcap,
  56. return_softmax,
  57. None,
  58. )
  59. return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
  60. def _flash_attn_varlen_forward(
  61. q,
  62. k,
  63. v,
  64. cu_seqlens_q,
  65. cu_seqlens_k,
  66. max_seqlen_q,
  67. max_seqlen_k,
  68. dropout_p,
  69. softmax_scale,
  70. causal,
  71. window_size=(-1, -1),
  72. softcap=0.0,
  73. alibi_slopes=None,
  74. return_softmax=False,
  75. block_table=None,
  76. leftpad_k=None,
  77. seqused_k=None,
  78. ):
  79. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  80. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
  81. q,
  82. k,
  83. v,
  84. None,
  85. cu_seqlens_q,
  86. cu_seqlens_k,
  87. seqused_k,
  88. leftpad_k,
  89. block_table,
  90. alibi_slopes,
  91. max_seqlen_q,
  92. max_seqlen_k,
  93. dropout_p,
  94. softmax_scale,
  95. False,
  96. causal,
  97. window_size[0],
  98. window_size[1],
  99. softcap,
  100. return_softmax,
  101. None,
  102. )
  103. # if out.isnan().any() or softmax_lse.isnan().any():
  104. # breakpoint()
  105. return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
  106. def _flash_attn_backward(
  107. dout,
  108. q,
  109. k,
  110. v,
  111. out,
  112. softmax_lse,
  113. dq,
  114. dk,
  115. dv,
  116. dropout_p,
  117. softmax_scale,
  118. causal,
  119. window_size,
  120. softcap,
  121. alibi_slopes,
  122. deterministic,
  123. rng_state=None,
  124. ):
  125. # dq, dk, dv are allocated by us so they should already be contiguous
  126. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  127. (
  128. dq,
  129. dk,
  130. dv,
  131. softmax_d,
  132. ) = flash_attn_cuda.bwd(
  133. dout,
  134. q,
  135. k,
  136. v,
  137. out,
  138. softmax_lse,
  139. dq,
  140. dk,
  141. dv,
  142. alibi_slopes,
  143. dropout_p,
  144. softmax_scale,
  145. causal,
  146. window_size[0],
  147. window_size[1],
  148. softcap,
  149. deterministic,
  150. None,
  151. rng_state,
  152. )
  153. return dq, dk, dv, softmax_d
  154. def _flash_attn_varlen_backward(
  155. dout,
  156. q,
  157. k,
  158. v,
  159. out,
  160. softmax_lse,
  161. dq,
  162. dk,
  163. dv,
  164. cu_seqlens_q,
  165. cu_seqlens_k,
  166. max_seqlen_q,
  167. max_seqlen_k,
  168. dropout_p,
  169. softmax_scale,
  170. causal,
  171. window_size,
  172. softcap,
  173. alibi_slopes,
  174. deterministic,
  175. rng_state=None,
  176. ):
  177. # dq, dk, dv are allocated by us so they should already be contiguous
  178. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  179. (
  180. dq,
  181. dk,
  182. dv,
  183. softmax_d,
  184. ) = flash_attn_cuda.varlen_bwd(
  185. dout,
  186. q,
  187. k,
  188. v,
  189. out,
  190. softmax_lse,
  191. dq,
  192. dk,
  193. dv,
  194. cu_seqlens_q,
  195. cu_seqlens_k,
  196. alibi_slopes,
  197. max_seqlen_q,
  198. max_seqlen_k,
  199. dropout_p,
  200. softmax_scale,
  201. False,
  202. causal,
  203. window_size[0],
  204. window_size[1],
  205. softcap,
  206. deterministic,
  207. None,
  208. rng_state,
  209. )
  210. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  211. # breakpoint()
  212. return dq, dk, dv, softmax_d
  213. class FlashAttnQKVPackedFunc(torch.autograd.Function):
  214. @staticmethod
  215. def forward(
  216. ctx,
  217. qkv,
  218. dropout_p,
  219. softmax_scale,
  220. causal,
  221. window_size,
  222. softcap,
  223. alibi_slopes,
  224. deterministic,
  225. return_softmax,
  226. ):
  227. if softmax_scale is None:
  228. softmax_scale = qkv.shape[-1] ** (-0.5)
  229. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  230. qkv[:, :, 0],
  231. qkv[:, :, 1],
  232. qkv[:, :, 2],
  233. dropout_p,
  234. softmax_scale,
  235. causal=causal,
  236. window_size=window_size,
  237. softcap=softcap,
  238. alibi_slopes=alibi_slopes,
  239. return_softmax=return_softmax and dropout_p > 0,
  240. )
  241. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  242. ctx.dropout_p = dropout_p
  243. ctx.softmax_scale = softmax_scale
  244. ctx.causal = causal
  245. ctx.window_size = window_size
  246. ctx.softcap = softcap
  247. ctx.alibi_slopes = alibi_slopes
  248. ctx.deterministic = deterministic
  249. return out if not return_softmax else (out, softmax_lse, S_dmask)
  250. @staticmethod
  251. def backward(ctx, dout, *args):
  252. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  253. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  254. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  255. _flash_attn_backward(
  256. dout,
  257. q,
  258. k,
  259. v,
  260. out,
  261. softmax_lse,
  262. dqkv[:, :, 0],
  263. dqkv[:, :, 1],
  264. dqkv[:, :, 2],
  265. ctx.dropout_p,
  266. ctx.softmax_scale,
  267. ctx.causal,
  268. ctx.window_size,
  269. ctx.softcap,
  270. ctx.alibi_slopes,
  271. ctx.deterministic,
  272. rng_state=rng_state,
  273. )
  274. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  275. return dqkv, None, None, None, None, None, None, None, None
  276. class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
  277. @staticmethod
  278. def forward(
  279. ctx,
  280. qkv,
  281. cu_seqlens,
  282. max_seqlen,
  283. dropout_p,
  284. softmax_scale,
  285. causal,
  286. window_size,
  287. softcap,
  288. alibi_slopes,
  289. deterministic,
  290. return_softmax,
  291. ):
  292. if softmax_scale is None:
  293. softmax_scale = qkv.shape[-1] ** (-0.5)
  294. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  295. qkv[:, 0],
  296. qkv[:, 1],
  297. qkv[:, 2],
  298. cu_seqlens,
  299. cu_seqlens,
  300. max_seqlen,
  301. max_seqlen,
  302. dropout_p,
  303. softmax_scale,
  304. causal=causal,
  305. window_size=window_size,
  306. softcap=softcap,
  307. alibi_slopes=alibi_slopes,
  308. return_softmax=return_softmax and dropout_p > 0,
  309. block_table=None,
  310. )
  311. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
  312. ctx.dropout_p = dropout_p
  313. ctx.max_seqlen = max_seqlen
  314. ctx.softmax_scale = softmax_scale
  315. ctx.causal = causal
  316. ctx.window_size = window_size
  317. ctx.softcap = softcap
  318. ctx.alibi_slopes = alibi_slopes
  319. ctx.deterministic = deterministic
  320. return out if not return_softmax else (out, softmax_lse, S_dmask)
  321. @staticmethod
  322. def backward(ctx, dout, *args):
  323. q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
  324. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  325. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  326. _flash_attn_varlen_backward(
  327. dout,
  328. q,
  329. k,
  330. v,
  331. out,
  332. softmax_lse,
  333. dqkv[:, 0],
  334. dqkv[:, 1],
  335. dqkv[:, 2],
  336. cu_seqlens,
  337. cu_seqlens,
  338. ctx.max_seqlen,
  339. ctx.max_seqlen,
  340. ctx.dropout_p,
  341. ctx.softmax_scale,
  342. ctx.causal,
  343. ctx.window_size,
  344. ctx.softcap,
  345. ctx.alibi_slopes,
  346. ctx.deterministic,
  347. rng_state=rng_state,
  348. )
  349. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  350. return dqkv, None, None, None, None, None, None, None, None, None, None
  351. class FlashAttnKVPackedFunc(torch.autograd.Function):
  352. @staticmethod
  353. def forward(
  354. ctx,
  355. q,
  356. kv,
  357. dropout_p,
  358. softmax_scale,
  359. causal,
  360. window_size,
  361. softcap,
  362. alibi_slopes,
  363. deterministic,
  364. return_softmax,
  365. ):
  366. if softmax_scale is None:
  367. softmax_scale = q.shape[-1] ** (-0.5)
  368. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  369. q,
  370. kv[:, :, 0],
  371. kv[:, :, 1],
  372. dropout_p,
  373. softmax_scale,
  374. causal=causal,
  375. window_size=window_size,
  376. softcap=softcap,
  377. alibi_slopes=alibi_slopes,
  378. return_softmax=return_softmax and dropout_p > 0,
  379. )
  380. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  381. ctx.dropout_p = dropout_p
  382. ctx.softmax_scale = softmax_scale
  383. ctx.causal = causal
  384. ctx.window_size = window_size
  385. ctx.softcap = softcap
  386. ctx.alibi_slopes = alibi_slopes
  387. ctx.deterministic = deterministic
  388. return out if not return_softmax else (out, softmax_lse, S_dmask)
  389. @staticmethod
  390. def backward(ctx, dout, *args):
  391. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  392. dq = torch.empty_like(q)
  393. kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
  394. dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
  395. _flash_attn_backward(
  396. dout,
  397. q,
  398. k,
  399. v,
  400. out,
  401. softmax_lse,
  402. dq,
  403. dkv[:, :, 0],
  404. dkv[:, :, 1],
  405. ctx.dropout_p,
  406. ctx.softmax_scale,
  407. ctx.causal,
  408. ctx.window_size,
  409. ctx.softcap,
  410. ctx.alibi_slopes,
  411. ctx.deterministic,
  412. rng_state=rng_state,
  413. )
  414. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  415. dkv = dkv[..., : dout.shape[-1]]
  416. return dq, dkv, None, None, None, None, None, None, None, None
  417. class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
  418. @staticmethod
  419. def forward(
  420. ctx,
  421. q,
  422. kv,
  423. cu_seqlens_q,
  424. cu_seqlens_k,
  425. max_seqlen_q,
  426. max_seqlen_k,
  427. dropout_p,
  428. softmax_scale,
  429. causal,
  430. window_size,
  431. softcap,
  432. alibi_slopes,
  433. deterministic,
  434. return_softmax,
  435. ):
  436. if softmax_scale is None:
  437. softmax_scale = q.shape[-1] ** (-0.5)
  438. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  439. q,
  440. kv[:, 0],
  441. kv[:, 1],
  442. cu_seqlens_q,
  443. cu_seqlens_k,
  444. max_seqlen_q,
  445. max_seqlen_k,
  446. dropout_p,
  447. softmax_scale,
  448. causal=causal,
  449. window_size=window_size,
  450. softcap=softcap,
  451. alibi_slopes=alibi_slopes,
  452. return_softmax=return_softmax and dropout_p > 0,
  453. block_table=None,
  454. )
  455. ctx.save_for_backward(
  456. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
  457. )
  458. ctx.dropout_p = dropout_p
  459. ctx.max_seqlen_q = max_seqlen_q
  460. ctx.max_seqlen_k = max_seqlen_k
  461. ctx.softmax_scale = softmax_scale
  462. ctx.causal = causal
  463. ctx.window_size = window_size
  464. ctx.softcap = softcap
  465. ctx.alibi_slopes = alibi_slopes
  466. ctx.deterministic = deterministic
  467. return out if not return_softmax else (out, softmax_lse, S_dmask)
  468. @staticmethod
  469. def backward(ctx, dout, *args):
  470. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
  471. dq = torch.empty_like(q)
  472. kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
  473. dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
  474. _flash_attn_varlen_backward(
  475. dout,
  476. q,
  477. k,
  478. v,
  479. out,
  480. softmax_lse,
  481. dq,
  482. dkv[:, 0],
  483. dkv[:, 1],
  484. cu_seqlens_q,
  485. cu_seqlens_k,
  486. ctx.max_seqlen_q,
  487. ctx.max_seqlen_k,
  488. ctx.dropout_p,
  489. ctx.softmax_scale,
  490. ctx.causal,
  491. ctx.window_size,
  492. ctx.softcap,
  493. ctx.alibi_slopes,
  494. ctx.deterministic,
  495. rng_state=rng_state,
  496. )
  497. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  498. dkv = dkv[..., : dout.shape[-1]]
  499. return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
  500. class FlashAttnFunc(torch.autograd.Function):
  501. @staticmethod
  502. def forward(
  503. ctx,
  504. q,
  505. k,
  506. v,
  507. dropout_p,
  508. softmax_scale,
  509. causal,
  510. window_size,
  511. softcap,
  512. alibi_slopes,
  513. deterministic,
  514. return_softmax,
  515. ):
  516. if softmax_scale is None:
  517. softmax_scale = q.shape[-1] ** (-0.5)
  518. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  519. q,
  520. k,
  521. v,
  522. dropout_p,
  523. softmax_scale,
  524. causal=causal,
  525. window_size=window_size,
  526. softcap=softcap,
  527. alibi_slopes=alibi_slopes,
  528. return_softmax=return_softmax and dropout_p > 0,
  529. )
  530. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  531. ctx.dropout_p = dropout_p
  532. ctx.softmax_scale = softmax_scale
  533. ctx.causal = causal
  534. ctx.window_size = window_size
  535. ctx.softcap = softcap
  536. ctx.alibi_slopes = alibi_slopes
  537. ctx.deterministic = deterministic
  538. return out if not return_softmax else (out, softmax_lse, S_dmask)
  539. @staticmethod
  540. def backward(ctx, dout, *args):
  541. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  542. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  543. _flash_attn_backward(
  544. dout,
  545. q,
  546. k,
  547. v,
  548. out,
  549. softmax_lse,
  550. dq,
  551. dk,
  552. dv,
  553. ctx.dropout_p,
  554. ctx.softmax_scale,
  555. ctx.causal,
  556. ctx.window_size,
  557. ctx.softcap,
  558. ctx.alibi_slopes,
  559. ctx.deterministic,
  560. rng_state=rng_state,
  561. )
  562. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  563. dk = dk[..., : dout.shape[-1]]
  564. dv = dv[..., : dout.shape[-1]]
  565. return dq, dk, dv, None, None, None, None, None, None, None, None
  566. class FlashAttnVarlenFunc(torch.autograd.Function):
  567. @staticmethod
  568. def forward(
  569. ctx,
  570. q,
  571. k,
  572. v,
  573. cu_seqlens_q,
  574. cu_seqlens_k,
  575. max_seqlen_q,
  576. max_seqlen_k,
  577. dropout_p,
  578. softmax_scale,
  579. causal,
  580. window_size,
  581. softcap,
  582. alibi_slopes,
  583. deterministic,
  584. return_softmax,
  585. block_table,
  586. ):
  587. if softmax_scale is None:
  588. softmax_scale = q.shape[-1] ** (-0.5)
  589. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  590. q,
  591. k,
  592. v,
  593. cu_seqlens_q,
  594. cu_seqlens_k,
  595. max_seqlen_q,
  596. max_seqlen_k,
  597. dropout_p,
  598. softmax_scale,
  599. causal=causal,
  600. window_size=window_size,
  601. softcap=softcap,
  602. alibi_slopes=alibi_slopes,
  603. return_softmax=return_softmax and dropout_p > 0,
  604. block_table=block_table,
  605. )
  606. ctx.save_for_backward(
  607. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
  608. )
  609. ctx.dropout_p = dropout_p
  610. ctx.max_seqlen_q = max_seqlen_q
  611. ctx.max_seqlen_k = max_seqlen_k
  612. ctx.softmax_scale = softmax_scale
  613. ctx.causal = causal
  614. ctx.window_size = window_size
  615. ctx.softcap = softcap
  616. ctx.alibi_slopes = alibi_slopes
  617. ctx.deterministic = deterministic
  618. return out if not return_softmax else (out, softmax_lse, S_dmask)
  619. @staticmethod
  620. def backward(ctx, dout, *args):
  621. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
  622. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  623. _flash_attn_varlen_backward(
  624. dout,
  625. q,
  626. k,
  627. v,
  628. out,
  629. softmax_lse,
  630. dq,
  631. dk,
  632. dv,
  633. cu_seqlens_q,
  634. cu_seqlens_k,
  635. ctx.max_seqlen_q,
  636. ctx.max_seqlen_k,
  637. ctx.dropout_p,
  638. ctx.softmax_scale,
  639. ctx.causal,
  640. ctx.window_size,
  641. ctx.softcap,
  642. ctx.alibi_slopes,
  643. ctx.deterministic,
  644. rng_state=rng_state,
  645. )
  646. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  647. dk = dk[..., : dout.shape[-1]]
  648. dv = dv[..., : dout.shape[-1]]
  649. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
  650. def flash_attn_qkvpacked_func(
  651. qkv,
  652. dropout_p=0.0,
  653. softmax_scale=None,
  654. causal=False,
  655. window_size=(-1, -1), # -1 means infinite context window
  656. softcap=0.0, # <=0.0 means deactivate
  657. alibi_slopes=None,
  658. deterministic=False,
  659. return_attn_probs=False,
  660. ):
  661. """dropout_p should be set to 0.0 during evaluation
  662. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  663. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  664. of the gradients of Q, K, V.
  665. For multi-query and grouped-query attention (MQA/GQA), please see
  666. flash_attn_kvpacked_func and flash_attn_func.
  667. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  668. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  669. Arguments:
  670. qkv: (batch_size, seqlen, 3, nheads, headdim)
  671. dropout_p: float. Dropout probability.
  672. softmax_scale: float. The scaling of QK^T before applying softmax.
  673. Default to 1 / sqrt(headdim).
  674. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  675. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  676. softcap: float. Anything > 0 activates softcapping attention.
  677. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
  678. the attention score of query i and key j.
  679. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  680. which is slightly slower and uses more memory. The forward pass is always deterministic.
  681. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  682. testing only. The returned probabilities are not guaranteed to be correct
  683. (they might not have the right scaling).
  684. Return:
  685. out: (batch_size, seqlen, nheads, headdim).
  686. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  687. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  688. normalization factor).
  689. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  690. The output of softmax (possibly with different scaling). It also encodes the dropout
  691. pattern (negative means that location was dropped, nonnegative means it was kept).
  692. """
  693. return FlashAttnQKVPackedFunc.apply(
  694. qkv,
  695. dropout_p,
  696. softmax_scale,
  697. causal,
  698. window_size,
  699. softcap,
  700. alibi_slopes,
  701. deterministic,
  702. return_attn_probs,
  703. )
  704. def flash_attn_kvpacked_func(
  705. q,
  706. kv,
  707. dropout_p=0.0,
  708. softmax_scale=None,
  709. causal=False,
  710. window_size=(-1, -1), # -1 means infinite context window
  711. softcap=0.0, # 0.0 means deactivated
  712. alibi_slopes=None,
  713. deterministic=False,
  714. return_attn_probs=False,
  715. ):
  716. """dropout_p should be set to 0.0 during evaluation
  717. If K, V are already stacked into 1 tensor, this function will be faster than
  718. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  719. of the gradients of K, V.
  720. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  721. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  722. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  723. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  724. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  725. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  726. 1 1 1 1 0
  727. 1 1 1 1 1
  728. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  729. 0 0
  730. 0 0
  731. 0 0
  732. 1 0
  733. 1 1
  734. If the row of the mask is all zero, the output will be zero.
  735. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  736. will only attend to keys between
  737. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  738. Arguments:
  739. q: (batch_size, seqlen, nheads, headdim)
  740. kv: (batch_size, seqlen, 2, nheads_k, headdim)
  741. dropout_p: float. Dropout probability.
  742. softmax_scale: float. The scaling of QK^T before applying softmax.
  743. Default to 1 / sqrt(headdim).
  744. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  745. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  746. softcap: float. Anything > 0 activates softcapping attention.
  747. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  748. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  749. is added to the attention score of query i and key j.
  750. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  751. which is slightly slower and uses more memory. The forward pass is always deterministic.
  752. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  753. testing only. The returned probabilities are not guaranteed to be correct
  754. (they might not have the right scaling).
  755. Return:
  756. out: (batch_size, seqlen, nheads, headdim).
  757. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  758. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  759. normalization factor).
  760. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  761. The output of softmax (possibly with different scaling). It also encodes the dropout
  762. pattern (negative means that location was dropped, nonnegative means it was kept).
  763. """
  764. return FlashAttnKVPackedFunc.apply(
  765. q,
  766. kv,
  767. dropout_p,
  768. softmax_scale,
  769. causal,
  770. window_size,
  771. softcap,
  772. alibi_slopes,
  773. deterministic,
  774. return_attn_probs,
  775. )
  776. def flash_attn_func(
  777. q,
  778. k,
  779. v,
  780. dropout_p=0.0,
  781. softmax_scale=None,
  782. causal=False,
  783. window_size=(-1, -1), # -1 means infinite context window
  784. softcap=0.0, # 0.0 means deactivated
  785. alibi_slopes=None,
  786. deterministic=False,
  787. return_attn_probs=False,
  788. ):
  789. """dropout_p should be set to 0.0 during evaluation
  790. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  791. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  792. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  793. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  794. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  795. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  796. 1 1 1 1 0
  797. 1 1 1 1 1
  798. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  799. 0 0
  800. 0 0
  801. 0 0
  802. 1 0
  803. 1 1
  804. If the row of the mask is all zero, the output will be zero.
  805. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  806. will only attend to keys between
  807. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  808. Arguments:
  809. q: (batch_size, seqlen, nheads, headdim)
  810. k: (batch_size, seqlen, nheads_k, headdim)
  811. v: (batch_size, seqlen, nheads_k, headdim)
  812. dropout_p: float. Dropout probability.
  813. softmax_scale: float. The scaling of QK^T before applying softmax.
  814. Default to 1 / sqrt(headdim).
  815. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  816. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  817. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  818. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  819. is added to the attention score of query i and key j.
  820. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  821. which is slightly slower and uses more memory. The forward pass is always deterministic.
  822. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  823. testing only. The returned probabilities are not guaranteed to be correct
  824. (they might not have the right scaling).
  825. Return:
  826. out: (batch_size, seqlen, nheads, headdim).
  827. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  828. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  829. normalization factor).
  830. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  831. The output of softmax (possibly with different scaling). It also encodes the dropout
  832. pattern (negative means that location was dropped, nonnegative means it was kept).
  833. """
  834. return FlashAttnFunc.apply(
  835. q,
  836. k,
  837. v,
  838. dropout_p,
  839. softmax_scale,
  840. causal,
  841. window_size,
  842. softcap,
  843. alibi_slopes,
  844. deterministic,
  845. return_attn_probs,
  846. )
  847. def flash_attn_varlen_qkvpacked_func(
  848. qkv,
  849. cu_seqlens,
  850. max_seqlen,
  851. dropout_p=0.0,
  852. softmax_scale=None,
  853. causal=False,
  854. window_size=(-1, -1), # -1 means infinite context window
  855. softcap=0.0, # 0.0 means deactivated
  856. alibi_slopes=None,
  857. deterministic=False,
  858. return_attn_probs=False,
  859. ):
  860. """dropout_p should be set to 0.0 during evaluation
  861. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  862. calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
  863. of the gradients of Q, K, V.
  864. For multi-query and grouped-query attention (MQA/GQA), please see
  865. flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
  866. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  867. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  868. Arguments:
  869. qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
  870. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  871. of the sequences in the batch, used to index into qkv.
  872. max_seqlen: int. Maximum sequence length in the batch.
  873. dropout_p: float. Dropout probability.
  874. softmax_scale: float. The scaling of QK^T before applying softmax.
  875. Default to 1 / sqrt(headdim).
  876. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  877. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  878. softcap: float. Anything > 0 activates softcapping attention.
  879. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
  880. is added to the attention score of query i and key j.
  881. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  882. which is slightly slower and uses more memory. The forward pass is always deterministic.
  883. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  884. testing only. The returned probabilities are not guaranteed to be correct
  885. (they might not have the right scaling).
  886. Return:
  887. out: (total, nheads, headdim).
  888. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  889. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  890. normalization factor).
  891. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  892. The output of softmax (possibly with different scaling). It also encodes the dropout
  893. pattern (negative means that location was dropped, nonnegative means it was kept).
  894. """
  895. return FlashAttnVarlenQKVPackedFunc.apply(
  896. qkv,
  897. cu_seqlens,
  898. max_seqlen,
  899. dropout_p,
  900. softmax_scale,
  901. causal,
  902. window_size,
  903. softcap,
  904. alibi_slopes,
  905. deterministic,
  906. return_attn_probs,
  907. )
  908. def flash_attn_varlen_kvpacked_func(
  909. q,
  910. kv,
  911. cu_seqlens_q,
  912. cu_seqlens_k,
  913. max_seqlen_q,
  914. max_seqlen_k,
  915. dropout_p=0.0,
  916. softmax_scale=None,
  917. causal=False,
  918. window_size=(-1, -1), # -1 means infinite context window
  919. softcap=0.0, # 0.0 means deactivated
  920. alibi_slopes=None,
  921. deterministic=False,
  922. return_attn_probs=False,
  923. ):
  924. """dropout_p should be set to 0.0 during evaluation
  925. If K, V are already stacked into 1 tensor, this function will be faster than
  926. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  927. of the gradients of K, V.
  928. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  929. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  930. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  931. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  932. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  933. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  934. 1 1 1 1 0
  935. 1 1 1 1 1
  936. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  937. 0 0
  938. 0 0
  939. 0 0
  940. 1 0
  941. 1 1
  942. If the row of the mask is all zero, the output will be zero.
  943. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  944. will only attend to keys between
  945. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  946. Arguments:
  947. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  948. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  949. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  950. of the sequences in the batch, used to index into q.
  951. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  952. of the sequences in the batch, used to index into kv.
  953. max_seqlen_q: int. Maximum query sequence length in the batch.
  954. max_seqlen_k: int. Maximum key sequence length in the batch.
  955. dropout_p: float. Dropout probability.
  956. softmax_scale: float. The scaling of QK^T before applying softmax.
  957. Default to 1 / sqrt(headdim).
  958. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  959. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  960. softcap: float. Anything > 0 activates softcapping attention.
  961. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  962. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  963. is added to the attention score of query i and key j.
  964. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  965. which is slightly slower and uses more memory. The forward pass is always deterministic.
  966. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  967. testing only. The returned probabilities are not guaranteed to be correct
  968. (they might not have the right scaling).
  969. Return:
  970. out: (total, nheads, headdim).
  971. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  972. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  973. normalization factor).
  974. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  975. The output of softmax (possibly with different scaling). It also encodes the dropout
  976. pattern (negative means that location was dropped, nonnegative means it was kept).
  977. """
  978. return FlashAttnVarlenKVPackedFunc.apply(
  979. q,
  980. kv,
  981. cu_seqlens_q,
  982. cu_seqlens_k,
  983. max_seqlen_q,
  984. max_seqlen_k,
  985. dropout_p,
  986. softmax_scale,
  987. causal,
  988. window_size,
  989. softcap,
  990. alibi_slopes,
  991. deterministic,
  992. return_attn_probs,
  993. )
  994. def flash_attn_varlen_func(
  995. q,
  996. k,
  997. v,
  998. cu_seqlens_q,
  999. cu_seqlens_k,
  1000. max_seqlen_q,
  1001. max_seqlen_k,
  1002. dropout_p=0.0,
  1003. softmax_scale=None,
  1004. causal=False,
  1005. window_size=(-1, -1), # -1 means infinite context window
  1006. softcap=0.0, # 0.0 means deactivated
  1007. alibi_slopes=None,
  1008. deterministic=False,
  1009. return_attn_probs=False,
  1010. block_table=None,
  1011. ):
  1012. """dropout_p should be set to 0.0 during evaluation
  1013. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  1014. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1015. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1016. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1017. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1018. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1019. 1 1 1 1 0
  1020. 1 1 1 1 1
  1021. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1022. 0 0
  1023. 0 0
  1024. 0 0
  1025. 1 0
  1026. 1 1
  1027. If the row of the mask is all zero, the output will be zero.
  1028. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1029. will only attend to keys between
  1030. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1031. Arguments:
  1032. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  1033. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1034. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1035. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1036. of the sequences in the batch, used to index into q.
  1037. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1038. of the sequences in the batch, used to index into kv.
  1039. max_seqlen_q: int. Maximum query sequence length in the batch.
  1040. max_seqlen_k: int. Maximum key sequence length in the batch.
  1041. dropout_p: float. Dropout probability.
  1042. softmax_scale: float. The scaling of QK^T before applying softmax.
  1043. Default to 1 / sqrt(headdim).
  1044. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1045. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1046. softcap: float. Anything > 0 activates softcapping attention.
  1047. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1048. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1049. is added to the attention score of query i and key j.
  1050. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1051. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1052. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1053. testing only. The returned probabilities are not guaranteed to be correct
  1054. (they might not have the right scaling).
  1055. Return:
  1056. out: (total, nheads, headdim).
  1057. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  1058. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1059. normalization factor).
  1060. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1061. The output of softmax (possibly with different scaling). It also encodes the dropout
  1062. pattern (negative means that location was dropped, nonnegative means it was kept).
  1063. """
  1064. return FlashAttnVarlenFunc.apply(
  1065. q,
  1066. k,
  1067. v,
  1068. cu_seqlens_q,
  1069. cu_seqlens_k,
  1070. max_seqlen_q,
  1071. max_seqlen_k,
  1072. dropout_p,
  1073. softmax_scale,
  1074. causal,
  1075. window_size,
  1076. softcap,
  1077. alibi_slopes,
  1078. deterministic,
  1079. return_attn_probs,
  1080. block_table,
  1081. )
  1082. def flash_attn_with_kvcache(
  1083. q,
  1084. k_cache,
  1085. v_cache,
  1086. k=None,
  1087. v=None,
  1088. rotary_cos=None,
  1089. rotary_sin=None,
  1090. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  1091. cache_batch_idx: Optional[torch.Tensor] = None,
  1092. cache_leftpad: Optional[torch.Tensor] = None,
  1093. block_table: Optional[torch.Tensor] = None,
  1094. softmax_scale=None,
  1095. causal=False,
  1096. window_size=(-1, -1), # -1 means infinite context window
  1097. softcap=0.0, # 0.0 means deactivated
  1098. rotary_interleaved=True,
  1099. alibi_slopes=None,
  1100. num_splits=0,
  1101. return_softmax_lse=False,
  1102. ):
  1103. """
  1104. If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
  1105. k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
  1106. the previous step, and update them with the new keys/values from the current step, and do
  1107. attention with the updated cache, all in 1 kernel.
  1108. If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
  1109. For example, the KV cache could be pre-allocated with the max sequence length, and you can use
  1110. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
  1111. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
  1112. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  1113. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
  1114. and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  1115. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
  1116. indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
  1117. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
  1118. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1119. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1120. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1121. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1122. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1123. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1124. 1 1 1 1 0
  1125. 1 1 1 1 1
  1126. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1127. 0 0
  1128. 0 0
  1129. 0 0
  1130. 1 0
  1131. 1 1
  1132. If the row of the mask is all zero, the output will be zero.
  1133. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1134. will only attend to keys between
  1135. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1136. Note: Does not support backward pass.
  1137. Arguments:
  1138. q: (batch_size, seqlen, nheads, headdim)
  1139. k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  1140. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  1141. page_block_size must be a multiple of 256.
  1142. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  1143. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  1144. k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
  1145. k with k_cache, starting at the indices specified by cache_seqlens.
  1146. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
  1147. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
  1148. to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
  1149. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
  1150. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
  1151. KV cache.
  1152. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
  1153. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
  1154. If the indices are not distinct, and k and v are provided, the values updated in the cache
  1155. might come from any of the duplicate indices.
  1156. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
  1157. block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
  1158. softmax_scale: float. The scaling of QK^T before applying softmax.
  1159. Default to 1 / sqrt(headdim).
  1160. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1161. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1162. softcap: float. Anything > 0 activates softcapping attention.
  1163. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
  1164. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
  1165. rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
  1166. (i.e. GPT-NeoX style).
  1167. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1168. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1169. is added to the attention score of query i and key j.
  1170. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
  1171. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
  1172. to automatically determine the number of splits.
  1173. Don't change this unless you know what you are doing.
  1174. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
  1175. Return:
  1176. out: (batch_size, seqlen, nheads, headdim).
  1177. softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
  1178. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1179. normalization factor).
  1180. """
  1181. assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
  1182. assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
  1183. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  1184. if softmax_scale is None:
  1185. softmax_scale = q.shape[-1] ** (-0.5)
  1186. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  1187. cache_seqlens = torch.full(
  1188. (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
  1189. )
  1190. cache_seqlens = maybe_contiguous(cache_seqlens)
  1191. cache_batch_idx = maybe_contiguous(cache_batch_idx)
  1192. block_table = maybe_contiguous(block_table)
  1193. out, softmax_lse = flash_attn_cuda.fwd_kvcache(
  1194. q,
  1195. k_cache,
  1196. v_cache,
  1197. k,
  1198. v,
  1199. cache_seqlens,
  1200. rotary_cos,
  1201. rotary_sin,
  1202. cache_batch_idx,
  1203. cache_leftpad,
  1204. block_table,
  1205. alibi_slopes,
  1206. None,
  1207. softmax_scale,
  1208. causal,
  1209. window_size[0],
  1210. window_size[1],
  1211. softcap,
  1212. rotary_interleaved,
  1213. num_splits,
  1214. )
  1215. return (out, softmax_lse) if return_softmax_lse else out