flash_attn_interface.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flash_attn_2_cuda as flash_attn_cuda
  8. # isort: on
  9. def _get_block_size_n(device, head_dim, is_dropout, is_causal):
  10. # This should match the block sizes in the CUDA kernel
  11. assert head_dim <= 256
  12. major, minor = torch.cuda.get_device_capability(device)
  13. is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
  14. is_sm80 = major == 8 and minor == 0
  15. is_sm90 = major == 9 and minor == 0
  16. if head_dim <= 32:
  17. return 128
  18. if head_dim <= 64:
  19. return 128 if not is_dropout else 64
  20. elif head_dim <= 96:
  21. return 64
  22. elif head_dim <= 128:
  23. if is_sm8x:
  24. return 64 if (not is_dropout and is_causal) else 32
  25. else:
  26. return 64 if not is_dropout else 32
  27. elif head_dim <= 160:
  28. if is_sm8x:
  29. return 64
  30. else:
  31. return 32
  32. elif head_dim <= 192:
  33. return 64
  34. elif head_dim <= 224:
  35. return 64
  36. elif head_dim <= 256:
  37. return 64
  38. def _flash_attn_forward(
  39. q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
  40. ):
  41. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  42. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  43. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
  44. q,
  45. k,
  46. v,
  47. None,
  48. alibi_slopes,
  49. dropout_p,
  50. softmax_scale,
  51. causal,
  52. window_size[0],
  53. window_size[1],
  54. softcap,
  55. return_softmax,
  56. None,
  57. )
  58. return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
  59. def _flash_attn_varlen_forward(
  60. q,
  61. k,
  62. v,
  63. cu_seqlens_q,
  64. cu_seqlens_k,
  65. max_seqlen_q,
  66. max_seqlen_k,
  67. dropout_p,
  68. softmax_scale,
  69. causal,
  70. window_size=(-1, -1),
  71. softcap=0.0,
  72. alibi_slopes=None,
  73. return_softmax=False,
  74. block_table=None,
  75. leftpad_k=None,
  76. seqused_k=None,
  77. ):
  78. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  79. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  80. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
  81. q,
  82. k,
  83. v,
  84. None,
  85. cu_seqlens_q,
  86. cu_seqlens_k,
  87. seqused_k,
  88. leftpad_k,
  89. block_table,
  90. alibi_slopes,
  91. max_seqlen_q,
  92. max_seqlen_k,
  93. dropout_p,
  94. softmax_scale,
  95. False,
  96. causal,
  97. window_size[0],
  98. window_size[1],
  99. softcap,
  100. return_softmax,
  101. None,
  102. )
  103. # if out.isnan().any() or softmax_lse.isnan().any():
  104. # breakpoint()
  105. return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
  106. def _flash_attn_backward(
  107. dout,
  108. q,
  109. k,
  110. v,
  111. out,
  112. softmax_lse,
  113. dq,
  114. dk,
  115. dv,
  116. dropout_p,
  117. softmax_scale,
  118. causal,
  119. window_size,
  120. softcap,
  121. alibi_slopes,
  122. deterministic,
  123. rng_state=None,
  124. ):
  125. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  126. # dq, dk, dv are allocated by us so they should already be contiguous
  127. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  128. (
  129. dq,
  130. dk,
  131. dv,
  132. softmax_d,
  133. ) = flash_attn_cuda.bwd(
  134. dout,
  135. q,
  136. k,
  137. v,
  138. out,
  139. softmax_lse,
  140. dq,
  141. dk,
  142. dv,
  143. alibi_slopes,
  144. dropout_p,
  145. softmax_scale,
  146. causal,
  147. window_size[0],
  148. window_size[1],
  149. softcap,
  150. deterministic,
  151. None,
  152. rng_state,
  153. )
  154. return dq, dk, dv, softmax_d
  155. def _flash_attn_varlen_backward(
  156. dout,
  157. q,
  158. k,
  159. v,
  160. out,
  161. softmax_lse,
  162. dq,
  163. dk,
  164. dv,
  165. cu_seqlens_q,
  166. cu_seqlens_k,
  167. max_seqlen_q,
  168. max_seqlen_k,
  169. dropout_p,
  170. softmax_scale,
  171. causal,
  172. window_size,
  173. softcap,
  174. alibi_slopes,
  175. deterministic,
  176. rng_state=None,
  177. ):
  178. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  179. # dq, dk, dv are allocated by us so they should already be contiguous
  180. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  181. (
  182. dq,
  183. dk,
  184. dv,
  185. softmax_d,
  186. ) = flash_attn_cuda.varlen_bwd(
  187. dout,
  188. q,
  189. k,
  190. v,
  191. out,
  192. softmax_lse,
  193. dq,
  194. dk,
  195. dv,
  196. cu_seqlens_q,
  197. cu_seqlens_k,
  198. alibi_slopes,
  199. max_seqlen_q,
  200. max_seqlen_k,
  201. dropout_p,
  202. softmax_scale,
  203. False,
  204. causal,
  205. window_size[0],
  206. window_size[1],
  207. softcap,
  208. deterministic,
  209. None,
  210. rng_state,
  211. )
  212. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  213. # breakpoint()
  214. return dq, dk, dv, softmax_d
  215. class FlashAttnQKVPackedFunc(torch.autograd.Function):
  216. @staticmethod
  217. def forward(
  218. ctx,
  219. qkv,
  220. dropout_p,
  221. softmax_scale,
  222. causal,
  223. window_size,
  224. softcap,
  225. alibi_slopes,
  226. deterministic,
  227. return_softmax,
  228. ):
  229. if softmax_scale is None:
  230. softmax_scale = qkv.shape[-1] ** (-0.5)
  231. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  232. qkv[:, :, 0],
  233. qkv[:, :, 1],
  234. qkv[:, :, 2],
  235. dropout_p,
  236. softmax_scale,
  237. causal=causal,
  238. window_size=window_size,
  239. softcap=softcap,
  240. alibi_slopes=alibi_slopes,
  241. return_softmax=return_softmax and dropout_p > 0,
  242. )
  243. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  244. ctx.dropout_p = dropout_p
  245. ctx.softmax_scale = softmax_scale
  246. ctx.causal = causal
  247. ctx.window_size = window_size
  248. ctx.softcap = softcap
  249. ctx.alibi_slopes = alibi_slopes
  250. ctx.deterministic = deterministic
  251. return out if not return_softmax else (out, softmax_lse, S_dmask)
  252. @staticmethod
  253. def backward(ctx, dout, *args):
  254. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  255. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  256. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  257. _flash_attn_backward(
  258. dout,
  259. q,
  260. k,
  261. v,
  262. out,
  263. softmax_lse,
  264. dqkv[:, :, 0],
  265. dqkv[:, :, 1],
  266. dqkv[:, :, 2],
  267. ctx.dropout_p,
  268. ctx.softmax_scale,
  269. ctx.causal,
  270. ctx.window_size,
  271. ctx.softcap,
  272. ctx.alibi_slopes,
  273. ctx.deterministic,
  274. rng_state=rng_state,
  275. )
  276. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  277. return dqkv, None, None, None, None, None, None, None, None
  278. class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
  279. @staticmethod
  280. def forward(
  281. ctx,
  282. qkv,
  283. cu_seqlens,
  284. max_seqlen,
  285. dropout_p,
  286. softmax_scale,
  287. causal,
  288. window_size,
  289. softcap,
  290. alibi_slopes,
  291. deterministic,
  292. return_softmax,
  293. ):
  294. if softmax_scale is None:
  295. softmax_scale = qkv.shape[-1] ** (-0.5)
  296. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  297. qkv[:, 0],
  298. qkv[:, 1],
  299. qkv[:, 2],
  300. cu_seqlens,
  301. cu_seqlens,
  302. max_seqlen,
  303. max_seqlen,
  304. dropout_p,
  305. softmax_scale,
  306. causal=causal,
  307. window_size=window_size,
  308. softcap=softcap,
  309. alibi_slopes=alibi_slopes,
  310. return_softmax=return_softmax and dropout_p > 0,
  311. block_table=None,
  312. )
  313. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
  314. ctx.dropout_p = dropout_p
  315. ctx.max_seqlen = max_seqlen
  316. ctx.softmax_scale = softmax_scale
  317. ctx.causal = causal
  318. ctx.window_size = window_size
  319. ctx.softcap = softcap
  320. ctx.alibi_slopes = alibi_slopes
  321. ctx.deterministic = deterministic
  322. return out if not return_softmax else (out, softmax_lse, S_dmask)
  323. @staticmethod
  324. def backward(ctx, dout, *args):
  325. q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
  326. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  327. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  328. _flash_attn_varlen_backward(
  329. dout,
  330. q,
  331. k,
  332. v,
  333. out,
  334. softmax_lse,
  335. dqkv[:, 0],
  336. dqkv[:, 1],
  337. dqkv[:, 2],
  338. cu_seqlens,
  339. cu_seqlens,
  340. ctx.max_seqlen,
  341. ctx.max_seqlen,
  342. ctx.dropout_p,
  343. ctx.softmax_scale,
  344. ctx.causal,
  345. ctx.window_size,
  346. ctx.softcap,
  347. ctx.alibi_slopes,
  348. ctx.deterministic,
  349. rng_state=rng_state,
  350. )
  351. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  352. return dqkv, None, None, None, None, None, None, None, None, None, None
  353. class FlashAttnKVPackedFunc(torch.autograd.Function):
  354. @staticmethod
  355. def forward(
  356. ctx,
  357. q,
  358. kv,
  359. dropout_p,
  360. softmax_scale,
  361. causal,
  362. window_size,
  363. softcap,
  364. alibi_slopes,
  365. deterministic,
  366. return_softmax,
  367. ):
  368. if softmax_scale is None:
  369. softmax_scale = q.shape[-1] ** (-0.5)
  370. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  371. q,
  372. kv[:, :, 0],
  373. kv[:, :, 1],
  374. dropout_p,
  375. softmax_scale,
  376. causal=causal,
  377. window_size=window_size,
  378. softcap=softcap,
  379. alibi_slopes=alibi_slopes,
  380. return_softmax=return_softmax and dropout_p > 0,
  381. )
  382. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  383. ctx.dropout_p = dropout_p
  384. ctx.softmax_scale = softmax_scale
  385. ctx.causal = causal
  386. ctx.window_size = window_size
  387. ctx.softcap = softcap
  388. ctx.alibi_slopes = alibi_slopes
  389. ctx.deterministic = deterministic
  390. return out if not return_softmax else (out, softmax_lse, S_dmask)
  391. @staticmethod
  392. def backward(ctx, dout, *args):
  393. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  394. dq = torch.empty_like(q)
  395. kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
  396. dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
  397. _flash_attn_backward(
  398. dout,
  399. q,
  400. k,
  401. v,
  402. out,
  403. softmax_lse,
  404. dq,
  405. dkv[:, :, 0],
  406. dkv[:, :, 1],
  407. ctx.dropout_p,
  408. ctx.softmax_scale,
  409. ctx.causal,
  410. ctx.window_size,
  411. ctx.softcap,
  412. ctx.alibi_slopes,
  413. ctx.deterministic,
  414. rng_state=rng_state,
  415. )
  416. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  417. dkv = dkv[..., : dout.shape[-1]]
  418. return dq, dkv, None, None, None, None, None, None, None, None
  419. class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
  420. @staticmethod
  421. def forward(
  422. ctx,
  423. q,
  424. kv,
  425. cu_seqlens_q,
  426. cu_seqlens_k,
  427. max_seqlen_q,
  428. max_seqlen_k,
  429. dropout_p,
  430. softmax_scale,
  431. causal,
  432. window_size,
  433. softcap,
  434. alibi_slopes,
  435. deterministic,
  436. return_softmax,
  437. ):
  438. if softmax_scale is None:
  439. softmax_scale = q.shape[-1] ** (-0.5)
  440. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  441. q,
  442. kv[:, 0],
  443. kv[:, 1],
  444. cu_seqlens_q,
  445. cu_seqlens_k,
  446. max_seqlen_q,
  447. max_seqlen_k,
  448. dropout_p,
  449. softmax_scale,
  450. causal=causal,
  451. window_size=window_size,
  452. softcap=softcap,
  453. alibi_slopes=alibi_slopes,
  454. return_softmax=return_softmax and dropout_p > 0,
  455. block_table=None,
  456. )
  457. ctx.save_for_backward(
  458. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
  459. )
  460. ctx.dropout_p = dropout_p
  461. ctx.max_seqlen_q = max_seqlen_q
  462. ctx.max_seqlen_k = max_seqlen_k
  463. ctx.softmax_scale = softmax_scale
  464. ctx.causal = causal
  465. ctx.window_size = window_size
  466. ctx.softcap = softcap
  467. ctx.alibi_slopes = alibi_slopes
  468. ctx.deterministic = deterministic
  469. return out if not return_softmax else (out, softmax_lse, S_dmask)
  470. @staticmethod
  471. def backward(ctx, dout, *args):
  472. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
  473. dq = torch.empty_like(q)
  474. kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
  475. dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
  476. _flash_attn_varlen_backward(
  477. dout,
  478. q,
  479. k,
  480. v,
  481. out,
  482. softmax_lse,
  483. dq,
  484. dkv[:, 0],
  485. dkv[:, 1],
  486. cu_seqlens_q,
  487. cu_seqlens_k,
  488. ctx.max_seqlen_q,
  489. ctx.max_seqlen_k,
  490. ctx.dropout_p,
  491. ctx.softmax_scale,
  492. ctx.causal,
  493. ctx.window_size,
  494. ctx.softcap,
  495. ctx.alibi_slopes,
  496. ctx.deterministic,
  497. rng_state=rng_state,
  498. )
  499. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  500. dkv = dkv[..., : dout.shape[-1]]
  501. return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
  502. class FlashAttnFunc(torch.autograd.Function):
  503. @staticmethod
  504. def forward(
  505. ctx,
  506. q,
  507. k,
  508. v,
  509. dropout_p,
  510. softmax_scale,
  511. causal,
  512. window_size,
  513. softcap,
  514. alibi_slopes,
  515. deterministic,
  516. return_softmax,
  517. ):
  518. if softmax_scale is None:
  519. softmax_scale = q.shape[-1] ** (-0.5)
  520. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  521. q,
  522. k,
  523. v,
  524. dropout_p,
  525. softmax_scale,
  526. causal=causal,
  527. window_size=window_size,
  528. softcap=softcap,
  529. alibi_slopes=alibi_slopes,
  530. return_softmax=return_softmax and dropout_p > 0,
  531. )
  532. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  533. ctx.dropout_p = dropout_p
  534. ctx.softmax_scale = softmax_scale
  535. ctx.causal = causal
  536. ctx.window_size = window_size
  537. ctx.softcap = softcap
  538. ctx.alibi_slopes = alibi_slopes
  539. ctx.deterministic = deterministic
  540. return out if not return_softmax else (out, softmax_lse, S_dmask)
  541. @staticmethod
  542. def backward(ctx, dout, *args):
  543. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  544. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  545. _flash_attn_backward(
  546. dout,
  547. q,
  548. k,
  549. v,
  550. out,
  551. softmax_lse,
  552. dq,
  553. dk,
  554. dv,
  555. ctx.dropout_p,
  556. ctx.softmax_scale,
  557. ctx.causal,
  558. ctx.window_size,
  559. ctx.softcap,
  560. ctx.alibi_slopes,
  561. ctx.deterministic,
  562. rng_state=rng_state,
  563. )
  564. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  565. dk = dk[..., : dout.shape[-1]]
  566. dv = dv[..., : dout.shape[-1]]
  567. return dq, dk, dv, None, None, None, None, None, None, None, None
  568. class FlashAttnVarlenFunc(torch.autograd.Function):
  569. @staticmethod
  570. def forward(
  571. ctx,
  572. q,
  573. k,
  574. v,
  575. cu_seqlens_q,
  576. cu_seqlens_k,
  577. max_seqlen_q,
  578. max_seqlen_k,
  579. dropout_p,
  580. softmax_scale,
  581. causal,
  582. window_size,
  583. softcap,
  584. alibi_slopes,
  585. deterministic,
  586. return_softmax,
  587. block_table,
  588. ):
  589. if softmax_scale is None:
  590. softmax_scale = q.shape[-1] ** (-0.5)
  591. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  592. q,
  593. k,
  594. v,
  595. cu_seqlens_q,
  596. cu_seqlens_k,
  597. max_seqlen_q,
  598. max_seqlen_k,
  599. dropout_p,
  600. softmax_scale,
  601. causal=causal,
  602. window_size=window_size,
  603. softcap=softcap,
  604. alibi_slopes=alibi_slopes,
  605. return_softmax=return_softmax and dropout_p > 0,
  606. block_table=block_table,
  607. )
  608. ctx.save_for_backward(
  609. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
  610. )
  611. ctx.dropout_p = dropout_p
  612. ctx.max_seqlen_q = max_seqlen_q
  613. ctx.max_seqlen_k = max_seqlen_k
  614. ctx.softmax_scale = softmax_scale
  615. ctx.causal = causal
  616. ctx.window_size = window_size
  617. ctx.softcap = softcap
  618. ctx.alibi_slopes = alibi_slopes
  619. ctx.deterministic = deterministic
  620. return out if not return_softmax else (out, softmax_lse, S_dmask)
  621. @staticmethod
  622. def backward(ctx, dout, *args):
  623. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
  624. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  625. _flash_attn_varlen_backward(
  626. dout,
  627. q,
  628. k,
  629. v,
  630. out,
  631. softmax_lse,
  632. dq,
  633. dk,
  634. dv,
  635. cu_seqlens_q,
  636. cu_seqlens_k,
  637. ctx.max_seqlen_q,
  638. ctx.max_seqlen_k,
  639. ctx.dropout_p,
  640. ctx.softmax_scale,
  641. ctx.causal,
  642. ctx.window_size,
  643. ctx.softcap,
  644. ctx.alibi_slopes,
  645. ctx.deterministic,
  646. rng_state=rng_state,
  647. )
  648. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  649. dk = dk[..., : dout.shape[-1]]
  650. dv = dv[..., : dout.shape[-1]]
  651. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
  652. def flash_attn_qkvpacked_func(
  653. qkv,
  654. dropout_p=0.0,
  655. softmax_scale=None,
  656. causal=False,
  657. window_size=(-1, -1), # -1 means infinite context window
  658. softcap=0.0, # <=0.0 means deactivate
  659. alibi_slopes=None,
  660. deterministic=False,
  661. return_attn_probs=False,
  662. ):
  663. """dropout_p should be set to 0.0 during evaluation
  664. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  665. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  666. of the gradients of Q, K, V.
  667. For multi-query and grouped-query attention (MQA/GQA), please see
  668. flash_attn_kvpacked_func and flash_attn_func.
  669. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  670. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  671. Arguments:
  672. qkv: (batch_size, seqlen, 3, nheads, headdim)
  673. dropout_p: float. Dropout probability.
  674. softmax_scale: float. The scaling of QK^T before applying softmax.
  675. Default to 1 / sqrt(headdim).
  676. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  677. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  678. softcap: float. Anything > 0 activates softcapping attention.
  679. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
  680. the attention score of query i and key j.
  681. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  682. which is slightly slower and uses more memory. The forward pass is always deterministic.
  683. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  684. testing only. The returned probabilities are not guaranteed to be correct
  685. (they might not have the right scaling).
  686. Return:
  687. out: (batch_size, seqlen, nheads, headdim).
  688. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  689. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  690. normalization factor).
  691. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  692. The output of softmax (possibly with different scaling). It also encodes the dropout
  693. pattern (negative means that location was dropped, nonnegative means it was kept).
  694. """
  695. return FlashAttnQKVPackedFunc.apply(
  696. qkv,
  697. dropout_p,
  698. softmax_scale,
  699. causal,
  700. window_size,
  701. softcap,
  702. alibi_slopes,
  703. deterministic,
  704. return_attn_probs,
  705. )
  706. def flash_attn_kvpacked_func(
  707. q,
  708. kv,
  709. dropout_p=0.0,
  710. softmax_scale=None,
  711. causal=False,
  712. window_size=(-1, -1), # -1 means infinite context window
  713. softcap=0.0, # 0.0 means deactivated
  714. alibi_slopes=None,
  715. deterministic=False,
  716. return_attn_probs=False,
  717. ):
  718. """dropout_p should be set to 0.0 during evaluation
  719. If K, V are already stacked into 1 tensor, this function will be faster than
  720. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  721. of the gradients of K, V.
  722. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  723. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  724. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  725. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  726. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  727. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  728. 1 1 1 1 0
  729. 1 1 1 1 1
  730. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  731. 0 0
  732. 0 0
  733. 0 0
  734. 1 0
  735. 1 1
  736. If the row of the mask is all zero, the output will be zero.
  737. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  738. will only attend to keys between
  739. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  740. Arguments:
  741. q: (batch_size, seqlen, nheads, headdim)
  742. kv: (batch_size, seqlen, 2, nheads_k, headdim)
  743. dropout_p: float. Dropout probability.
  744. softmax_scale: float. The scaling of QK^T before applying softmax.
  745. Default to 1 / sqrt(headdim).
  746. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  747. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  748. softcap: float. Anything > 0 activates softcapping attention.
  749. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  750. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  751. is added to the attention score of query i and key j.
  752. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  753. which is slightly slower and uses more memory. The forward pass is always deterministic.
  754. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  755. testing only. The returned probabilities are not guaranteed to be correct
  756. (they might not have the right scaling).
  757. Return:
  758. out: (batch_size, seqlen, nheads, headdim).
  759. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  760. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  761. normalization factor).
  762. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  763. The output of softmax (possibly with different scaling). It also encodes the dropout
  764. pattern (negative means that location was dropped, nonnegative means it was kept).
  765. """
  766. return FlashAttnKVPackedFunc.apply(
  767. q,
  768. kv,
  769. dropout_p,
  770. softmax_scale,
  771. causal,
  772. window_size,
  773. softcap,
  774. alibi_slopes,
  775. deterministic,
  776. return_attn_probs,
  777. )
  778. def flash_attn_func(
  779. q,
  780. k,
  781. v,
  782. dropout_p=0.0,
  783. softmax_scale=None,
  784. causal=False,
  785. window_size=(-1, -1), # -1 means infinite context window
  786. softcap=0.0, # 0.0 means deactivated
  787. alibi_slopes=None,
  788. deterministic=False,
  789. return_attn_probs=False,
  790. ):
  791. """dropout_p should be set to 0.0 during evaluation
  792. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  793. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  794. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  795. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  796. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  797. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  798. 1 1 1 1 0
  799. 1 1 1 1 1
  800. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  801. 0 0
  802. 0 0
  803. 0 0
  804. 1 0
  805. 1 1
  806. If the row of the mask is all zero, the output will be zero.
  807. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  808. will only attend to keys between
  809. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  810. Arguments:
  811. q: (batch_size, seqlen, nheads, headdim)
  812. k: (batch_size, seqlen, nheads_k, headdim)
  813. v: (batch_size, seqlen, nheads_k, headdim)
  814. dropout_p: float. Dropout probability.
  815. softmax_scale: float. The scaling of QK^T before applying softmax.
  816. Default to 1 / sqrt(headdim).
  817. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  818. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  819. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  820. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  821. is added to the attention score of query i and key j.
  822. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  823. which is slightly slower and uses more memory. The forward pass is always deterministic.
  824. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  825. testing only. The returned probabilities are not guaranteed to be correct
  826. (they might not have the right scaling).
  827. Return:
  828. out: (batch_size, seqlen, nheads, headdim).
  829. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  830. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  831. normalization factor).
  832. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  833. The output of softmax (possibly with different scaling). It also encodes the dropout
  834. pattern (negative means that location was dropped, nonnegative means it was kept).
  835. """
  836. return FlashAttnFunc.apply(
  837. q,
  838. k,
  839. v,
  840. dropout_p,
  841. softmax_scale,
  842. causal,
  843. window_size,
  844. softcap,
  845. alibi_slopes,
  846. deterministic,
  847. return_attn_probs,
  848. )
  849. def flash_attn_varlen_qkvpacked_func(
  850. qkv,
  851. cu_seqlens,
  852. max_seqlen,
  853. dropout_p=0.0,
  854. softmax_scale=None,
  855. causal=False,
  856. window_size=(-1, -1), # -1 means infinite context window
  857. softcap=0.0, # 0.0 means deactivated
  858. alibi_slopes=None,
  859. deterministic=False,
  860. return_attn_probs=False,
  861. ):
  862. """dropout_p should be set to 0.0 during evaluation
  863. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  864. calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
  865. of the gradients of Q, K, V.
  866. For multi-query and grouped-query attention (MQA/GQA), please see
  867. flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
  868. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  869. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  870. Arguments:
  871. qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
  872. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  873. of the sequences in the batch, used to index into qkv.
  874. max_seqlen: int. Maximum sequence length in the batch.
  875. dropout_p: float. Dropout probability.
  876. softmax_scale: float. The scaling of QK^T before applying softmax.
  877. Default to 1 / sqrt(headdim).
  878. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  879. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  880. softcap: float. Anything > 0 activates softcapping attention.
  881. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
  882. is added to the attention score of query i and key j.
  883. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  884. which is slightly slower and uses more memory. The forward pass is always deterministic.
  885. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  886. testing only. The returned probabilities are not guaranteed to be correct
  887. (they might not have the right scaling).
  888. Return:
  889. out: (total, nheads, headdim).
  890. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  891. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  892. normalization factor).
  893. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  894. The output of softmax (possibly with different scaling). It also encodes the dropout
  895. pattern (negative means that location was dropped, nonnegative means it was kept).
  896. """
  897. return FlashAttnVarlenQKVPackedFunc.apply(
  898. qkv,
  899. cu_seqlens,
  900. max_seqlen,
  901. dropout_p,
  902. softmax_scale,
  903. causal,
  904. window_size,
  905. softcap,
  906. alibi_slopes,
  907. deterministic,
  908. return_attn_probs,
  909. )
  910. def flash_attn_varlen_kvpacked_func(
  911. q,
  912. kv,
  913. cu_seqlens_q,
  914. cu_seqlens_k,
  915. max_seqlen_q,
  916. max_seqlen_k,
  917. dropout_p=0.0,
  918. softmax_scale=None,
  919. causal=False,
  920. window_size=(-1, -1), # -1 means infinite context window
  921. softcap=0.0, # 0.0 means deactivated
  922. alibi_slopes=None,
  923. deterministic=False,
  924. return_attn_probs=False,
  925. ):
  926. """dropout_p should be set to 0.0 during evaluation
  927. If K, V are already stacked into 1 tensor, this function will be faster than
  928. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  929. of the gradients of K, V.
  930. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  931. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  932. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  933. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  934. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  935. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  936. 1 1 1 1 0
  937. 1 1 1 1 1
  938. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  939. 0 0
  940. 0 0
  941. 0 0
  942. 1 0
  943. 1 1
  944. If the row of the mask is all zero, the output will be zero.
  945. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  946. will only attend to keys between
  947. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  948. Arguments:
  949. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  950. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  951. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  952. of the sequences in the batch, used to index into q.
  953. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  954. of the sequences in the batch, used to index into kv.
  955. max_seqlen_q: int. Maximum query sequence length in the batch.
  956. max_seqlen_k: int. Maximum key sequence length in the batch.
  957. dropout_p: float. Dropout probability.
  958. softmax_scale: float. The scaling of QK^T before applying softmax.
  959. Default to 1 / sqrt(headdim).
  960. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  961. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  962. softcap: float. Anything > 0 activates softcapping attention.
  963. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  964. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  965. is added to the attention score of query i and key j.
  966. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  967. which is slightly slower and uses more memory. The forward pass is always deterministic.
  968. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  969. testing only. The returned probabilities are not guaranteed to be correct
  970. (they might not have the right scaling).
  971. Return:
  972. out: (total, nheads, headdim).
  973. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  974. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  975. normalization factor).
  976. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  977. The output of softmax (possibly with different scaling). It also encodes the dropout
  978. pattern (negative means that location was dropped, nonnegative means it was kept).
  979. """
  980. return FlashAttnVarlenKVPackedFunc.apply(
  981. q,
  982. kv,
  983. cu_seqlens_q,
  984. cu_seqlens_k,
  985. max_seqlen_q,
  986. max_seqlen_k,
  987. dropout_p,
  988. softmax_scale,
  989. causal,
  990. window_size,
  991. softcap,
  992. alibi_slopes,
  993. deterministic,
  994. return_attn_probs,
  995. )
  996. def flash_attn_varlen_func(
  997. q,
  998. k,
  999. v,
  1000. cu_seqlens_q,
  1001. cu_seqlens_k,
  1002. max_seqlen_q,
  1003. max_seqlen_k,
  1004. dropout_p=0.0,
  1005. softmax_scale=None,
  1006. causal=False,
  1007. window_size=(-1, -1), # -1 means infinite context window
  1008. softcap=0.0, # 0.0 means deactivated
  1009. alibi_slopes=None,
  1010. deterministic=False,
  1011. return_attn_probs=False,
  1012. block_table=None,
  1013. ):
  1014. """dropout_p should be set to 0.0 during evaluation
  1015. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  1016. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1017. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1018. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1019. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1020. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1021. 1 1 1 1 0
  1022. 1 1 1 1 1
  1023. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1024. 0 0
  1025. 0 0
  1026. 0 0
  1027. 1 0
  1028. 1 1
  1029. If the row of the mask is all zero, the output will be zero.
  1030. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1031. will only attend to keys between
  1032. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1033. Arguments:
  1034. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  1035. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1036. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  1037. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1038. of the sequences in the batch, used to index into q.
  1039. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  1040. of the sequences in the batch, used to index into kv.
  1041. max_seqlen_q: int. Maximum query sequence length in the batch.
  1042. max_seqlen_k: int. Maximum key sequence length in the batch.
  1043. dropout_p: float. Dropout probability.
  1044. softmax_scale: float. The scaling of QK^T before applying softmax.
  1045. Default to 1 / sqrt(headdim).
  1046. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1047. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1048. softcap: float. Anything > 0 activates softcapping attention.
  1049. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1050. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1051. is added to the attention score of query i and key j.
  1052. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  1053. which is slightly slower and uses more memory. The forward pass is always deterministic.
  1054. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  1055. testing only. The returned probabilities are not guaranteed to be correct
  1056. (they might not have the right scaling).
  1057. Return:
  1058. out: (total, nheads, headdim).
  1059. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  1060. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1061. normalization factor).
  1062. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  1063. The output of softmax (possibly with different scaling). It also encodes the dropout
  1064. pattern (negative means that location was dropped, nonnegative means it was kept).
  1065. """
  1066. return FlashAttnVarlenFunc.apply(
  1067. q,
  1068. k,
  1069. v,
  1070. cu_seqlens_q,
  1071. cu_seqlens_k,
  1072. max_seqlen_q,
  1073. max_seqlen_k,
  1074. dropout_p,
  1075. softmax_scale,
  1076. causal,
  1077. window_size,
  1078. softcap,
  1079. alibi_slopes,
  1080. deterministic,
  1081. return_attn_probs,
  1082. block_table,
  1083. )
  1084. def flash_attn_with_kvcache(
  1085. q,
  1086. k_cache,
  1087. v_cache,
  1088. k=None,
  1089. v=None,
  1090. rotary_cos=None,
  1091. rotary_sin=None,
  1092. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  1093. cache_batch_idx: Optional[torch.Tensor] = None,
  1094. cache_leftpad: Optional[torch.Tensor] = None,
  1095. block_table: Optional[torch.Tensor] = None,
  1096. softmax_scale=None,
  1097. causal=False,
  1098. window_size=(-1, -1), # -1 means infinite context window
  1099. softcap=0.0, # 0.0 means deactivated
  1100. rotary_interleaved=True,
  1101. alibi_slopes=None,
  1102. num_splits=0,
  1103. return_softmax_lse=False,
  1104. ):
  1105. """
  1106. If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
  1107. k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
  1108. the previous step, and update them with the new keys/values from the current step, and do
  1109. attention with the updated cache, all in 1 kernel.
  1110. If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
  1111. For example, the KV cache could be pre-allocated with the max sequence length, and you can use
  1112. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
  1113. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
  1114. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  1115. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
  1116. and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  1117. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
  1118. indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
  1119. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
  1120. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  1121. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  1122. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  1123. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  1124. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  1125. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  1126. 1 1 1 1 0
  1127. 1 1 1 1 1
  1128. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  1129. 0 0
  1130. 0 0
  1131. 0 0
  1132. 1 0
  1133. 1 1
  1134. If the row of the mask is all zero, the output will be zero.
  1135. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  1136. will only attend to keys between
  1137. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  1138. Note: Does not support backward pass.
  1139. Arguments:
  1140. q: (batch_size, seqlen, nheads, headdim)
  1141. k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  1142. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  1143. page_block_size must be a multiple of 256.
  1144. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
  1145. or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
  1146. k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
  1147. k with k_cache, starting at the indices specified by cache_seqlens.
  1148. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
  1149. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
  1150. to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
  1151. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
  1152. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
  1153. KV cache.
  1154. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
  1155. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
  1156. If the indices are not distinct, and k and v are provided, the values updated in the cache
  1157. might come from any of the duplicate indices.
  1158. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
  1159. block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
  1160. softmax_scale: float. The scaling of QK^T before applying softmax.
  1161. Default to 1 / sqrt(headdim).
  1162. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  1163. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  1164. softcap: float. Anything > 0 activates softcapping attention.
  1165. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
  1166. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
  1167. rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
  1168. (i.e. GPT-NeoX style).
  1169. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  1170. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  1171. is added to the attention score of query i and key j.
  1172. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
  1173. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
  1174. to automatically determine the number of splits.
  1175. Don't change this unless you know what you are doing.
  1176. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
  1177. Return:
  1178. out: (batch_size, seqlen, nheads, headdim).
  1179. softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
  1180. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  1181. normalization factor).
  1182. """
  1183. assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
  1184. assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
  1185. maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
  1186. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  1187. if softmax_scale is None:
  1188. softmax_scale = q.shape[-1] ** (-0.5)
  1189. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  1190. cache_seqlens = torch.full(
  1191. (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
  1192. )
  1193. cache_seqlens = maybe_contiguous(cache_seqlens)
  1194. cache_batch_idx = maybe_contiguous(cache_batch_idx)
  1195. block_table = maybe_contiguous(block_table)
  1196. out, softmax_lse = flash_attn_cuda.fwd_kvcache(
  1197. q,
  1198. k_cache,
  1199. v_cache,
  1200. k,
  1201. v,
  1202. cache_seqlens,
  1203. rotary_cos,
  1204. rotary_sin,
  1205. cache_batch_idx,
  1206. cache_leftpad,
  1207. block_table,
  1208. alibi_slopes,
  1209. None,
  1210. softmax_scale,
  1211. causal,
  1212. window_size[0],
  1213. window_size[1],
  1214. softcap,
  1215. rotary_interleaved,
  1216. num_splits,
  1217. )
  1218. return (out, softmax_lse) if return_softmax_lse else out