1
0

flash_attn_interface.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flashattn_hopper_cuda
  8. # isort: on
  9. def maybe_contiguous(x):
  10. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  11. def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descale_k = None, descale_v = None):
  12. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  13. out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
  14. q,
  15. k,
  16. v,
  17. None,
  18. softmax_scale,
  19. descale_q,
  20. descale_k,
  21. descale_v,
  22. causal,
  23. )
  24. return out, q, k, v, out_padded, softmax_lse, S_dmask
  25. def _flash_attn_backward(
  26. dout,
  27. q,
  28. k,
  29. v,
  30. out,
  31. softmax_lse,
  32. dq,
  33. dk,
  34. dv,
  35. softmax_scale,
  36. causal,
  37. deterministic=False
  38. ):
  39. # dq, dk, dv are allocated by us so they should already be contiguous
  40. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  41. dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd(
  42. dout,
  43. q,
  44. k,
  45. v,
  46. out,
  47. softmax_lse,
  48. dq,
  49. dk,
  50. dv,
  51. softmax_scale,
  52. causal,
  53. deterministic,
  54. )
  55. return dq, dk, dv, softmax_d
  56. def _flash_attn_varlen_forward(
  57. q,
  58. k,
  59. v,
  60. cu_seqlens_q,
  61. cu_seqlens_k,
  62. max_seqlen_q,
  63. max_seqlen_k,
  64. softmax_scale,
  65. causal,
  66. seqused_q=None,
  67. seqused_k=None,
  68. ):
  69. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  70. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  71. out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
  72. q,
  73. k,
  74. v,
  75. None,
  76. cu_seqlens_q,
  77. cu_seqlens_k,
  78. seqused_q,
  79. seqused_k,
  80. max_seqlen_q,
  81. max_seqlen_k,
  82. softmax_scale,
  83. causal,
  84. )
  85. # if out.isnan().any() or softmax_lse.isnan().any():
  86. # breakpoint()
  87. return out, q, k, v, out_padded, softmax_lse
  88. def _flash_attn_varlen_backward(
  89. dout,
  90. q,
  91. k,
  92. v,
  93. out,
  94. softmax_lse,
  95. dq,
  96. dk,
  97. dv,
  98. cu_seqlens_q,
  99. cu_seqlens_k,
  100. max_seqlen_q,
  101. max_seqlen_k,
  102. softmax_scale,
  103. causal,
  104. deterministic=False,
  105. seqused_q=None,
  106. seqused_k=None,
  107. ):
  108. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  109. # dq, dk, dv are allocated by us so they should already be contiguous
  110. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  111. (
  112. dq,
  113. dk,
  114. dv,
  115. softmax_d,
  116. *rest,
  117. ) = flashattn_hopper_cuda.varlen_bwd(
  118. dout,
  119. q,
  120. k,
  121. v,
  122. out,
  123. softmax_lse,
  124. dq,
  125. dk,
  126. dv,
  127. cu_seqlens_q,
  128. cu_seqlens_k,
  129. seqused_q,
  130. seqused_k,
  131. max_seqlen_q,
  132. max_seqlen_k,
  133. softmax_scale,
  134. causal,
  135. deterministic,
  136. )
  137. # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
  138. # breakpoint()
  139. return dq, dk, dv, softmax_d
  140. class FlashAttnFunc(torch.autograd.Function):
  141. @staticmethod
  142. def forward(
  143. ctx,
  144. q,
  145. k,
  146. v,
  147. softmax_scale,
  148. causal,
  149. deterministic=False,
  150. ):
  151. if softmax_scale is None:
  152. softmax_scale = q.shape[-1] ** (-0.5)
  153. out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
  154. q,
  155. k,
  156. v,
  157. softmax_scale,
  158. causal
  159. )
  160. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  161. ctx.softmax_scale = softmax_scale
  162. ctx.causal = causal
  163. ctx.deterministic = deterministic
  164. return out, softmax_lse
  165. @staticmethod
  166. def backward(ctx, dout, *args):
  167. q, k, v, out, softmax_lse = ctx.saved_tensors
  168. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  169. _flash_attn_backward(
  170. dout,
  171. q,
  172. k,
  173. v,
  174. out,
  175. softmax_lse,
  176. dq,
  177. dk,
  178. dv,
  179. ctx.softmax_scale,
  180. ctx.causal,
  181. ctx.deterministic,
  182. )
  183. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  184. dk = dk[..., : dout.shape[-1]]
  185. dv = dv[..., : dout.shape[-1]]
  186. return dq, dk, dv, None, None, None
  187. class FlashAttnVarlenFunc(torch.autograd.Function):
  188. @staticmethod
  189. def forward(
  190. ctx,
  191. q,
  192. k,
  193. v,
  194. cu_seqlens_q,
  195. cu_seqlens_k,
  196. max_seqlen_q,
  197. max_seqlen_k,
  198. softmax_scale,
  199. causal,
  200. deterministic=False,
  201. seqused_q=None,
  202. seqused_k=None,
  203. ):
  204. if softmax_scale is None:
  205. softmax_scale = q.shape[-1] ** (-0.5)
  206. out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
  207. q,
  208. k,
  209. v,
  210. cu_seqlens_q,
  211. cu_seqlens_k,
  212. max_seqlen_q,
  213. max_seqlen_k,
  214. softmax_scale,
  215. causal=causal,
  216. seqused_q=seqused_q,
  217. seqused_k=seqused_k,
  218. )
  219. ctx.save_for_backward(
  220. q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k,
  221. seqused_q, seqused_k
  222. )
  223. ctx.max_seqlen_q = max_seqlen_q
  224. ctx.max_seqlen_k = max_seqlen_k
  225. ctx.softmax_scale = softmax_scale
  226. ctx.causal = causal
  227. ctx.deterministic = deterministic
  228. return out, softmax_lse
  229. @staticmethod
  230. def backward(ctx, dout, *args):
  231. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
  232. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  233. _flash_attn_varlen_backward(
  234. dout,
  235. q,
  236. k,
  237. v,
  238. out,
  239. softmax_lse,
  240. dq,
  241. dk,
  242. dv,
  243. cu_seqlens_q,
  244. cu_seqlens_k,
  245. ctx.max_seqlen_q,
  246. ctx.max_seqlen_k,
  247. ctx.softmax_scale,
  248. ctx.causal,
  249. ctx.deterministic,
  250. seqused_q,
  251. seqused_k,
  252. )
  253. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  254. dk = dk[..., : dout.shape[-1]]
  255. dv = dv[..., : dout.shape[-1]]
  256. return dq, dk, dv, None, None, None, None, None, None, None, None, None
  257. def flash_attn_func(
  258. q,
  259. k,
  260. v,
  261. softmax_scale=None,
  262. causal=False,
  263. deterministic=False
  264. ):
  265. """dropout_p should be set to 0.0 during evaluation
  266. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  267. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  268. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  269. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  270. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  271. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  272. 1 1 1 1 0
  273. 1 1 1 1 1
  274. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  275. 0 0
  276. 0 0
  277. 0 0
  278. 1 0
  279. 1 1
  280. If the row of the mask is all zero, the output will be zero.
  281. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  282. will only attend to keys between
  283. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  284. Arguments:
  285. q: (batch_size, seqlen, nheads, headdim)
  286. k: (batch_size, seqlen, nheads_k, headdim)
  287. v: (batch_size, seqlen, nheads_k, headdim)
  288. dropout_p: float. Dropout probability.
  289. softmax_scale: float. The scaling of QK^T before applying softmax.
  290. Default to 1 / sqrt(headdim).
  291. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  292. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  293. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  294. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  295. is added to the attention score of query i and key j.
  296. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  297. which is slightly slower and uses more memory. The forward pass is always deterministic.
  298. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  299. testing only. The returned probabilities are not guaranteed to be correct
  300. (they might not have the right scaling).
  301. Return:
  302. out: (batch_size, seqlen, nheads, headdim).
  303. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  304. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  305. normalization factor).
  306. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  307. The output of softmax (possibly with different scaling). It also encodes the dropout
  308. pattern (negative means that location was dropped, nonnegative means it was kept).
  309. """
  310. return FlashAttnFunc.apply(
  311. q,
  312. k,
  313. v,
  314. softmax_scale,
  315. causal,
  316. deterministic,
  317. )
  318. def flash_attn_varlen_func(
  319. q,
  320. k,
  321. v,
  322. cu_seqlens_q,
  323. cu_seqlens_k,
  324. max_seqlen_q,
  325. max_seqlen_k,
  326. softmax_scale=None,
  327. causal=False,
  328. deterministic=False,
  329. seqused_q=None,
  330. seqused_k=None,
  331. ):
  332. """
  333. Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
  334. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  335. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  336. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  337. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  338. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  339. 1 1 1 1 0
  340. 1 1 1 1 1
  341. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  342. 0 0
  343. 0 0
  344. 0 0
  345. 1 0
  346. 1 1
  347. If the row of the mask is all zero, the output will be zero.
  348. Arguments:
  349. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
  350. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  351. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
  352. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  353. of the sequences in the batch, used to index into q.
  354. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  355. of the sequences in the batch, used to index into kv.
  356. max_seqlen_q: int. Maximum query sequence length in the batch.
  357. max_seqlen_k: int. Maximum key sequence length in the batch.
  358. softmax_scale: float. The scaling of QK^T before applying softmax.
  359. Default to 1 / sqrt(headdim).
  360. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  361. seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
  362. query and output tokens in each sequence.
  363. seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
  364. key and value tokens in each sequence.
  365. Return:
  366. out: (total, nheads, headdim).
  367. softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
  368. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  369. normalization factor).
  370. """
  371. return FlashAttnVarlenFunc.apply(
  372. q,
  373. k,
  374. v,
  375. cu_seqlens_q,
  376. cu_seqlens_k,
  377. max_seqlen_q,
  378. max_seqlen_k,
  379. softmax_scale,
  380. causal,
  381. deterministic,
  382. seqused_q,
  383. seqused_k,
  384. )