1
0

flash_attn_interface.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flashattn_hopper_cuda
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _flash_attn_forward(q, k, v, softmax_scale, causal):
  12. # q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  13. out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
  14. q,
  15. k,
  16. v,
  17. None,
  18. softmax_scale,
  19. causal,
  20. )
  21. return out, q, k, v, out_padded, softmax_lse, S_dmask
  22. def _flash_attn_backward(
  23. dout,
  24. q,
  25. k,
  26. v,
  27. out,
  28. softmax_lse,
  29. dq,
  30. dk,
  31. dv,
  32. softmax_scale,
  33. causal
  34. ):
  35. # dq, dk, dv are allocated by us so they should already be contiguous
  36. #dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  37. dq, dk, dv, softmax_d, = flashattn_hopper_cuda.bwd(
  38. dout,
  39. q,
  40. k,
  41. v,
  42. out,
  43. softmax_lse,
  44. dq,
  45. dk,
  46. dv,
  47. softmax_scale,
  48. causal,
  49. )
  50. return dq, dk, dv, softmax_d
  51. def _flash_attn_varlen_forward(
  52. q,
  53. k,
  54. v,
  55. cu_seqlens_q,
  56. cu_seqlens_k,
  57. max_seqlen_q,
  58. max_seqlen_k,
  59. softmax_scale,
  60. causal,
  61. ):
  62. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  63. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  64. out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
  65. q,
  66. k,
  67. v,
  68. None,
  69. cu_seqlens_q,
  70. cu_seqlens_k,
  71. None,
  72. max_seqlen_q,
  73. max_seqlen_k,
  74. softmax_scale,
  75. causal,
  76. )
  77. # if out.isnan().any() or softmax_lse.isnan().any():
  78. # breakpoint()
  79. return out, q, k, v, out_padded, softmax_lse
  80. def _flash_attn_varlen_backward(
  81. dout,
  82. q,
  83. k,
  84. v,
  85. out,
  86. softmax_lse,
  87. dq,
  88. dk,
  89. dv,
  90. cu_seqlens_q,
  91. cu_seqlens_k,
  92. max_seqlen_q,
  93. max_seqlen_k,
  94. softmax_scale,
  95. causal,
  96. ):
  97. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  98. # dq, dk, dv are allocated by us so they should already be contiguous
  99. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  100. (
  101. dq,
  102. dk,
  103. dv,
  104. softmax_d,
  105. ) = _get_fa_module().varlen_bwd(
  106. dout,
  107. q,
  108. k,
  109. v,
  110. out,
  111. softmax_lse,
  112. dq,
  113. dk,
  114. dv,
  115. cu_seqlens_q,
  116. cu_seqlens_k,
  117. max_seqlen_q,
  118. max_seqlen_k,
  119. softmax_scale,
  120. causal,
  121. )
  122. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  123. # breakpoint()
  124. return dq, dk, dv, softmax_d
  125. class FlashAttnFunc(torch.autograd.Function):
  126. @staticmethod
  127. def forward(
  128. ctx,
  129. q,
  130. k,
  131. v,
  132. softmax_scale,
  133. causal,
  134. ):
  135. if softmax_scale is None:
  136. softmax_scale = q.shape[-1] ** (-0.5)
  137. out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
  138. q,
  139. k,
  140. v,
  141. softmax_scale,
  142. causal
  143. )
  144. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  145. ctx.softmax_scale = softmax_scale
  146. ctx.causal = causal
  147. return out, softmax_lse
  148. @staticmethod
  149. def backward(ctx, dout, *args):
  150. q, k, v, out, softmax_lse = ctx.saved_tensors
  151. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  152. _flash_attn_backward(
  153. dout,
  154. q,
  155. k,
  156. v,
  157. out,
  158. softmax_lse,
  159. dq,
  160. dk,
  161. dv,
  162. ctx.softmax_scale,
  163. ctx.causal,
  164. )
  165. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  166. dk = dk[..., : dout.shape[-1]]
  167. dv = dv[..., : dout.shape[-1]]
  168. return dq, dk, dv, None, None
  169. class FlashAttnVarlenFunc(torch.autograd.Function):
  170. @staticmethod
  171. def forward(
  172. ctx,
  173. q,
  174. k,
  175. v,
  176. cu_seqlens_q,
  177. cu_seqlens_k,
  178. max_seqlen_q,
  179. max_seqlen_k,
  180. softmax_scale,
  181. causal,
  182. ):
  183. if softmax_scale is None:
  184. softmax_scale = q.shape[-1] ** (-0.5)
  185. out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
  186. q,
  187. k,
  188. v,
  189. cu_seqlens_q,
  190. cu_seqlens_k,
  191. max_seqlen_q,
  192. max_seqlen_k,
  193. softmax_scale,
  194. causal=causal,
  195. )
  196. ctx.save_for_backward(
  197. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k
  198. )
  199. ctx.max_seqlen_q = max_seqlen_q
  200. ctx.max_seqlen_k = max_seqlen_k
  201. ctx.softmax_scale = softmax_scale
  202. ctx.causal = causal
  203. return out, softmax_lse
  204. @staticmethod
  205. def backward(ctx, dout, *args):
  206. # TODO: Uncomment these when var-seq-len is supported in bwd kernel.
  207. # q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
  208. # dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  209. # _flash_attn_varlen_backward(
  210. # dout,
  211. # q,
  212. # k,
  213. # v,
  214. # out,
  215. # softmax_lse,
  216. # dq,
  217. # dk,
  218. # dv,
  219. # cu_seqlens_q,
  220. # cu_seqlens_k,
  221. # ctx.max_seqlen_q,
  222. # ctx.max_seqlen_k,
  223. # ctx.softmax_scale,
  224. # ctx.causal,
  225. # )
  226. # dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  227. # dk = dk[..., : dout.shape[-1]]
  228. # dv = dv[..., : dout.shape[-1]]
  229. # return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
  230. return None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
  231. def flash_attn_func(
  232. q,
  233. k,
  234. v,
  235. softmax_scale=None,
  236. causal=False,
  237. ):
  238. """dropout_p should be set to 0.0 during evaluation
  239. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  240. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  241. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  242. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  243. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  244. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  245. 1 1 1 1 0
  246. 1 1 1 1 1
  247. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  248. 0 0
  249. 0 0
  250. 0 0
  251. 1 0
  252. 1 1
  253. If the row of the mask is all zero, the output will be zero.
  254. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  255. will only attend to keys between
  256. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  257. Arguments:
  258. q: (batch_size, seqlen, nheads, headdim)
  259. k: (batch_size, seqlen, nheads_k, headdim)
  260. v: (batch_size, seqlen, nheads_k, headdim)
  261. dropout_p: float. Dropout probability.
  262. softmax_scale: float. The scaling of QK^T before applying softmax.
  263. Default to 1 / sqrt(headdim).
  264. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  265. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  266. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  267. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  268. is added to the attention score of query i and key j.
  269. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  270. which is slightly slower and uses more memory. The forward pass is always deterministic.
  271. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  272. testing only. The returned probabilities are not guaranteed to be correct
  273. (they might not have the right scaling).
  274. Return:
  275. out: (batch_size, seqlen, nheads, headdim).
  276. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  277. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  278. normalization factor).
  279. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  280. The output of softmax (possibly with different scaling). It also encodes the dropout
  281. pattern (negative means that location was dropped, nonnegative means it was kept).
  282. """
  283. return FlashAttnFunc.apply(
  284. q,
  285. k,
  286. v,
  287. softmax_scale,
  288. causal,
  289. )
  290. def flash_attn_varlen_func(
  291. q,
  292. k,
  293. v,
  294. cu_seqlens_q,
  295. cu_seqlens_k,
  296. max_seqlen_q,
  297. max_seqlen_k,
  298. softmax_scale=None,
  299. causal=False,
  300. ):
  301. """
  302. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  303. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  304. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  305. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  306. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  307. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  308. 1 1 1 1 0
  309. 1 1 1 1 1
  310. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  311. 0 0
  312. 0 0
  313. 0 0
  314. 1 0
  315. 1 1
  316. If the row of the mask is all zero, the output will be zero.
  317. Arguments:
  318. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  319. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  320. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  321. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  322. of the sequences in the batch, used to index into q.
  323. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  324. of the sequences in the batch, used to index into kv.
  325. max_seqlen_q: int. Maximum query sequence length in the batch.
  326. max_seqlen_k: int. Maximum key sequence length in the batch.
  327. softmax_scale: float. The scaling of QK^T before applying softmax.
  328. Default to 1 / sqrt(headdim).
  329. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  330. Return:
  331. out: (total, nheads, headdim).
  332. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  333. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  334. normalization factor).
  335. """
  336. return FlashAttnVarlenFunc.apply(
  337. q,
  338. k,
  339. v,
  340. cu_seqlens_q,
  341. cu_seqlens_k,
  342. max_seqlen_q,
  343. max_seqlen_k,
  344. softmax_scale,
  345. causal,
  346. )