flash_attn_interface.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flashattn_hopper_cuda
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descale_k = None, descale_v = None):
  12. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  13. out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
  14. q,
  15. k,
  16. v,
  17. None,
  18. softmax_scale,
  19. descale_q,
  20. descale_k,
  21. descale_v,
  22. causal,
  23. )
  24. return out, q, k, v, out_padded, softmax_lse, S_dmask
  25. def _flash_attn_backward(
  26. dout,
  27. q,
  28. k,
  29. v,
  30. out,
  31. softmax_lse,
  32. dq,
  33. dk,
  34. dv,
  35. softmax_scale,
  36. causal,
  37. deterministic=False
  38. ):
  39. # dq, dk, dv are allocated by us so they should already be contiguous
  40. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  41. dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd(
  42. dout,
  43. q,
  44. k,
  45. v,
  46. out,
  47. softmax_lse,
  48. dq,
  49. dk,
  50. dv,
  51. softmax_scale,
  52. causal,
  53. deterministic,
  54. )
  55. return dq, dk, dv, softmax_d
  56. def _flash_attn_varlen_forward(
  57. q,
  58. k,
  59. v,
  60. cu_seqlens_q,
  61. cu_seqlens_k,
  62. max_seqlen_q,
  63. max_seqlen_k,
  64. softmax_scale,
  65. causal,
  66. seqused_q=None,
  67. seqused_k=None,
  68. ):
  69. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  70. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  71. out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
  72. q,
  73. k,
  74. v,
  75. None,
  76. cu_seqlens_q,
  77. cu_seqlens_k,
  78. seqused_q,
  79. seqused_k,
  80. max_seqlen_q,
  81. max_seqlen_k,
  82. softmax_scale,
  83. causal,
  84. )
  85. # if out.isnan().any() or softmax_lse.isnan().any():
  86. # breakpoint()
  87. return out, q, k, v, out_padded, softmax_lse
  88. def _flash_attn_varlen_backward(
  89. dout,
  90. q,
  91. k,
  92. v,
  93. out,
  94. softmax_lse,
  95. dq,
  96. dk,
  97. dv,
  98. cu_seqlens_q,
  99. cu_seqlens_k,
  100. max_seqlen_q,
  101. max_seqlen_k,
  102. softmax_scale,
  103. causal,
  104. deterministic=False,
  105. seqused_q=None,
  106. seqused_k=None,
  107. ):
  108. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  109. # dq, dk, dv are allocated by us so they should already be contiguous
  110. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  111. (
  112. dq,
  113. dk,
  114. dv,
  115. softmax_d,
  116. *rest,
  117. ) = flashattn_hopper_cuda.varlen_bwd(
  118. dout,
  119. q,
  120. k,
  121. v,
  122. out,
  123. softmax_lse,
  124. dq,
  125. dk,
  126. dv,
  127. cu_seqlens_q,
  128. cu_seqlens_k,
  129. seqused_q,
  130. seqused_k,
  131. max_seqlen_q,
  132. max_seqlen_k,
  133. softmax_scale,
  134. causal,
  135. deterministic,
  136. )
  137. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  138. # breakpoint()
  139. return dq, dk, dv, softmax_d
  140. class FlashAttnFunc(torch.autograd.Function):
  141. @staticmethod
  142. def forward(
  143. ctx,
  144. q,
  145. k,
  146. v,
  147. softmax_scale,
  148. causal,
  149. deterministic=False,
  150. descale_q=None,
  151. descale_k=None,
  152. descale_v=None,
  153. ):
  154. if softmax_scale is None:
  155. softmax_scale = q.shape[-1] ** (-0.5)
  156. out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
  157. q,
  158. k,
  159. v,
  160. softmax_scale,
  161. causal,
  162. descale_q=descale_q,
  163. descale_k=descale_k,
  164. descale_v=descale_v,
  165. )
  166. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  167. ctx.softmax_scale = softmax_scale
  168. ctx.causal = causal
  169. ctx.deterministic = deterministic
  170. return out, softmax_lse
  171. @staticmethod
  172. def backward(ctx, dout, *args):
  173. q, k, v, out, softmax_lse = ctx.saved_tensors
  174. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  175. _flash_attn_backward(
  176. dout,
  177. q,
  178. k,
  179. v,
  180. out,
  181. softmax_lse,
  182. dq,
  183. dk,
  184. dv,
  185. ctx.softmax_scale,
  186. ctx.causal,
  187. ctx.deterministic,
  188. )
  189. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  190. dk = dk[..., : dout.shape[-1]]
  191. dv = dv[..., : dout.shape[-1]]
  192. return dq, dk, dv, None, None, None, None, None, None
  193. class FlashAttnVarlenFunc(torch.autograd.Function):
  194. @staticmethod
  195. def forward(
  196. ctx,
  197. q,
  198. k,
  199. v,
  200. cu_seqlens_q,
  201. cu_seqlens_k,
  202. max_seqlen_q,
  203. max_seqlen_k,
  204. softmax_scale,
  205. causal,
  206. deterministic=False,
  207. seqused_q=None,
  208. seqused_k=None,
  209. ):
  210. if softmax_scale is None:
  211. softmax_scale = q.shape[-1] ** (-0.5)
  212. out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
  213. q,
  214. k,
  215. v,
  216. cu_seqlens_q,
  217. cu_seqlens_k,
  218. max_seqlen_q,
  219. max_seqlen_k,
  220. softmax_scale,
  221. causal=causal,
  222. seqused_q=seqused_q,
  223. seqused_k=seqused_k,
  224. )
  225. ctx.save_for_backward(
  226. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k,
  227. seqused_q, seqused_k
  228. )
  229. ctx.max_seqlen_q = max_seqlen_q
  230. ctx.max_seqlen_k = max_seqlen_k
  231. ctx.softmax_scale = softmax_scale
  232. ctx.causal = causal
  233. ctx.deterministic = deterministic
  234. return out, softmax_lse
  235. @staticmethod
  236. def backward(ctx, dout, *args):
  237. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
  238. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  239. _flash_attn_varlen_backward(
  240. dout,
  241. q,
  242. k,
  243. v,
  244. out,
  245. softmax_lse,
  246. dq,
  247. dk,
  248. dv,
  249. cu_seqlens_q,
  250. cu_seqlens_k,
  251. ctx.max_seqlen_q,
  252. ctx.max_seqlen_k,
  253. ctx.softmax_scale,
  254. ctx.causal,
  255. ctx.deterministic,
  256. seqused_q,
  257. seqused_k,
  258. )
  259. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  260. dk = dk[..., : dout.shape[-1]]
  261. dv = dv[..., : dout.shape[-1]]
  262. return dq, dk, dv, None, None, None, None, None, None, None, None, None
  263. def flash_attn_func(
  264. q,
  265. k,
  266. v,
  267. softmax_scale=None,
  268. causal=False,
  269. deterministic=False,
  270. descale_q=None,
  271. descale_k=None,
  272. descale_v=None,
  273. ):
  274. """dropout_p should be set to 0.0 during evaluation
  275. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  276. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  277. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  278. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  279. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  280. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  281. 1 1 1 1 0
  282. 1 1 1 1 1
  283. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  284. 0 0
  285. 0 0
  286. 0 0
  287. 1 0
  288. 1 1
  289. If the row of the mask is all zero, the output will be zero.
  290. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  291. will only attend to keys between
  292. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  293. Arguments:
  294. q: (batch_size, seqlen, nheads, headdim)
  295. k: (batch_size, seqlen, nheads_k, headdim)
  296. v: (batch_size, seqlen, nheads_k, headdim)
  297. dropout_p: float. Dropout probability.
  298. softmax_scale: float. The scaling of QK^T before applying softmax.
  299. Default to 1 / sqrt(headdim).
  300. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  301. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  302. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  303. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  304. is added to the attention score of query i and key j.
  305. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  306. which is slightly slower and uses more memory. The forward pass is always deterministic.
  307. descale_q: (1,), fp32. A de-quantization scaling factor for q in fp8 execution.
  308. descale_k: (1,), fp32. A de-quantization scaling factor for k in fp8 execution.
  309. descale_v: (1,), fp32. A de-quantization scaling factor for v in fp8 execution.
  310. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  311. testing only. The returned probabilities are not guaranteed to be correct
  312. (they might not have the right scaling).
  313. Return:
  314. out: (batch_size, seqlen, nheads, headdim).
  315. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  316. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  317. normalization factor).
  318. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  319. The output of softmax (possibly with different scaling). It also encodes the dropout
  320. pattern (negative means that location was dropped, nonnegative means it was kept).
  321. """
  322. return FlashAttnFunc.apply(
  323. q,
  324. k,
  325. v,
  326. softmax_scale,
  327. causal,
  328. deterministic,
  329. descale_q,
  330. descale_k,
  331. descale_v,
  332. )
  333. def flash_attn_varlen_func(
  334. q,
  335. k,
  336. v,
  337. cu_seqlens_q,
  338. cu_seqlens_k,
  339. max_seqlen_q,
  340. max_seqlen_k,
  341. softmax_scale=None,
  342. causal=False,
  343. deterministic=False,
  344. seqused_q=None,
  345. seqused_k=None,
  346. ):
  347. """
  348. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  349. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  350. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  351. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  352. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  353. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  354. 1 1 1 1 0
  355. 1 1 1 1 1
  356. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  357. 0 0
  358. 0 0
  359. 0 0
  360. 1 0
  361. 1 1
  362. If the row of the mask is all zero, the output will be zero.
  363. Arguments:
  364. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  365. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  366. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  367. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  368. of the sequences in the batch, used to index into q.
  369. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  370. of the sequences in the batch, used to index into kv.
  371. max_seqlen_q: int. Maximum query sequence length in the batch.
  372. max_seqlen_k: int. Maximum key sequence length in the batch.
  373. softmax_scale: float. The scaling of QK^T before applying softmax.
  374. Default to 1 / sqrt(headdim).
  375. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  376. seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
  377. query and output tokens in each sequence.
  378. seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
  379. key and value tokens in each sequence.
  380. Return:
  381. out: (total, nheads, headdim).
  382. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  383. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  384. normalization factor).
  385. """
  386. return FlashAttnVarlenFunc.apply(
  387. q,
  388. k,
  389. v,
  390. cu_seqlens_q,
  391. cu_seqlens_k,
  392. max_seqlen_q,
  393. max_seqlen_k,
  394. softmax_scale,
  395. causal,
  396. deterministic,
  397. seqused_q,
  398. seqused_k,
  399. )