1
0

flash_attn_interface.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900
  1. from typing import Optional, Union
  2. import torch
  3. import torch.nn as nn
  4. # isort: off
  5. # We need to import the CUDA kernels after importing torch
  6. import flash_attn_2_cuda as flash_attn_cuda
  7. # isort: on
  8. def _get_block_size(device, head_dim, is_dropout, is_causal):
  9. # This should match the block sizes in the CUDA kernel
  10. assert head_dim <= 256
  11. major, minor = torch.cuda.get_device_capability(device)
  12. is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
  13. is_sm80 = major == 8 and minor == 0
  14. is_sm90 = major == 9 and minor == 0
  15. if head_dim <= 32:
  16. return 128, 128
  17. if head_dim <= 64:
  18. return (128, 128) if not is_dropout else (128, 64)
  19. elif head_dim <= 96:
  20. return (64, 64) if (is_sm8x and is_causal) else (128, 64)
  21. elif head_dim <= 128:
  22. if is_sm8x:
  23. return (64, 64) if (not is_dropout and is_causal) else (128, 32)
  24. else:
  25. return 128, (64 if not is_dropout else 32)
  26. elif head_dim <= 160:
  27. if is_sm8x:
  28. return (128, 64) if not is_causal else (64, 64)
  29. else:
  30. return 128, 32
  31. elif head_dim <= 192:
  32. return (128, 64) if not is_dropout else (64, 64)
  33. elif head_dim <= 224:
  34. return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
  35. elif head_dim <= 256:
  36. return (128, 64) if is_sm80 else (64, 64)
  37. def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
  38. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  39. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  40. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
  41. q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
  42. )
  43. return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
  44. def _flash_attn_varlen_forward(
  45. q,
  46. k,
  47. v,
  48. cu_seqlens_q,
  49. cu_seqlens_k,
  50. max_seqlen_q,
  51. max_seqlen_k,
  52. dropout_p,
  53. softmax_scale,
  54. causal,
  55. return_softmax,
  56. ):
  57. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  58. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  59. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
  60. q,
  61. k,
  62. v,
  63. None,
  64. cu_seqlens_q,
  65. cu_seqlens_k,
  66. max_seqlen_q,
  67. max_seqlen_k,
  68. dropout_p,
  69. softmax_scale,
  70. False,
  71. causal,
  72. return_softmax,
  73. None,
  74. )
  75. # if out.isnan().any() or softmax_lse.isnan().any():
  76. # breakpoint()
  77. return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
  78. def _flash_attn_backward(
  79. dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, rng_state=None
  80. ):
  81. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  82. # dq, dk, dv are allocated by us so they should already be contiguous
  83. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  84. dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
  85. dout,
  86. q,
  87. k,
  88. v,
  89. out,
  90. softmax_lse,
  91. dq,
  92. dk,
  93. dv,
  94. dropout_p,
  95. softmax_scale,
  96. causal,
  97. None,
  98. rng_state,
  99. )
  100. return dq, dk, dv, softmax_d
  101. def _flash_attn_varlen_backward(
  102. dout,
  103. q,
  104. k,
  105. v,
  106. out,
  107. softmax_lse,
  108. dq,
  109. dk,
  110. dv,
  111. cu_seqlens_q,
  112. cu_seqlens_k,
  113. max_seqlen_q,
  114. max_seqlen_k,
  115. dropout_p,
  116. softmax_scale,
  117. causal,
  118. rng_state=None,
  119. ):
  120. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  121. # dq, dk, dv are allocated by us so they should already be contiguous
  122. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  123. dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
  124. dout,
  125. q,
  126. k,
  127. v,
  128. out,
  129. softmax_lse,
  130. dq,
  131. dk,
  132. dv,
  133. cu_seqlens_q,
  134. cu_seqlens_k,
  135. max_seqlen_q,
  136. max_seqlen_k,
  137. dropout_p,
  138. softmax_scale,
  139. False,
  140. causal,
  141. None,
  142. rng_state,
  143. )
  144. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  145. # breakpoint()
  146. return dq, dk, dv, softmax_d
  147. class FlashAttnQKVPackedFunc(torch.autograd.Function):
  148. @staticmethod
  149. def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax):
  150. if softmax_scale is None:
  151. softmax_scale = qkv.shape[-1] ** (-0.5)
  152. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  153. qkv[:, :, 0],
  154. qkv[:, :, 1],
  155. qkv[:, :, 2],
  156. dropout_p,
  157. softmax_scale,
  158. causal=causal,
  159. return_softmax=return_softmax and dropout_p > 0,
  160. )
  161. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  162. ctx.dropout_p = dropout_p
  163. ctx.softmax_scale = softmax_scale
  164. ctx.causal = causal
  165. return out if not return_softmax else (out, softmax_lse, S_dmask)
  166. @staticmethod
  167. def backward(ctx, dout, *args):
  168. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  169. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  170. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  171. _flash_attn_backward(
  172. dout,
  173. q,
  174. k,
  175. v,
  176. out,
  177. softmax_lse,
  178. dqkv[:, :, 0],
  179. dqkv[:, :, 1],
  180. dqkv[:, :, 2],
  181. ctx.dropout_p,
  182. ctx.softmax_scale,
  183. ctx.causal,
  184. rng_state=rng_state,
  185. )
  186. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  187. return dqkv, None, None, None, None
  188. class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
  189. @staticmethod
  190. def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
  191. if softmax_scale is None:
  192. softmax_scale = qkv.shape[-1] ** (-0.5)
  193. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  194. qkv[:, 0],
  195. qkv[:, 1],
  196. qkv[:, 2],
  197. cu_seqlens,
  198. cu_seqlens,
  199. max_seqlen,
  200. max_seqlen,
  201. dropout_p,
  202. softmax_scale,
  203. causal=causal,
  204. return_softmax=return_softmax and dropout_p > 0,
  205. )
  206. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
  207. ctx.dropout_p = dropout_p
  208. ctx.max_seqlen = max_seqlen
  209. ctx.softmax_scale = softmax_scale
  210. ctx.causal = causal
  211. return out if not return_softmax else (out, softmax_lse, S_dmask)
  212. @staticmethod
  213. def backward(ctx, dout, *args):
  214. q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
  215. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  216. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  217. _flash_attn_varlen_backward(
  218. dout,
  219. q,
  220. k,
  221. v,
  222. out,
  223. softmax_lse,
  224. dqkv[:, 0],
  225. dqkv[:, 1],
  226. dqkv[:, 2],
  227. cu_seqlens,
  228. cu_seqlens,
  229. ctx.max_seqlen,
  230. ctx.max_seqlen,
  231. ctx.dropout_p,
  232. ctx.softmax_scale,
  233. ctx.causal,
  234. rng_state=rng_state,
  235. )
  236. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  237. return dqkv, None, None, None, None, None, None
  238. class FlashAttnKVPackedFunc(torch.autograd.Function):
  239. @staticmethod
  240. def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax):
  241. if softmax_scale is None:
  242. softmax_scale = q.shape[-1] ** (-0.5)
  243. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  244. q,
  245. kv[:, :, 0],
  246. kv[:, :, 1],
  247. dropout_p,
  248. softmax_scale,
  249. causal=causal,
  250. return_softmax=return_softmax and dropout_p > 0,
  251. )
  252. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  253. ctx.dropout_p = dropout_p
  254. ctx.softmax_scale = softmax_scale
  255. ctx.causal = causal
  256. return out if not return_softmax else (out, softmax_lse, S_dmask)
  257. @staticmethod
  258. def backward(ctx, dout, *args):
  259. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  260. dq = torch.empty_like(q)
  261. kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
  262. dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
  263. _flash_attn_backward(
  264. dout,
  265. q,
  266. k,
  267. v,
  268. out,
  269. softmax_lse,
  270. dq,
  271. dkv[:, :, 0],
  272. dkv[:, :, 1],
  273. ctx.dropout_p,
  274. ctx.softmax_scale,
  275. ctx.causal,
  276. rng_state=rng_state,
  277. )
  278. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  279. dkv = dkv[..., : dout.shape[-1]]
  280. return dq, dkv, None, None, None, None
  281. class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
  282. @staticmethod
  283. def forward(
  284. ctx,
  285. q,
  286. kv,
  287. cu_seqlens_q,
  288. cu_seqlens_k,
  289. max_seqlen_q,
  290. max_seqlen_k,
  291. dropout_p,
  292. softmax_scale,
  293. causal,
  294. return_softmax,
  295. ):
  296. if softmax_scale is None:
  297. softmax_scale = q.shape[-1] ** (-0.5)
  298. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  299. q,
  300. kv[:, 0],
  301. kv[:, 1],
  302. cu_seqlens_q,
  303. cu_seqlens_k,
  304. max_seqlen_q,
  305. max_seqlen_k,
  306. dropout_p,
  307. softmax_scale,
  308. causal=causal,
  309. return_softmax=return_softmax and dropout_p > 0,
  310. )
  311. ctx.save_for_backward(
  312. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
  313. )
  314. ctx.dropout_p = dropout_p
  315. ctx.max_seqlen_q = max_seqlen_q
  316. ctx.max_seqlen_k = max_seqlen_k
  317. ctx.softmax_scale = softmax_scale
  318. ctx.causal = causal
  319. return out if not return_softmax else (out, softmax_lse, S_dmask)
  320. @staticmethod
  321. def backward(ctx, dout, *args):
  322. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
  323. dq = torch.empty_like(q)
  324. kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
  325. dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
  326. _flash_attn_varlen_backward(
  327. dout,
  328. q,
  329. k,
  330. v,
  331. out,
  332. softmax_lse,
  333. dq,
  334. dkv[:, 0],
  335. dkv[:, 1],
  336. cu_seqlens_q,
  337. cu_seqlens_k,
  338. ctx.max_seqlen_q,
  339. ctx.max_seqlen_k,
  340. ctx.dropout_p,
  341. ctx.softmax_scale,
  342. ctx.causal,
  343. rng_state=rng_state,
  344. )
  345. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  346. dkv = dkv[..., : dout.shape[-1]]
  347. return dq, dkv, None, None, None, None, None, None, None, None
  348. class FlashAttnFunc(torch.autograd.Function):
  349. @staticmethod
  350. def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax):
  351. if softmax_scale is None:
  352. softmax_scale = q.shape[-1] ** (-0.5)
  353. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  354. q,
  355. k,
  356. v,
  357. dropout_p,
  358. softmax_scale,
  359. causal=causal,
  360. return_softmax=return_softmax and dropout_p > 0,
  361. )
  362. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  363. ctx.dropout_p = dropout_p
  364. ctx.softmax_scale = softmax_scale
  365. ctx.causal = causal
  366. return out if not return_softmax else (out, softmax_lse, S_dmask)
  367. @staticmethod
  368. def backward(ctx, dout, *args):
  369. q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
  370. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  371. _flash_attn_backward(
  372. dout,
  373. q,
  374. k,
  375. v,
  376. out,
  377. softmax_lse,
  378. dq,
  379. dk,
  380. dv,
  381. ctx.dropout_p,
  382. ctx.softmax_scale,
  383. ctx.causal,
  384. rng_state=rng_state,
  385. )
  386. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  387. dk = dk[..., : dout.shape[-1]]
  388. dv = dv[..., : dout.shape[-1]]
  389. return dq, dk, dv, None, None, None, None, None, None, None, None
  390. class FlashAttnVarlenFunc(torch.autograd.Function):
  391. @staticmethod
  392. def forward(
  393. ctx,
  394. q,
  395. k,
  396. v,
  397. cu_seqlens_q,
  398. cu_seqlens_k,
  399. max_seqlen_q,
  400. max_seqlen_k,
  401. dropout_p,
  402. softmax_scale,
  403. causal,
  404. return_softmax,
  405. ):
  406. if softmax_scale is None:
  407. softmax_scale = q.shape[-1] ** (-0.5)
  408. out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  409. q,
  410. k,
  411. v,
  412. cu_seqlens_q,
  413. cu_seqlens_k,
  414. max_seqlen_q,
  415. max_seqlen_k,
  416. dropout_p,
  417. softmax_scale,
  418. causal=causal,
  419. return_softmax=return_softmax and dropout_p > 0,
  420. )
  421. ctx.save_for_backward(
  422. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
  423. )
  424. ctx.dropout_p = dropout_p
  425. ctx.max_seqlen_q = max_seqlen_q
  426. ctx.max_seqlen_k = max_seqlen_k
  427. ctx.softmax_scale = softmax_scale
  428. ctx.causal = causal
  429. return out if not return_softmax else (out, softmax_lse, S_dmask)
  430. @staticmethod
  431. def backward(ctx, dout, *args):
  432. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
  433. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  434. _flash_attn_varlen_backward(
  435. dout,
  436. q,
  437. k,
  438. v,
  439. out,
  440. softmax_lse,
  441. dq,
  442. dk,
  443. dv,
  444. cu_seqlens_q,
  445. cu_seqlens_k,
  446. ctx.max_seqlen_q,
  447. ctx.max_seqlen_k,
  448. ctx.dropout_p,
  449. ctx.softmax_scale,
  450. ctx.causal,
  451. rng_state=rng_state,
  452. )
  453. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  454. dk = dk[..., : dout.shape[-1]]
  455. dv = dv[..., : dout.shape[-1]]
  456. return dq, dk, dv, None, None, None, None, None, None, None, None
  457. def flash_attn_qkvpacked_func(
  458. qkv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
  459. ):
  460. """dropout_p should be set to 0.0 during evaluation
  461. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  462. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  463. of the gradients of Q, K, V.
  464. For multi-query and grouped-query attention (MQA/GQA), please see
  465. flash_attn_kvpacked_func and flash_attn_func.
  466. Arguments:
  467. qkv: (batch_size, seqlen, 3, nheads, headdim)
  468. dropout_p: float. Dropout probability.
  469. softmax_scale: float. The scaling of QK^T before applying softmax.
  470. Default to 1 / sqrt(headdim).
  471. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  472. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  473. testing only. The returned probabilities are not guaranteed to be correct
  474. (they might not have the right scaling).
  475. Return:
  476. out: (batch_size, seqlen, nheads, headdim).
  477. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  478. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  479. normalization factor).
  480. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  481. The output of softmax (possibly with different scaling). It also encodes the dropout
  482. pattern (negative means that location was dropped, nonnegative means it was kept).
  483. """
  484. return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs)
  485. def flash_attn_kvpacked_func(
  486. q, kv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
  487. ):
  488. """dropout_p should be set to 0.0 during evaluation
  489. If K, V are already stacked into 1 tensor, this function will be faster than
  490. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  491. of the gradients of K, V.
  492. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  493. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  494. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  495. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  496. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  497. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  498. 1 1 1 1 0
  499. 1 1 1 1 1
  500. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  501. 0 0
  502. 0 0
  503. 0 0
  504. 1 0
  505. 1 1
  506. If the row of the mask is all zero, the output will be zero.
  507. Arguments:
  508. q: (batch_size, seqlen, nheads, headdim)
  509. kv: (batch_size, seqlen, 2, nheads_k, headdim)
  510. dropout_p: float. Dropout probability.
  511. softmax_scale: float. The scaling of QK^T before applying softmax.
  512. Default to 1 / sqrt(headdim).
  513. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  514. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  515. testing only. The returned probabilities are not guaranteed to be correct
  516. (they might not have the right scaling).
  517. Return:
  518. out: (batch_size, seqlen, nheads, headdim).
  519. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  520. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  521. normalization factor).
  522. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  523. The output of softmax (possibly with different scaling). It also encodes the dropout
  524. pattern (negative means that location was dropped, nonnegative means it was kept).
  525. """
  526. return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs)
  527. def flash_attn_func(
  528. q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
  529. ):
  530. """dropout_p should be set to 0.0 during evaluation
  531. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  532. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  533. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  534. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  535. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  536. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  537. 1 1 1 1 0
  538. 1 1 1 1 1
  539. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  540. 0 0
  541. 0 0
  542. 0 0
  543. 1 0
  544. 1 1
  545. If the row of the mask is all zero, the output will be zero.
  546. Arguments:
  547. q: (batch_size, seqlen, nheads, headdim)
  548. k: (batch_size, seqlen, nheads_k, headdim)
  549. v: (batch_size, seqlen, nheads_k, headdim)
  550. dropout_p: float. Dropout probability.
  551. softmax_scale: float. The scaling of QK^T before applying softmax.
  552. Default to 1 / sqrt(headdim).
  553. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  554. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  555. testing only. The returned probabilities are not guaranteed to be correct
  556. (they might not have the right scaling).
  557. Return:
  558. out: (batch_size, seqlen, nheads, headdim).
  559. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  560. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  561. normalization factor).
  562. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  563. The output of softmax (possibly with different scaling). It also encodes the dropout
  564. pattern (negative means that location was dropped, nonnegative means it was kept).
  565. """
  566. return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs)
  567. def flash_attn_varlen_qkvpacked_func(
  568. qkv,
  569. cu_seqlens,
  570. max_seqlen,
  571. dropout_p=0.0,
  572. softmax_scale=None,
  573. causal=False,
  574. return_attn_probs=False,
  575. ):
  576. """dropout_p should be set to 0.0 during evaluation
  577. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  578. calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
  579. of the gradients of Q, K, V.
  580. For multi-query and grouped-query attention (MQA/GQA), please see
  581. flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
  582. Arguments:
  583. qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
  584. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  585. of the sequences in the batch, used to index into qkv.
  586. max_seqlen: int. Maximum sequence length in the batch.
  587. dropout_p: float. Dropout probability.
  588. softmax_scale: float. The scaling of QK^T before applying softmax.
  589. Default to 1 / sqrt(headdim).
  590. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  591. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  592. testing only. The returned probabilities are not guaranteed to be correct
  593. (they might not have the right scaling).
  594. Return:
  595. out: (total, nheads, headdim).
  596. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  597. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  598. normalization factor).
  599. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  600. The output of softmax (possibly with different scaling). It also encodes the dropout
  601. pattern (negative means that location was dropped, nonnegative means it was kept).
  602. """
  603. return FlashAttnVarlenQKVPackedFunc.apply(
  604. qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
  605. )
  606. def flash_attn_varlen_kvpacked_func(
  607. q,
  608. kv,
  609. cu_seqlens_q,
  610. cu_seqlens_k,
  611. max_seqlen_q,
  612. max_seqlen_k,
  613. dropout_p=0.0,
  614. softmax_scale=None,
  615. causal=False,
  616. return_attn_probs=False,
  617. ):
  618. """dropout_p should be set to 0.0 during evaluation
  619. If K, V are already stacked into 1 tensor, this function will be faster than
  620. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  621. of the gradients of K, V.
  622. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  623. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  624. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  625. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  626. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  627. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  628. 1 1 1 1 0
  629. 1 1 1 1 1
  630. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  631. 0 0
  632. 0 0
  633. 0 0
  634. 1 0
  635. 1 1
  636. If the row of the mask is all zero, the output will be zero.
  637. Arguments:
  638. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  639. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  640. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  641. of the sequences in the batch, used to index into q.
  642. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  643. of the sequences in the batch, used to index into kv.
  644. max_seqlen_q: int. Maximum query sequence length in the batch.
  645. max_seqlen_k: int. Maximum key sequence length in the batch.
  646. dropout_p: float. Dropout probability.
  647. softmax_scale: float. The scaling of QK^T before applying softmax.
  648. Default to 1 / sqrt(headdim).
  649. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  650. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  651. testing only. The returned probabilities are not guaranteed to be correct
  652. (they might not have the right scaling).
  653. Return:
  654. out: (total, nheads, headdim).
  655. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  656. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  657. normalization factor).
  658. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  659. The output of softmax (possibly with different scaling). It also encodes the dropout
  660. pattern (negative means that location was dropped, nonnegative means it was kept).
  661. """
  662. return FlashAttnVarlenKVPackedFunc.apply(
  663. q,
  664. kv,
  665. cu_seqlens_q,
  666. cu_seqlens_k,
  667. max_seqlen_q,
  668. max_seqlen_k,
  669. dropout_p,
  670. softmax_scale,
  671. causal,
  672. return_attn_probs,
  673. )
  674. def flash_attn_varlen_func(
  675. q,
  676. k,
  677. v,
  678. cu_seqlens_q,
  679. cu_seqlens_k,
  680. max_seqlen_q,
  681. max_seqlen_k,
  682. dropout_p=0.0,
  683. softmax_scale=None,
  684. causal=False,
  685. return_attn_probs=False,
  686. ):
  687. """dropout_p should be set to 0.0 during evaluation
  688. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  689. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  690. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  691. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  692. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  693. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  694. 1 1 1 1 0
  695. 1 1 1 1 1
  696. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  697. 0 0
  698. 0 0
  699. 0 0
  700. 1 0
  701. 1 1
  702. If the row of the mask is all zero, the output will be zero.
  703. Arguments:
  704. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  705. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  706. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  707. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  708. of the sequences in the batch, used to index into q.
  709. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  710. of the sequences in the batch, used to index into kv.
  711. max_seqlen_q: int. Maximum query sequence length in the batch.
  712. max_seqlen_k: int. Maximum key sequence length in the batch.
  713. dropout_p: float. Dropout probability.
  714. softmax_scale: float. The scaling of QK^T before applying softmax.
  715. Default to 1 / sqrt(headdim).
  716. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  717. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  718. testing only. The returned probabilities are not guaranteed to be correct
  719. (they might not have the right scaling).
  720. Return:
  721. out: (total, nheads, headdim).
  722. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  723. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  724. normalization factor).
  725. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  726. The output of softmax (possibly with different scaling). It also encodes the dropout
  727. pattern (negative means that location was dropped, nonnegative means it was kept).
  728. """
  729. return FlashAttnVarlenFunc.apply(
  730. q,
  731. k,
  732. v,
  733. cu_seqlens_q,
  734. cu_seqlens_k,
  735. max_seqlen_q,
  736. max_seqlen_k,
  737. dropout_p,
  738. softmax_scale,
  739. causal,
  740. return_attn_probs,
  741. )
  742. def flash_attn_with_kvcache(
  743. q,
  744. k_cache,
  745. v_cache,
  746. k=None,
  747. v=None,
  748. rotary_cos=None,
  749. rotary_sin=None,
  750. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  751. softmax_scale=None,
  752. causal=False,
  753. rotary_interleaved=True,
  754. num_splits=0,
  755. ):
  756. """
  757. If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
  758. k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
  759. the previous step, and update them with the new keys/values from the current step, and do
  760. attention with the updated cache, all in 1 kernel.
  761. If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
  762. For example, the KV cache could be pre-allocated with the max sequence length, and you can use
  763. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
  764. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated
  765. by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
  766. If causal, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens,
  767. cache_seqlens + 1, etc. If not causal, the query @q will be rotated by rotary_cos and rotary_sin
  768. at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
  769. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
  770. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  771. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  772. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  773. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  774. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  775. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  776. 1 1 1 1 0
  777. 1 1 1 1 1
  778. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  779. 0 0
  780. 0 0
  781. 0 0
  782. 1 0
  783. 1 1
  784. If the row of the mask is all zero, the output will be zero.
  785. Note: Does not support backward pass.
  786. Arguments:
  787. q: (batch_size, seqlen, nheads, headdim)
  788. k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
  789. v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
  790. k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
  791. k with k_cache, starting at the indices specified by cache_seqlens.
  792. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
  793. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
  794. to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
  795. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
  796. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
  797. KV cache.
  798. softmax_scale: float. The scaling of QK^T before applying softmax.
  799. Default to 1 / sqrt(headdim).
  800. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  801. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
  802. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
  803. rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
  804. (i.e. GPT-NeoX style).
  805. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
  806. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
  807. to automatically determine the number of splits.
  808. Don't change this unless you know what you are doing.
  809. Return:
  810. out: (batch_size, seqlen, nheads, headdim).
  811. """
  812. assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
  813. assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
  814. maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
  815. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  816. if softmax_scale is None:
  817. softmax_scale = q.shape[-1] ** (-0.5)
  818. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  819. cache_seqlens = torch.full(
  820. (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
  821. )
  822. out, softmax_lse = flash_attn_cuda.fwd_kvcache(
  823. q,
  824. k_cache,
  825. v_cache,
  826. k,
  827. v,
  828. cache_seqlens,
  829. rotary_cos,
  830. rotary_sin,
  831. None,
  832. softmax_scale,
  833. causal,
  834. rotary_interleaved,
  835. num_splits,
  836. )
  837. return out