flash_attn_interface.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flashattn_hopper_cuda
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q = None, descale_k = None, descale_v = None):
  12. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  13. out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
  14. q,
  15. k,
  16. v,
  17. None,
  18. softmax_scale,
  19. descale_q,
  20. descale_k,
  21. descale_v,
  22. causal,
  23. window_size[0],
  24. window_size[1],
  25. )
  26. return out, q, k, v, out_padded, softmax_lse, S_dmask
  27. def _flash_attn_backward(
  28. dout,
  29. q,
  30. k,
  31. v,
  32. out,
  33. softmax_lse,
  34. dq,
  35. dk,
  36. dv,
  37. softmax_scale,
  38. causal,
  39. window_size,
  40. deterministic=False
  41. ):
  42. # dq, dk, dv are allocated by us so they should already be contiguous
  43. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  44. dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd(
  45. dout,
  46. q,
  47. k,
  48. v,
  49. out,
  50. softmax_lse,
  51. dq,
  52. dk,
  53. dv,
  54. softmax_scale,
  55. causal,
  56. window_size[0],
  57. window_size[1],
  58. deterministic,
  59. )
  60. return dq, dk, dv, softmax_d
  61. def _flash_attn_varlen_forward(
  62. q,
  63. k,
  64. v,
  65. cu_seqlens_q,
  66. cu_seqlens_k,
  67. max_seqlen_q,
  68. max_seqlen_k,
  69. softmax_scale,
  70. causal,
  71. window_size=(-1, -1),
  72. seqused_q=None,
  73. seqused_k=None,
  74. optimize_for_doc_masking=False,
  75. ):
  76. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  77. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  78. out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
  79. q,
  80. k,
  81. v,
  82. None,
  83. cu_seqlens_q,
  84. cu_seqlens_k,
  85. seqused_q,
  86. seqused_k,
  87. max_seqlen_q,
  88. max_seqlen_k,
  89. softmax_scale,
  90. causal,
  91. window_size[0],
  92. window_size[1],
  93. optimize_for_doc_masking,
  94. )
  95. # if out.isnan().any() or softmax_lse.isnan().any():
  96. # breakpoint()
  97. return out, q, k, v, out_padded, softmax_lse
  98. def _flash_attn_varlen_backward(
  99. dout,
  100. q,
  101. k,
  102. v,
  103. out,
  104. softmax_lse,
  105. dq,
  106. dk,
  107. dv,
  108. cu_seqlens_q,
  109. cu_seqlens_k,
  110. max_seqlen_q,
  111. max_seqlen_k,
  112. softmax_scale,
  113. causal,
  114. window_size,
  115. deterministic=False,
  116. seqused_q=None,
  117. seqused_k=None,
  118. optimize_for_doc_masking=False,
  119. ):
  120. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  121. # dq, dk, dv are allocated by us so they should already be contiguous
  122. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  123. (
  124. dq,
  125. dk,
  126. dv,
  127. softmax_d,
  128. *rest,
  129. ) = flashattn_hopper_cuda.varlen_bwd(
  130. dout,
  131. q,
  132. k,
  133. v,
  134. out,
  135. softmax_lse,
  136. dq,
  137. dk,
  138. dv,
  139. cu_seqlens_q,
  140. cu_seqlens_k,
  141. seqused_q,
  142. seqused_k,
  143. max_seqlen_q,
  144. max_seqlen_k,
  145. softmax_scale,
  146. causal,
  147. window_size[0],
  148. window_size[1],
  149. deterministic,
  150. optimize_for_doc_masking,
  151. )
  152. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  153. # breakpoint()
  154. return dq, dk, dv, softmax_d
  155. class FlashAttnFunc(torch.autograd.Function):
  156. @staticmethod
  157. def forward(
  158. ctx,
  159. q,
  160. k,
  161. v,
  162. softmax_scale,
  163. causal,
  164. window_size,
  165. deterministic=False,
  166. descale_q=None,
  167. descale_k=None,
  168. descale_v=None,
  169. ):
  170. if softmax_scale is None:
  171. softmax_scale = q.shape[-1] ** (-0.5)
  172. out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
  173. q,
  174. k,
  175. v,
  176. softmax_scale,
  177. causal,
  178. window_size,
  179. descale_q=descale_q,
  180. descale_k=descale_k,
  181. descale_v=descale_v,
  182. )
  183. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  184. ctx.softmax_scale = softmax_scale
  185. ctx.causal = causal
  186. ctx.window_size = window_size
  187. ctx.deterministic = deterministic
  188. return out, softmax_lse
  189. @staticmethod
  190. def backward(ctx, dout, *args):
  191. q, k, v, out, softmax_lse = ctx.saved_tensors
  192. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  193. _flash_attn_backward(
  194. dout,
  195. q,
  196. k,
  197. v,
  198. out,
  199. softmax_lse,
  200. dq,
  201. dk,
  202. dv,
  203. ctx.softmax_scale,
  204. ctx.causal,
  205. ctx.window_size,
  206. ctx.deterministic,
  207. )
  208. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  209. dk = dk[..., : dout.shape[-1]]
  210. dv = dv[..., : dout.shape[-1]]
  211. return dq, dk, dv, None, None, None, None, None, None, None
  212. class FlashAttnVarlenFunc(torch.autograd.Function):
  213. @staticmethod
  214. def forward(
  215. ctx,
  216. q,
  217. k,
  218. v,
  219. cu_seqlens_q,
  220. cu_seqlens_k,
  221. max_seqlen_q,
  222. max_seqlen_k,
  223. softmax_scale,
  224. causal,
  225. window_size,
  226. deterministic=False,
  227. seqused_q=None,
  228. seqused_k=None,
  229. optimize_for_doc_masking=False,
  230. ):
  231. if softmax_scale is None:
  232. softmax_scale = q.shape[-1] ** (-0.5)
  233. out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
  234. q,
  235. k,
  236. v,
  237. cu_seqlens_q,
  238. cu_seqlens_k,
  239. max_seqlen_q,
  240. max_seqlen_k,
  241. softmax_scale,
  242. causal=causal,
  243. window_size=window_size,
  244. seqused_q=seqused_q,
  245. seqused_k=seqused_k,
  246. optimize_for_doc_masking=optimize_for_doc_masking,
  247. )
  248. ctx.save_for_backward(
  249. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k,
  250. seqused_q, seqused_k
  251. )
  252. ctx.max_seqlen_q = max_seqlen_q
  253. ctx.max_seqlen_k = max_seqlen_k
  254. ctx.softmax_scale = softmax_scale
  255. ctx.causal = causal
  256. ctx.window_size = window_size
  257. ctx.deterministic = deterministic
  258. ctx.optimize_for_doc_masking = optimize_for_doc_masking
  259. return out, softmax_lse
  260. @staticmethod
  261. def backward(ctx, dout, *args):
  262. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
  263. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  264. _flash_attn_varlen_backward(
  265. dout,
  266. q,
  267. k,
  268. v,
  269. out,
  270. softmax_lse,
  271. dq,
  272. dk,
  273. dv,
  274. cu_seqlens_q,
  275. cu_seqlens_k,
  276. ctx.max_seqlen_q,
  277. ctx.max_seqlen_k,
  278. ctx.softmax_scale,
  279. ctx.causal,
  280. ctx.window_size,
  281. ctx.deterministic,
  282. seqused_q,
  283. seqused_k,
  284. ctx.optimize_for_doc_masking,
  285. )
  286. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  287. dk = dk[..., : dout.shape[-1]]
  288. dv = dv[..., : dout.shape[-1]]
  289. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
  290. def flash_attn_func(
  291. q,
  292. k,
  293. v,
  294. softmax_scale=None,
  295. causal=False,
  296. window_size=(-1, -1),
  297. deterministic=False,
  298. descale_q=None,
  299. descale_k=None,
  300. descale_v=None,
  301. ):
  302. """dropout_p should be set to 0.0 during evaluation
  303. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  304. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  305. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  306. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  307. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  308. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  309. 1 1 1 1 0
  310. 1 1 1 1 1
  311. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  312. 0 0
  313. 0 0
  314. 0 0
  315. 1 0
  316. 1 1
  317. If the row of the mask is all zero, the output will be zero.
  318. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  319. will only attend to keys between
  320. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  321. Arguments:
  322. q: (batch_size, seqlen, nheads, headdim)
  323. k: (batch_size, seqlen, nheads_k, headdim)
  324. v: (batch_size, seqlen, nheads_k, headdim)
  325. dropout_p: float. Dropout probability.
  326. softmax_scale: float. The scaling of QK^T before applying softmax.
  327. Default to 1 / sqrt(headdim).
  328. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  329. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  330. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  331. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  332. is added to the attention score of query i and key j.
  333. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  334. which is slightly slower and uses more memory. The forward pass is always deterministic.
  335. descale_q: (1,), fp32. A de-quantization scaling factor for q in fp8 execution.
  336. descale_k: (1,), fp32. A de-quantization scaling factor for k in fp8 execution.
  337. descale_v: (1,), fp32. A de-quantization scaling factor for v in fp8 execution.
  338. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  339. testing only. The returned probabilities are not guaranteed to be correct
  340. (they might not have the right scaling).
  341. Return:
  342. out: (batch_size, seqlen, nheads, headdim).
  343. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  344. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  345. normalization factor).
  346. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  347. The output of softmax (possibly with different scaling). It also encodes the dropout
  348. pattern (negative means that location was dropped, nonnegative means it was kept).
  349. """
  350. return FlashAttnFunc.apply(
  351. q,
  352. k,
  353. v,
  354. softmax_scale,
  355. causal,
  356. window_size,
  357. deterministic,
  358. descale_q,
  359. descale_k,
  360. descale_v,
  361. )
  362. def flash_attn_varlen_func(
  363. q,
  364. k,
  365. v,
  366. cu_seqlens_q,
  367. cu_seqlens_k,
  368. max_seqlen_q,
  369. max_seqlen_k,
  370. softmax_scale=None,
  371. causal=False,
  372. window_size=(-1, -1),
  373. deterministic=False,
  374. seqused_q=None,
  375. seqused_k=None,
  376. optimize_for_doc_masking=False,
  377. ):
  378. """
  379. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  380. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  381. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  382. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  383. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  384. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  385. 1 1 1 1 0
  386. 1 1 1 1 1
  387. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  388. 0 0
  389. 0 0
  390. 0 0
  391. 1 0
  392. 1 1
  393. If the row of the mask is all zero, the output will be zero.
  394. Arguments:
  395. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  396. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  397. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  398. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  399. of the sequences in the batch, used to index into q.
  400. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  401. of the sequences in the batch, used to index into kv.
  402. max_seqlen_q: int. Maximum query sequence length in the batch.
  403. max_seqlen_k: int. Maximum key sequence length in the batch.
  404. softmax_scale: float. The scaling of QK^T before applying softmax.
  405. Default to 1 / sqrt(headdim).
  406. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  407. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  408. seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
  409. query and output tokens in each sequence.
  410. seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
  411. key and value tokens in each sequence.
  412. Return:
  413. out: (total, nheads, headdim).
  414. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  415. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  416. normalization factor).
  417. """
  418. return FlashAttnVarlenFunc.apply(
  419. q,
  420. k,
  421. v,
  422. cu_seqlens_q,
  423. cu_seqlens_k,
  424. max_seqlen_q,
  425. max_seqlen_k,
  426. softmax_scale,
  427. causal,
  428. window_size,
  429. deterministic,
  430. seqused_q,
  431. seqused_k,
  432. optimize_for_doc_masking,
  433. )