flash_attn_interface.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flashattn_hopper_cuda
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descale_k = None, descale_v = None):
  12. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  13. out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
  14. q,
  15. k,
  16. v,
  17. None,
  18. softmax_scale,
  19. descale_q,
  20. descale_k,
  21. descale_v,
  22. causal,
  23. )
  24. return out, q, k, v, out_padded, softmax_lse, S_dmask
  25. def _flash_attn_backward(
  26. dout,
  27. q,
  28. k,
  29. v,
  30. out,
  31. softmax_lse,
  32. dq,
  33. dk,
  34. dv,
  35. softmax_scale,
  36. causal,
  37. deterministic=False
  38. ):
  39. # dq, dk, dv are allocated by us so they should already be contiguous
  40. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  41. dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd(
  42. dout,
  43. q,
  44. k,
  45. v,
  46. out,
  47. softmax_lse,
  48. dq,
  49. dk,
  50. dv,
  51. softmax_scale,
  52. causal,
  53. deterministic,
  54. )
  55. return dq, dk, dv, softmax_d
  56. def _flash_attn_varlen_forward(
  57. q,
  58. k,
  59. v,
  60. cu_seqlens_q,
  61. cu_seqlens_k,
  62. max_seqlen_q,
  63. max_seqlen_k,
  64. softmax_scale,
  65. causal,
  66. ):
  67. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  68. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  69. out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
  70. q,
  71. k,
  72. v,
  73. None,
  74. cu_seqlens_q,
  75. cu_seqlens_k,
  76. None,
  77. max_seqlen_q,
  78. max_seqlen_k,
  79. softmax_scale,
  80. causal,
  81. )
  82. # if out.isnan().any() or softmax_lse.isnan().any():
  83. # breakpoint()
  84. return out, q, k, v, out_padded, softmax_lse
  85. def _flash_attn_varlen_backward(
  86. dout,
  87. q,
  88. k,
  89. v,
  90. out,
  91. softmax_lse,
  92. dq,
  93. dk,
  94. dv,
  95. cu_seqlens_q,
  96. cu_seqlens_k,
  97. max_seqlen_q,
  98. max_seqlen_k,
  99. softmax_scale,
  100. causal,
  101. deterministic=False,
  102. ):
  103. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  104. # dq, dk, dv are allocated by us so they should already be contiguous
  105. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  106. (
  107. dq,
  108. dk,
  109. dv,
  110. softmax_d,
  111. *rest,
  112. ) = flashattn_hopper_cuda.varlen_bwd(
  113. dout,
  114. q,
  115. k,
  116. v,
  117. out,
  118. softmax_lse,
  119. dq,
  120. dk,
  121. dv,
  122. cu_seqlens_q,
  123. cu_seqlens_k,
  124. max_seqlen_q,
  125. max_seqlen_k,
  126. softmax_scale,
  127. causal,
  128. deterministic,
  129. )
  130. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  131. # breakpoint()
  132. return dq, dk, dv, softmax_d
  133. class FlashAttnFunc(torch.autograd.Function):
  134. @staticmethod
  135. def forward(
  136. ctx,
  137. q,
  138. k,
  139. v,
  140. softmax_scale,
  141. causal,
  142. deterministic=False,
  143. ):
  144. if softmax_scale is None:
  145. softmax_scale = q.shape[-1] ** (-0.5)
  146. out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
  147. q,
  148. k,
  149. v,
  150. softmax_scale,
  151. causal
  152. )
  153. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  154. ctx.softmax_scale = softmax_scale
  155. ctx.causal = causal
  156. ctx.deterministic = deterministic
  157. return out, softmax_lse
  158. @staticmethod
  159. def backward(ctx, dout, *args):
  160. q, k, v, out, softmax_lse = ctx.saved_tensors
  161. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  162. _flash_attn_backward(
  163. dout,
  164. q,
  165. k,
  166. v,
  167. out,
  168. softmax_lse,
  169. dq,
  170. dk,
  171. dv,
  172. ctx.softmax_scale,
  173. ctx.causal,
  174. ctx.deterministic,
  175. )
  176. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  177. dk = dk[..., : dout.shape[-1]]
  178. dv = dv[..., : dout.shape[-1]]
  179. return dq, dk, dv, None, None, None
  180. class FlashAttnVarlenFunc(torch.autograd.Function):
  181. @staticmethod
  182. def forward(
  183. ctx,
  184. q,
  185. k,
  186. v,
  187. cu_seqlens_q,
  188. cu_seqlens_k,
  189. max_seqlen_q,
  190. max_seqlen_k,
  191. softmax_scale,
  192. causal,
  193. deterministic=False,
  194. ):
  195. if softmax_scale is None:
  196. softmax_scale = q.shape[-1] ** (-0.5)
  197. out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
  198. q,
  199. k,
  200. v,
  201. cu_seqlens_q,
  202. cu_seqlens_k,
  203. max_seqlen_q,
  204. max_seqlen_k,
  205. softmax_scale,
  206. causal=causal,
  207. )
  208. ctx.save_for_backward(
  209. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k
  210. )
  211. ctx.max_seqlen_q = max_seqlen_q
  212. ctx.max_seqlen_k = max_seqlen_k
  213. ctx.softmax_scale = softmax_scale
  214. ctx.causal = causal
  215. ctx.deterministic = deterministic
  216. return out, softmax_lse
  217. @staticmethod
  218. def backward(ctx, dout, *args):
  219. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
  220. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  221. _flash_attn_varlen_backward(
  222. dout,
  223. q,
  224. k,
  225. v,
  226. out,
  227. softmax_lse,
  228. dq,
  229. dk,
  230. dv,
  231. cu_seqlens_q,
  232. cu_seqlens_k,
  233. ctx.max_seqlen_q,
  234. ctx.max_seqlen_k,
  235. ctx.softmax_scale,
  236. ctx.causal,
  237. ctx.deterministic,
  238. )
  239. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  240. dk = dk[..., : dout.shape[-1]]
  241. dv = dv[..., : dout.shape[-1]]
  242. return dq, dk, dv, None, None, None, None, None, None, None
  243. def flash_attn_func(
  244. q,
  245. k,
  246. v,
  247. softmax_scale=None,
  248. causal=False,
  249. deterministic=False
  250. ):
  251. """dropout_p should be set to 0.0 during evaluation
  252. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  253. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  254. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  255. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  256. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  257. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  258. 1 1 1 1 0
  259. 1 1 1 1 1
  260. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  261. 0 0
  262. 0 0
  263. 0 0
  264. 1 0
  265. 1 1
  266. If the row of the mask is all zero, the output will be zero.
  267. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  268. will only attend to keys between
  269. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  270. Arguments:
  271. q: (batch_size, seqlen, nheads, headdim)
  272. k: (batch_size, seqlen, nheads_k, headdim)
  273. v: (batch_size, seqlen, nheads_k, headdim)
  274. dropout_p: float. Dropout probability.
  275. softmax_scale: float. The scaling of QK^T before applying softmax.
  276. Default to 1 / sqrt(headdim).
  277. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  278. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  279. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  280. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  281. is added to the attention score of query i and key j.
  282. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  283. which is slightly slower and uses more memory. The forward pass is always deterministic.
  284. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  285. testing only. The returned probabilities are not guaranteed to be correct
  286. (they might not have the right scaling).
  287. Return:
  288. out: (batch_size, seqlen, nheads, headdim).
  289. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  290. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  291. normalization factor).
  292. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  293. The output of softmax (possibly with different scaling). It also encodes the dropout
  294. pattern (negative means that location was dropped, nonnegative means it was kept).
  295. """
  296. return FlashAttnFunc.apply(
  297. q,
  298. k,
  299. v,
  300. softmax_scale,
  301. causal,
  302. deterministic,
  303. )
  304. def flash_attn_varlen_func(
  305. q,
  306. k,
  307. v,
  308. cu_seqlens_q,
  309. cu_seqlens_k,
  310. max_seqlen_q,
  311. max_seqlen_k,
  312. softmax_scale=None,
  313. causal=False,
  314. deterministic=False,
  315. ):
  316. """
  317. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  318. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  319. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  320. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  321. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  322. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  323. 1 1 1 1 0
  324. 1 1 1 1 1
  325. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  326. 0 0
  327. 0 0
  328. 0 0
  329. 1 0
  330. 1 1
  331. If the row of the mask is all zero, the output will be zero.
  332. Arguments:
  333. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  334. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  335. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  336. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  337. of the sequences in the batch, used to index into q.
  338. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  339. of the sequences in the batch, used to index into kv.
  340. max_seqlen_q: int. Maximum query sequence length in the batch.
  341. max_seqlen_k: int. Maximum key sequence length in the batch.
  342. softmax_scale: float. The scaling of QK^T before applying softmax.
  343. Default to 1 / sqrt(headdim).
  344. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  345. Return:
  346. out: (total, nheads, headdim).
  347. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  348. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  349. normalization factor).
  350. """
  351. return FlashAttnVarlenFunc.apply(
  352. q,
  353. k,
  354. v,
  355. cu_seqlens_q,
  356. cu_seqlens_k,
  357. max_seqlen_q,
  358. max_seqlen_k,
  359. softmax_scale,
  360. causal,
  361. deterministic,
  362. )