flash_attn_interface.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. # isort: off
  6. # We need to import the CUDA kernels after importing torch
  7. import flashattn_hopper_cuda
  8. # isort: on
  9. def _flash_attn_forward(q, k, v, softmax_scale, causal,
  10. q_scale=None, k_scale=None, v_scale=None,
  11. window_size=(-1, -1),
  12. softcap=0.0):
  13. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  14. q, k = [maybe_contiguous(x) for x in (q, k)]
  15. v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
  16. out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.fwd(
  17. q,
  18. k,
  19. v,
  20. None,
  21. softmax_scale,
  22. causal,
  23. q_scale, k_scale, v_scale,
  24. window_size[0], window_size[1],
  25. softcap
  26. )
  27. return out, q, k, v, out_padded, softmax_lse
  28. def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal,
  29. q_scale=None, k_scale=None, v_scale=None,
  30. window_size=(-1, -1), softcap=0.0):
  31. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  32. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  33. out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.fwd_varlen(
  34. q,
  35. k,
  36. v,
  37. None,
  38. cu_seqlens_q, cu_seqlens_k, None, None, max_seqlen_q, max_seqlen_k,
  39. softmax_scale,
  40. causal,
  41. q_scale, k_scale, v_scale,
  42. window_size[0], window_size[1],
  43. softcap,
  44. )
  45. return out, q, k, v, out_padded, softmax_lse
  46. def _flash_attn_backward(
  47. dout,
  48. q,
  49. k,
  50. v,
  51. out,
  52. softmax_lse,
  53. dq,
  54. dk,
  55. dv,
  56. softmax_scale,
  57. causal,
  58. window_size=(-1, -1),
  59. softcap=0.0,
  60. deterministic=False
  61. ):
  62. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  63. # dq, dk, dv are allocated by us so they should already be contiguous
  64. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  65. dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd(
  66. dout,
  67. q,
  68. k,
  69. v,
  70. out,
  71. softmax_lse,
  72. dq,
  73. dk,
  74. dv,
  75. softmax_scale,
  76. causal,
  77. window_size[0],
  78. window_size[1],
  79. softcap,
  80. deterministic,
  81. )
  82. return dq, dk, dv, softmax_d
  83. def _flash_attn_varlen_backward(
  84. dout,
  85. q,
  86. k,
  87. v,
  88. out,
  89. softmax_lse,
  90. cu_seqlens_q,
  91. cu_seqlens_k,
  92. max_seqlen_q,
  93. max_seqlen_k,
  94. dq,
  95. dk,
  96. dv,
  97. softmax_scale,
  98. causal,
  99. window_size=(-1, -1),
  100. softcap=0.0,
  101. deterministic=False
  102. ):
  103. maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
  104. # dq, dk, dv are allocated by us so they should already be contiguous
  105. dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
  106. dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd_varlen(
  107. dout,
  108. q,
  109. k,
  110. v,
  111. out,
  112. softmax_lse,
  113. dq,
  114. dk,
  115. dv,
  116. cu_seqlens_q,
  117. cu_seqlens_k,
  118. None, None,
  119. max_seqlen_q,
  120. max_seqlen_k,
  121. softmax_scale,
  122. causal,
  123. window_size[0],
  124. window_size[1],
  125. softcap,
  126. deterministic,
  127. )
  128. return dq, dk, dv, softmax_d
  129. class FlashAttnQKVPackedFunc(torch.autograd.Function):
  130. @staticmethod
  131. def forward(
  132. ctx,
  133. qkv,
  134. softmax_scale,
  135. causal,
  136. q_scale=None, k_scale=None, v_scale=None,
  137. window_size=(-1, -1),
  138. softcap=0.0,
  139. deterministic=False,
  140. num_heads_q=None,
  141. ):
  142. if softmax_scale is None:
  143. softmax_scale = qkv.shape[-1] ** (-0.5)
  144. if qkv.dim() == 5:
  145. assert qkv.shape[-3] == 3
  146. q, k, v = qkv.unbind(dim=-3)
  147. else:
  148. assert qkv.dim() == 4
  149. assert num_heads_q is not None
  150. num_heads_k = (qkv.shape[2] - num_heads_q) // 2
  151. assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
  152. q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
  153. out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
  154. q,
  155. k,
  156. v,
  157. softmax_scale,
  158. causal=causal,
  159. q_scale=q_scale, k_scale=k_scale, v_scale=v_scale,
  160. window_size=window_size,
  161. softcap=softcap,
  162. )
  163. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  164. ctx.softmax_scale = softmax_scale
  165. ctx.causal = causal
  166. ctx.window_size = window_size
  167. ctx.softcap = softcap
  168. ctx.deterministic = deterministic
  169. ctx.ndim = qkv.dim()
  170. # return out, softmax_lse
  171. return out
  172. @staticmethod
  173. def backward(ctx, dout, *args):
  174. q, k, v, out, softmax_lse = ctx.saved_tensors
  175. if ctx.ndim == 5:
  176. qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
  177. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  178. dq, dk, dv = dqkv.unbind(dim=-3)
  179. else:
  180. num_heads_q = q.shape[2]
  181. num_heads_k = k.shape[2]
  182. qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
  183. dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
  184. dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
  185. _flash_attn_backward(
  186. dout,
  187. q,
  188. k,
  189. v,
  190. out,
  191. softmax_lse,
  192. dq,
  193. dk,
  194. dv,
  195. ctx.softmax_scale,
  196. ctx.causal,
  197. ctx.window_size,
  198. ctx.softcap,
  199. ctx.deterministic,
  200. )
  201. dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
  202. return dqkv, None, None, None, None, None, None, None, None, None
  203. class FlashAttnFunc(torch.autograd.Function):
  204. @staticmethod
  205. def forward(
  206. ctx,
  207. q,
  208. k,
  209. v,
  210. softmax_scale,
  211. causal,
  212. q_scale=None, k_scale=None, v_scale=None,
  213. window_size=(-1, -1),
  214. softcap=0.0,
  215. deterministic=False,
  216. ):
  217. if softmax_scale is None:
  218. softmax_scale = q.shape[-1] ** (-0.5)
  219. out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
  220. q,
  221. k,
  222. v,
  223. softmax_scale,
  224. causal=causal,
  225. q_scale=q_scale, k_scale=k_scale, v_scale=v_scale,
  226. window_size=window_size,
  227. softcap=softcap,
  228. )
  229. ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
  230. ctx.softmax_scale = softmax_scale
  231. ctx.causal = causal
  232. ctx.window_size = window_size
  233. ctx.softcap = softcap
  234. ctx.deterministic = deterministic
  235. return out, softmax_lse
  236. @staticmethod
  237. def backward(ctx, dout, *args):
  238. q, k, v, out, softmax_lse = ctx.saved_tensors
  239. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  240. _flash_attn_backward(
  241. dout,
  242. q,
  243. k,
  244. v,
  245. out,
  246. softmax_lse,
  247. dq,
  248. dk,
  249. dv,
  250. ctx.softmax_scale,
  251. ctx.causal,
  252. ctx.window_size,
  253. ctx.softcap,
  254. ctx.deterministic,
  255. )
  256. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  257. dk = dk[..., : dout.shape[-1]]
  258. dv = dv[..., : dout.shape[-1]]
  259. return dq, dk, dv, None, None, None, None, None, None, None, None
  260. class FlashAttnVarlenFunc(torch.autograd.Function):
  261. @staticmethod
  262. def forward(
  263. ctx,
  264. q,
  265. k,
  266. v,
  267. cu_seqlens_q,
  268. cu_seqlens_k,
  269. max_seqlen_q,
  270. max_seqlen_k,
  271. softmax_scale,
  272. causal,
  273. q_scale=None, k_scale=None, v_scale=None,
  274. window_size=(-1, -1),
  275. softcap=0.0,
  276. deterministic=False,
  277. ):
  278. if softmax_scale is None:
  279. softmax_scale = q.shape[-1] ** (-0.5)
  280. out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
  281. q,
  282. k,
  283. v,
  284. cu_seqlens_q,
  285. cu_seqlens_k,
  286. max_seqlen_q,
  287. max_seqlen_k,
  288. softmax_scale,
  289. causal=causal,
  290. q_scale=q_scale, k_scale=k_scale, v_scale=v_scale,
  291. window_size=window_size,
  292. softcap=softcap,
  293. )
  294. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k)
  295. ctx.max_seqlen_q = max_seqlen_q
  296. ctx.max_seqlen_k = max_seqlen_k
  297. ctx.softmax_scale = softmax_scale
  298. ctx.causal = causal
  299. ctx.window_size = window_size
  300. ctx.softcap = softcap
  301. ctx.deterministic = deterministic
  302. return out, softmax_lse
  303. @staticmethod
  304. def backward(ctx, dout, *args):
  305. q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
  306. dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
  307. _flash_attn_varlen_backward(
  308. dout,
  309. q,
  310. k,
  311. v,
  312. out,
  313. softmax_lse,
  314. cu_seqlens_q,
  315. cu_seqlens_k,
  316. ctx.max_seqlen_q,
  317. ctx.max_seqlen_k,
  318. dq,
  319. dk,
  320. dv,
  321. ctx.softmax_scale,
  322. ctx.causal,
  323. ctx.window_size,
  324. ctx.softcap,
  325. ctx.deterministic,
  326. )
  327. dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
  328. dk = dk[..., : dout.shape[-1]]
  329. dv = dv[..., : dout.shape[-1]]
  330. return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
  331. def flash_attn_qkvpacked_func(
  332. qkv,
  333. softmax_scale=None,
  334. causal=False,
  335. q_scale=None, k_scale=None, v_scale=None,
  336. window_size=(-1, -1),
  337. softcap=0.0,
  338. deterministic=False,
  339. num_heads_q=None,
  340. ):
  341. """dropout_p should be set to 0.0 during evaluation
  342. If Q, K, V are already stacked into 1 tensor, this function will be faster than
  343. calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
  344. of the gradients of Q, K, V.
  345. For multi-query and grouped-query attention (MQA/GQA), please see
  346. flash_attn_kvpacked_func and flash_attn_func.
  347. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  348. will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
  349. Arguments:
  350. qkv: (batch_size, seqlen, 3, nheads, headdim)
  351. dropout_p: float. Dropout probability.
  352. softmax_scale: float. The scaling of QK^T before applying softmax.
  353. Default to 1 / sqrt(headdim).
  354. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  355. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  356. softcap: float. Anything > 0 activates softcapping attention.
  357. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
  358. the attention score of query i and key j.
  359. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  360. which is slightly slower and uses more memory. The forward pass is always deterministic.
  361. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  362. testing only. The returned probabilities are not guaranteed to be correct
  363. (they might not have the right scaling).
  364. Return:
  365. out: (batch_size, seqlen, nheads, headdim).
  366. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  367. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  368. normalization factor).
  369. S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
  370. The output of softmax (possibly with different scaling). It also encodes the dropout
  371. pattern (negative means that location was dropped, nonnegative means it was kept).
  372. """
  373. return FlashAttnQKVPackedFunc.apply(
  374. qkv,
  375. softmax_scale,
  376. causal,
  377. q_scale, k_scale, v_scale,
  378. window_size,
  379. softcap,
  380. deterministic,
  381. num_heads_q,
  382. )
  383. def flash_attn_func(
  384. q,
  385. k,
  386. v,
  387. softmax_scale=None,
  388. causal=False,
  389. q_scale=None, k_scale=None, v_scale=None,
  390. window_size=(-1, -1),
  391. softcap=0.0,
  392. deterministic=False
  393. ):
  394. """dropout_p should be set to 0.0 during evaluation
  395. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
  396. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
  397. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
  398. 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
  399. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
  400. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
  401. 1 1 1 1 0
  402. 1 1 1 1 1
  403. If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
  404. 0 0
  405. 0 0
  406. 0 0
  407. 1 0
  408. 1 1
  409. If the row of the mask is all zero, the output will be zero.
  410. If window_size != (-1, -1), implements sliding window local attention. Query at position i
  411. will only attend to keys between
  412. [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
  413. Arguments:
  414. q: (batch_size, seqlen, nheads, headdim)
  415. k: (batch_size, seqlen, nheads_k, headdim)
  416. v: (batch_size, seqlen, nheads_k, headdim)
  417. dropout_p: float. Dropout probability.
  418. softmax_scale: float. The scaling of QK^T before applying softmax.
  419. Default to 1 / sqrt(headdim).
  420. causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
  421. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
  422. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
  423. (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
  424. is added to the attention score of query i and key j.
  425. deterministic: bool. Whether to use the deterministic implementation of the backward pass,
  426. which is slightly slower and uses more memory. The forward pass is always deterministic.
  427. return_attn_probs: bool. Whether to return the attention probabilities. This option is for
  428. testing only. The returned probabilities are not guaranteed to be correct
  429. (they might not have the right scaling).
  430. Return:
  431. out: (batch_size, seqlen, nheads, headdim).
  432. softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
  433. logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
  434. normalization factor).
  435. """
  436. return FlashAttnFunc.apply(
  437. q,
  438. k,
  439. v,
  440. softmax_scale,
  441. causal,
  442. q_scale, k_scale, v_scale,
  443. window_size,
  444. softcap,
  445. deterministic,
  446. )
  447. def flash_attn_varlen_func(
  448. q,
  449. k,
  450. v,
  451. cu_seqlens_q,
  452. cu_seqlens_k,
  453. max_seqlen_q,
  454. max_seqlen_k,
  455. softmax_scale=None,
  456. causal=False,
  457. q_scale=None, k_scale=None, v_scale=None,
  458. window_size=(-1, -1),
  459. softcap=0.0,
  460. deterministic=False
  461. ):
  462. return FlashAttnVarlenFunc.apply(
  463. q,
  464. k,
  465. v,
  466. cu_seqlens_q,
  467. cu_seqlens_k,
  468. max_seqlen_q,
  469. max_seqlen_k,
  470. softmax_scale,
  471. causal,
  472. q_scale, k_scale, v_scale,
  473. window_size,
  474. softcap,
  475. deterministic,
  476. )