flash_attn_triton.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160
  1. """
  2. *Experimental* implementation of FlashAttention in Triton.
  3. Tested with triton==2.0.0.dev20221202.
  4. Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
  5. other than 64:
  6. https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
  7. We'll update this implementation with the new Triton backend once this is fixed.
  8. We use the FlashAttention implementation from Phil Tillet a starting point.
  9. https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
  10. Changes:
  11. - Implement both causal and non-causal attention.
  12. - Implement both self-attention and cross-attention.
  13. - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
  14. - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
  15. - Support attention bias.
  16. - Speed up the forward pass a bit, and only store the LSE instead of m and l.
  17. - Make the backward for d=128 much faster by reducing register spilling.
  18. - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
  19. small batch size * nheads.
  20. Caution:
  21. - This is an *experimental* implementation. The forward pass should be quite robust but
  22. I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
  23. - This implementation has only been tested on A100.
  24. - If you plan to use headdim other than 64 and 128, you should test for race conditions
  25. (due to the Triton compiler), as done in tests/test_flash_attn.py
  26. "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
  27. for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
  28. that there are none left for other head dimensions.
  29. Differences between this Triton version and the CUDA version:
  30. - Triton version doesn't support dropout.
  31. - Triton forward is generally faster than CUDA forward, while Triton backward is
  32. generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
  33. than CUDA forward + backward.
  34. - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
  35. - Triton version supports attention bias, while CUDA version doesn't.
  36. """
  37. import math
  38. import torch
  39. import triton
  40. import triton.language as tl
  41. # Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
  42. # @triton.autotune(
  43. # configs=[
  44. # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
  45. # # This config has a race condition when EVEN_M == False, disabling it for now.
  46. # # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
  47. # ],
  48. # key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
  49. # )
  50. @triton.heuristics(
  51. {
  52. "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
  53. "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
  54. "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
  55. }
  56. )
  57. @triton.jit
  58. def _fwd_kernel(
  59. Q,
  60. K,
  61. V,
  62. Bias,
  63. Out,
  64. Lse,
  65. TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
  66. softmax_scale,
  67. stride_qb,
  68. stride_qh,
  69. stride_qm,
  70. stride_kb,
  71. stride_kh,
  72. stride_kn,
  73. stride_vb,
  74. stride_vh,
  75. stride_vn,
  76. stride_bb,
  77. stride_bh,
  78. stride_bm,
  79. stride_ob,
  80. stride_oh,
  81. stride_om,
  82. nheads,
  83. seqlen_q,
  84. seqlen_k,
  85. seqlen_q_rounded,
  86. headdim,
  87. CACHE_KEY_SEQLEN_Q,
  88. CACHE_KEY_SEQLEN_K,
  89. BIAS_TYPE: tl.constexpr,
  90. IS_CAUSAL: tl.constexpr,
  91. BLOCK_HEADDIM: tl.constexpr,
  92. EVEN_M: tl.constexpr,
  93. EVEN_N: tl.constexpr,
  94. EVEN_HEADDIM: tl.constexpr,
  95. BLOCK_M: tl.constexpr,
  96. BLOCK_N: tl.constexpr,
  97. ):
  98. start_m = tl.program_id(0)
  99. off_hb = tl.program_id(1)
  100. off_b = off_hb // nheads
  101. off_h = off_hb % nheads
  102. # off_b = tl.program_id(1)
  103. # off_h = tl.program_id(2)
  104. # off_hb = off_b * nheads + off_h
  105. # initialize offsets
  106. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  107. offs_n = tl.arange(0, BLOCK_N)
  108. offs_d = tl.arange(0, BLOCK_HEADDIM)
  109. # Initialize pointers to Q, K, V
  110. # Adding parenthesis around indexing might use int32 math instead of int64 math?
  111. # https://github.com/openai/triton/issues/741
  112. # I'm seeing a tiny bit of difference (5-7us)
  113. q_ptrs = (
  114. Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
  115. )
  116. k_ptrs = (
  117. K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
  118. )
  119. v_ptrs = (
  120. V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
  121. )
  122. if BIAS_TYPE == "vector":
  123. b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
  124. elif BIAS_TYPE == "matrix":
  125. b_ptrs = (
  126. Bias
  127. + off_b * stride_bb
  128. + off_h * stride_bh
  129. + (offs_m[:, None] * stride_bm + offs_n[None, :])
  130. )
  131. # initialize pointer to m and l
  132. t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
  133. lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  134. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  135. acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
  136. # load q: it will stay in SRAM throughout
  137. # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
  138. # tl.load(q_ptrs), we get the wrong output!
  139. if EVEN_M & EVEN_N:
  140. if EVEN_HEADDIM:
  141. q = tl.load(q_ptrs)
  142. else:
  143. q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
  144. else:
  145. if EVEN_HEADDIM:
  146. q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
  147. else:
  148. q = tl.load(
  149. q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
  150. )
  151. # loop over k, v and update accumulator
  152. end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
  153. for start_n in range(0, end_n, BLOCK_N):
  154. start_n = tl.multiple_of(start_n, BLOCK_N)
  155. # -- compute qk ----
  156. if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
  157. if EVEN_HEADDIM:
  158. k = tl.load(k_ptrs + start_n * stride_kn)
  159. else:
  160. k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
  161. else:
  162. if EVEN_HEADDIM:
  163. k = tl.load(
  164. k_ptrs + start_n * stride_kn,
  165. mask=(start_n + offs_n)[:, None] < seqlen_k,
  166. other=0.0,
  167. )
  168. else:
  169. k = tl.load(
  170. k_ptrs + start_n * stride_kn,
  171. mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
  172. other=0.0,
  173. )
  174. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  175. qk += tl.dot(q, k, trans_b=True)
  176. # Trying to combine the two masks seem to make the result wrong
  177. if not EVEN_N: # Need to mask out otherwise the softmax is wrong
  178. qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
  179. if IS_CAUSAL:
  180. qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
  181. if BIAS_TYPE != "none":
  182. if BIAS_TYPE == "vector":
  183. if EVEN_N:
  184. bias = tl.load(b_ptrs + start_n).to(tl.float32)
  185. else:
  186. bias = tl.load(
  187. b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0
  188. ).to(tl.float32)
  189. bias = bias[None, :]
  190. elif BIAS_TYPE == "matrix":
  191. if EVEN_M & EVEN_N:
  192. bias = tl.load(b_ptrs + start_n).to(tl.float32)
  193. else:
  194. bias = tl.load(
  195. b_ptrs + start_n,
  196. mask=(offs_m[:, None] < seqlen_q)
  197. & ((start_n + offs_n)[None, :] < seqlen_k),
  198. other=0.0,
  199. ).to(tl.float32)
  200. # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
  201. # can then fuse the mult and add into an fma instruction. But if we have bias we need to
  202. # to multiply with softmax_scale here.
  203. qk = qk * softmax_scale + bias
  204. m_ij = tl.maximum(tl.max(qk, 1), lse_i)
  205. p = tl.exp(qk - m_ij[:, None])
  206. else:
  207. m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
  208. p = tl.exp(qk * softmax_scale - m_ij[:, None])
  209. l_ij = tl.sum(p, 1)
  210. # scale acc_o
  211. acc_o_scale = tl.exp(m_i - m_ij)
  212. # # -- update output accumulator --
  213. # BUG: have to store and immediately load
  214. tl.store(t_ptrs, acc_o_scale)
  215. acc_o_scale = tl.load(t_ptrs)
  216. acc_o = acc_o * acc_o_scale[:, None]
  217. # update acc_o
  218. if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
  219. if EVEN_HEADDIM:
  220. v = tl.load(v_ptrs + start_n * stride_vn)
  221. else:
  222. v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
  223. else:
  224. if EVEN_HEADDIM:
  225. v = tl.load(
  226. v_ptrs + start_n * stride_vn,
  227. mask=(start_n + offs_n)[:, None] < seqlen_k,
  228. other=0.0,
  229. )
  230. else:
  231. v = tl.load(
  232. v_ptrs + start_n * stride_vn,
  233. mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
  234. other=0.0,
  235. )
  236. p = p.to(v.dtype)
  237. acc_o += tl.dot(p, v)
  238. # -- update statistics
  239. m_i = m_ij
  240. l_i_new = tl.exp(lse_i - m_ij) + l_ij
  241. lse_i = m_ij + tl.log(l_i_new)
  242. o_scale = tl.exp(m_i - lse_i)
  243. # BUG: have to store and immediately load
  244. tl.store(t_ptrs, o_scale)
  245. o_scale = tl.load(t_ptrs)
  246. acc_o = acc_o * o_scale[:, None]
  247. # rematerialize offsets to save registers
  248. start_m = tl.program_id(0)
  249. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  250. # write back l and m
  251. lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
  252. tl.store(lse_ptrs, lse_i)
  253. # initialize pointers to output
  254. offs_d = tl.arange(0, BLOCK_HEADDIM)
  255. out_ptrs = (
  256. Out
  257. + off_b * stride_ob
  258. + off_h * stride_oh
  259. + (offs_m[:, None] * stride_om + offs_d[None, :])
  260. )
  261. if EVEN_M:
  262. if EVEN_HEADDIM:
  263. tl.store(out_ptrs, acc_o)
  264. else:
  265. tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
  266. else:
  267. if EVEN_HEADDIM:
  268. tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
  269. else:
  270. tl.store(
  271. out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
  272. )
  273. @triton.jit
  274. def _bwd_preprocess_do_o_dot(
  275. Out,
  276. DO,
  277. Delta,
  278. stride_ob,
  279. stride_oh,
  280. stride_om,
  281. stride_dob,
  282. stride_doh,
  283. stride_dom,
  284. nheads,
  285. seqlen_q,
  286. seqlen_q_rounded,
  287. headdim,
  288. BLOCK_M: tl.constexpr,
  289. BLOCK_HEADDIM: tl.constexpr,
  290. ):
  291. start_m = tl.program_id(0)
  292. off_hb = tl.program_id(1)
  293. off_b = off_hb // nheads
  294. off_h = off_hb % nheads
  295. # initialize offsets
  296. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  297. offs_d = tl.arange(0, BLOCK_HEADDIM)
  298. # load
  299. o = tl.load(
  300. Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
  301. mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
  302. other=0.0,
  303. ).to(tl.float32)
  304. do = tl.load(
  305. DO
  306. + off_b * stride_dob
  307. + off_h * stride_doh
  308. + offs_m[:, None] * stride_dom
  309. + offs_d[None, :],
  310. mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
  311. other=0.0,
  312. ).to(tl.float32)
  313. delta = tl.sum(o * do, axis=1)
  314. # write-back
  315. tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
  316. @triton.jit
  317. def _bwd_store_dk_dv(
  318. dk_ptrs,
  319. dv_ptrs,
  320. dk,
  321. dv,
  322. offs_n,
  323. offs_d,
  324. seqlen_k,
  325. headdim,
  326. EVEN_M: tl.constexpr,
  327. EVEN_N: tl.constexpr,
  328. EVEN_HEADDIM: tl.constexpr,
  329. ):
  330. # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
  331. # if we just call tl.store(dv_ptrs), there's a race condition
  332. if EVEN_N & EVEN_M:
  333. if EVEN_HEADDIM:
  334. tl.store(dv_ptrs, dv)
  335. tl.store(dk_ptrs, dk)
  336. else:
  337. tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
  338. tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
  339. else:
  340. if EVEN_HEADDIM:
  341. tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
  342. tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
  343. else:
  344. tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
  345. tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
  346. @triton.jit
  347. def _bwd_kernel_one_col_block(
  348. start_n,
  349. Q,
  350. K,
  351. V,
  352. Bias,
  353. DO,
  354. DQ,
  355. DK,
  356. DV,
  357. LSE,
  358. D,
  359. softmax_scale,
  360. stride_qm,
  361. stride_kn,
  362. stride_vn,
  363. stride_bm,
  364. stride_dom,
  365. stride_dqm,
  366. stride_dkn,
  367. stride_dvn,
  368. seqlen_q,
  369. seqlen_k,
  370. headdim,
  371. ATOMIC_ADD: tl.constexpr,
  372. BIAS_TYPE: tl.constexpr,
  373. IS_CAUSAL: tl.constexpr,
  374. BLOCK_HEADDIM: tl.constexpr,
  375. EVEN_M: tl.constexpr,
  376. EVEN_N: tl.constexpr,
  377. EVEN_HEADDIM: tl.constexpr,
  378. BLOCK_M: tl.constexpr,
  379. BLOCK_N: tl.constexpr,
  380. ):
  381. # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
  382. begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
  383. # initialize row/col offsets
  384. offs_qm = begin_m + tl.arange(0, BLOCK_M)
  385. offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
  386. offs_m = tl.arange(0, BLOCK_M)
  387. offs_d = tl.arange(0, BLOCK_HEADDIM)
  388. # initialize pointers to value-like data
  389. q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
  390. k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
  391. v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
  392. do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
  393. dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
  394. if BIAS_TYPE == "vector":
  395. b_ptrs = Bias + offs_n
  396. elif BIAS_TYPE == "matrix":
  397. b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
  398. # initialize dv and dk
  399. dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
  400. dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
  401. # There seems to be some problem with Triton pipelining that makes results wrong for
  402. # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
  403. # may have zero step, and pipelining with the bias matrix could screw it up.
  404. # So we just exit early.
  405. if begin_m >= seqlen_q:
  406. dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
  407. dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
  408. _bwd_store_dk_dv(
  409. dk_ptrs,
  410. dv_ptrs,
  411. dk,
  412. dv,
  413. offs_n,
  414. offs_d,
  415. seqlen_k,
  416. headdim,
  417. EVEN_M=EVEN_M,
  418. EVEN_N=EVEN_N,
  419. EVEN_HEADDIM=EVEN_HEADDIM,
  420. )
  421. return
  422. # k and v stay in SRAM throughout
  423. # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
  424. # if we just call tl.load(k_ptrs), we get the wrong output!
  425. if EVEN_N & EVEN_M:
  426. if EVEN_HEADDIM:
  427. k = tl.load(k_ptrs)
  428. v = tl.load(v_ptrs)
  429. else:
  430. k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
  431. v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
  432. else:
  433. if EVEN_HEADDIM:
  434. k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
  435. v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
  436. else:
  437. k = tl.load(
  438. k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
  439. )
  440. v = tl.load(
  441. v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
  442. )
  443. # loop over rows
  444. num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
  445. for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
  446. start_m = tl.multiple_of(start_m, BLOCK_M)
  447. offs_m_curr = start_m + offs_m
  448. # load q, k, v, do on-chip
  449. # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
  450. if EVEN_M & EVEN_HEADDIM:
  451. q = tl.load(q_ptrs)
  452. else:
  453. if EVEN_HEADDIM:
  454. q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
  455. else:
  456. q = tl.load(
  457. q_ptrs,
  458. mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
  459. other=0.0,
  460. )
  461. # recompute p = softmax(qk, dim=-1).T
  462. qk = tl.dot(q, k, trans_b=True)
  463. # Trying to combine the two masks seem to make the result wrong
  464. if not EVEN_N: # Need to mask out otherwise the softmax is wrong
  465. qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
  466. if IS_CAUSAL:
  467. qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
  468. if BIAS_TYPE != "none":
  469. tl.debug_barrier() # Race condition otherwise
  470. if BIAS_TYPE == "vector":
  471. if EVEN_N:
  472. bias = tl.load(b_ptrs).to(tl.float32)
  473. else:
  474. bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
  475. bias = bias[None, :]
  476. elif BIAS_TYPE == "matrix":
  477. if EVEN_M & EVEN_N:
  478. bias = tl.load(b_ptrs).to(tl.float32)
  479. else:
  480. bias = tl.load(
  481. b_ptrs,
  482. mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k),
  483. other=0.0,
  484. ).to(tl.float32)
  485. qk = qk * softmax_scale + bias
  486. # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
  487. # Also wrong for headdim=64.
  488. if not (EVEN_M & EVEN_HEADDIM):
  489. tl.debug_barrier()
  490. lse_i = tl.load(LSE + offs_m_curr)
  491. if BIAS_TYPE == "none":
  492. p = tl.exp(qk * softmax_scale - lse_i[:, None])
  493. else:
  494. p = tl.exp(qk - lse_i[:, None])
  495. # compute dv
  496. # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
  497. # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
  498. # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
  499. # the output is correct.
  500. if EVEN_M & EVEN_HEADDIM:
  501. do = tl.load(do_ptrs)
  502. else:
  503. # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
  504. do = tl.load(
  505. do_ptrs,
  506. mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
  507. other=0.0,
  508. )
  509. # if EVEN_M:
  510. # if EVEN_HEADDIM:
  511. # do = tl.load(do_ptrs)
  512. # else:
  513. # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
  514. # else:
  515. # if EVEN_HEADDIM:
  516. # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
  517. # else:
  518. # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
  519. # & (offs_d[None, :] < headdim), other=0.0)
  520. dv += tl.dot(p.to(do.dtype), do, trans_a=True)
  521. # compute dp = dot(v, do)
  522. # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
  523. # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
  524. # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
  525. if not (EVEN_M & EVEN_HEADDIM):
  526. tl.debug_barrier()
  527. dp = tl.dot(do, v, trans_b=True)
  528. # There's a race condition for headdim=48
  529. if not EVEN_HEADDIM:
  530. tl.debug_barrier()
  531. # compute ds = p * (dp - delta[:, None])
  532. # Putting the subtraction after the dp matmul (instead of before) is slightly faster
  533. Di = tl.load(D + offs_m_curr)
  534. # Converting ds to q.dtype here reduces register pressure and makes it much faster
  535. # for BLOCK_HEADDIM=128
  536. ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
  537. # compute dk = dot(ds.T, q)
  538. dk += tl.dot(ds, q, trans_a=True)
  539. # compute dq
  540. if not (
  541. EVEN_M & EVEN_HEADDIM
  542. ): # Otherewise there's a race condition when BIAS_TYPE='matrix'
  543. tl.debug_barrier()
  544. if not ATOMIC_ADD:
  545. if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
  546. dq = tl.load(dq_ptrs, eviction_policy="evict_last")
  547. dq += tl.dot(ds, k)
  548. tl.store(dq_ptrs, dq, eviction_policy="evict_last")
  549. else:
  550. if EVEN_HEADDIM:
  551. dq = tl.load(
  552. dq_ptrs,
  553. mask=offs_m_curr[:, None] < seqlen_q,
  554. other=0.0,
  555. eviction_policy="evict_last",
  556. )
  557. dq += tl.dot(ds, k)
  558. tl.store(
  559. dq_ptrs,
  560. dq,
  561. mask=offs_m_curr[:, None] < seqlen_q,
  562. eviction_policy="evict_last",
  563. )
  564. else:
  565. dq = tl.load(
  566. dq_ptrs,
  567. mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
  568. other=0.0,
  569. eviction_policy="evict_last",
  570. )
  571. dq += tl.dot(ds, k)
  572. tl.store(
  573. dq_ptrs,
  574. dq,
  575. mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
  576. eviction_policy="evict_last",
  577. )
  578. else: # If we're parallelizing across the seqlen_k dimension
  579. dq = tl.dot(ds, k)
  580. if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
  581. tl.atomic_add(dq_ptrs, dq)
  582. else:
  583. if EVEN_HEADDIM:
  584. tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
  585. else:
  586. tl.atomic_add(
  587. dq_ptrs,
  588. dq,
  589. mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
  590. )
  591. # increment pointers
  592. dq_ptrs += BLOCK_M * stride_dqm
  593. q_ptrs += BLOCK_M * stride_qm
  594. do_ptrs += BLOCK_M * stride_dom
  595. if BIAS_TYPE == "matrix":
  596. b_ptrs += BLOCK_M * stride_bm
  597. # write-back
  598. dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
  599. dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
  600. _bwd_store_dk_dv(
  601. dk_ptrs,
  602. dv_ptrs,
  603. dk,
  604. dv,
  605. offs_n,
  606. offs_d,
  607. seqlen_k,
  608. headdim,
  609. EVEN_M=EVEN_M,
  610. EVEN_N=EVEN_N,
  611. EVEN_HEADDIM=EVEN_HEADDIM,
  612. )
  613. def init_to_zero(name):
  614. return lambda nargs: nargs[name].zero_()
  615. @triton.autotune(
  616. configs=[
  617. triton.Config(
  618. {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
  619. num_warps=8,
  620. num_stages=1,
  621. pre_hook=init_to_zero("DQ"),
  622. ),
  623. triton.Config(
  624. {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
  625. num_warps=8,
  626. num_stages=1,
  627. pre_hook=init_to_zero("DQ"),
  628. ),
  629. # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
  630. # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
  631. # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
  632. # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
  633. # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
  634. # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
  635. ],
  636. key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
  637. )
  638. @triton.heuristics(
  639. {
  640. "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
  641. "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
  642. "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
  643. }
  644. )
  645. @triton.jit
  646. def _bwd_kernel(
  647. Q,
  648. K,
  649. V,
  650. Bias,
  651. DO,
  652. DQ,
  653. DK,
  654. DV,
  655. LSE,
  656. D,
  657. softmax_scale,
  658. stride_qb,
  659. stride_qh,
  660. stride_qm,
  661. stride_kb,
  662. stride_kh,
  663. stride_kn,
  664. stride_vb,
  665. stride_vh,
  666. stride_vn,
  667. stride_bb,
  668. stride_bh,
  669. stride_bm,
  670. stride_dob,
  671. stride_doh,
  672. stride_dom,
  673. stride_dqb,
  674. stride_dqh,
  675. stride_dqm,
  676. stride_dkb,
  677. stride_dkh,
  678. stride_dkn,
  679. stride_dvb,
  680. stride_dvh,
  681. stride_dvn,
  682. nheads,
  683. seqlen_q,
  684. seqlen_k,
  685. seqlen_q_rounded,
  686. headdim,
  687. CACHE_KEY_SEQLEN_Q,
  688. CACHE_KEY_SEQLEN_K,
  689. BIAS_TYPE: tl.constexpr,
  690. IS_CAUSAL: tl.constexpr,
  691. BLOCK_HEADDIM: tl.constexpr,
  692. SEQUENCE_PARALLEL: tl.constexpr,
  693. EVEN_M: tl.constexpr,
  694. EVEN_N: tl.constexpr,
  695. EVEN_HEADDIM: tl.constexpr,
  696. BLOCK_M: tl.constexpr,
  697. BLOCK_N: tl.constexpr,
  698. ):
  699. off_hb = tl.program_id(1)
  700. off_b = off_hb // nheads
  701. off_h = off_hb % nheads
  702. # offset pointers for batch/head
  703. Q += off_b * stride_qb + off_h * stride_qh
  704. K += off_b * stride_kb + off_h * stride_kh
  705. V += off_b * stride_vb + off_h * stride_vh
  706. DO += off_b * stride_dob + off_h * stride_doh
  707. DQ += off_b * stride_dqb + off_h * stride_dqh
  708. DK += off_b * stride_dkb + off_h * stride_dkh
  709. DV += off_b * stride_dvb + off_h * stride_dvh
  710. if BIAS_TYPE != "none":
  711. Bias += off_b * stride_bb + off_h * stride_bh
  712. # pointer to row-wise quantities in value-like data
  713. D += off_hb * seqlen_q_rounded
  714. LSE += off_hb * seqlen_q_rounded
  715. if not SEQUENCE_PARALLEL:
  716. num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
  717. for start_n in range(0, num_block_n):
  718. _bwd_kernel_one_col_block(
  719. start_n,
  720. Q,
  721. K,
  722. V,
  723. Bias,
  724. DO,
  725. DQ,
  726. DK,
  727. DV,
  728. LSE,
  729. D,
  730. softmax_scale,
  731. stride_qm,
  732. stride_kn,
  733. stride_vn,
  734. stride_bm,
  735. stride_dom,
  736. stride_dqm,
  737. stride_dkn,
  738. stride_dvn,
  739. seqlen_q,
  740. seqlen_k,
  741. headdim,
  742. ATOMIC_ADD=False,
  743. BIAS_TYPE=BIAS_TYPE,
  744. IS_CAUSAL=IS_CAUSAL,
  745. BLOCK_HEADDIM=BLOCK_HEADDIM,
  746. EVEN_M=EVEN_M,
  747. EVEN_N=EVEN_N,
  748. EVEN_HEADDIM=EVEN_HEADDIM,
  749. BLOCK_M=BLOCK_M,
  750. BLOCK_N=BLOCK_N,
  751. )
  752. else:
  753. start_n = tl.program_id(0)
  754. _bwd_kernel_one_col_block(
  755. start_n,
  756. Q,
  757. K,
  758. V,
  759. Bias,
  760. DO,
  761. DQ,
  762. DK,
  763. DV,
  764. LSE,
  765. D,
  766. softmax_scale,
  767. stride_qm,
  768. stride_kn,
  769. stride_vn,
  770. stride_bm,
  771. stride_dom,
  772. stride_dqm,
  773. stride_dkn,
  774. stride_dvn,
  775. seqlen_q,
  776. seqlen_k,
  777. headdim,
  778. ATOMIC_ADD=True,
  779. BIAS_TYPE=BIAS_TYPE,
  780. IS_CAUSAL=IS_CAUSAL,
  781. BLOCK_HEADDIM=BLOCK_HEADDIM,
  782. EVEN_M=EVEN_M,
  783. EVEN_N=EVEN_N,
  784. EVEN_HEADDIM=EVEN_HEADDIM,
  785. BLOCK_M=BLOCK_M,
  786. BLOCK_N=BLOCK_N,
  787. )
  788. def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
  789. # shape constraints
  790. batch, seqlen_q, nheads, d = q.shape
  791. _, seqlen_k, _, _ = k.shape
  792. assert k.shape == (batch, seqlen_k, nheads, d)
  793. assert v.shape == (batch, seqlen_k, nheads, d)
  794. assert d <= 128, "FlashAttention only support head dimensions up to 128"
  795. assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
  796. assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
  797. assert q.is_cuda and k.is_cuda and v.is_cuda
  798. softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
  799. has_bias = bias is not None
  800. bias_type = "none"
  801. if has_bias:
  802. assert bias.dtype in [q.dtype, torch.float]
  803. assert bias.is_cuda
  804. assert bias.dim() == 4
  805. if bias.stride(-1) != 1:
  806. bias = bias.contiguous()
  807. if bias.shape[2:] == (1, seqlen_k):
  808. bias_type = "vector"
  809. elif bias.shape[2:] == (seqlen_q, seqlen_k):
  810. bias_type = "matrix"
  811. else:
  812. raise RuntimeError(
  813. "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
  814. )
  815. bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
  816. bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
  817. seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
  818. lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
  819. tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
  820. o = torch.empty_like(q)
  821. BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
  822. BLOCK = 128
  823. num_warps = 4 if d <= 64 else 8
  824. grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
  825. _fwd_kernel[grid](
  826. q,
  827. k,
  828. v,
  829. bias,
  830. o,
  831. lse,
  832. tmp,
  833. softmax_scale,
  834. q.stride(0),
  835. q.stride(2),
  836. q.stride(1),
  837. k.stride(0),
  838. k.stride(2),
  839. k.stride(1),
  840. v.stride(0),
  841. v.stride(2),
  842. v.stride(1),
  843. *bias_strides,
  844. o.stride(0),
  845. o.stride(2),
  846. o.stride(1),
  847. nheads,
  848. seqlen_q,
  849. seqlen_k,
  850. seqlen_q_rounded,
  851. d,
  852. seqlen_q // 32,
  853. seqlen_k // 32, # key for triton cache (limit number of compilations)
  854. # Can't use kwargs here because triton autotune expects key to be args, not kwargs
  855. # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
  856. bias_type,
  857. causal,
  858. BLOCK_HEADDIM,
  859. BLOCK_M=BLOCK,
  860. BLOCK_N=BLOCK,
  861. num_warps=num_warps,
  862. num_stages=1,
  863. )
  864. return o, lse, softmax_scale # softmax_scale could have been updated
  865. def _flash_attn_backward(
  866. do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None
  867. ):
  868. # Make sure that the last dimension is contiguous
  869. if do.stride(-1) != 1:
  870. do = do.contiguous()
  871. batch, seqlen_q, nheads, d = q.shape
  872. _, seqlen_k, _, _ = k.shape
  873. # assert d in {16, 32, 64, 128}
  874. assert d <= 128
  875. seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
  876. assert lse.shape == (batch, nheads, seqlen_q_rounded)
  877. assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
  878. assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
  879. softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
  880. # dq_accum = torch.zeros_like(q, dtype=torch.float32)
  881. dq_accum = torch.empty_like(q, dtype=torch.float32)
  882. delta = torch.empty_like(lse)
  883. # delta = torch.zeros_like(lse)
  884. BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
  885. grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
  886. _bwd_preprocess_do_o_dot[grid](
  887. o,
  888. do,
  889. delta,
  890. o.stride(0),
  891. o.stride(2),
  892. o.stride(1),
  893. do.stride(0),
  894. do.stride(2),
  895. do.stride(1),
  896. nheads,
  897. seqlen_q,
  898. seqlen_q_rounded,
  899. d,
  900. BLOCK_M=128,
  901. BLOCK_HEADDIM=BLOCK_HEADDIM,
  902. )
  903. has_bias = bias is not None
  904. bias_type = "none"
  905. if has_bias:
  906. assert bias.dtype in [q.dtype, torch.float]
  907. assert bias.is_cuda
  908. assert bias.dim() == 4
  909. assert bias.stride(-1) == 1
  910. if bias.shape[2:] == (1, seqlen_k):
  911. bias_type = "vector"
  912. elif bias.shape[2:] == (seqlen_q, seqlen_k):
  913. bias_type = "matrix"
  914. else:
  915. raise RuntimeError(
  916. "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
  917. )
  918. bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
  919. bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
  920. # BLOCK_M = 128
  921. # BLOCK_N = 64
  922. # num_warps = 4
  923. grid = lambda META: (
  924. triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
  925. batch * nheads,
  926. )
  927. _bwd_kernel[grid](
  928. q,
  929. k,
  930. v,
  931. bias,
  932. do,
  933. dq_accum,
  934. dk,
  935. dv,
  936. lse,
  937. delta,
  938. softmax_scale,
  939. q.stride(0),
  940. q.stride(2),
  941. q.stride(1),
  942. k.stride(0),
  943. k.stride(2),
  944. k.stride(1),
  945. v.stride(0),
  946. v.stride(2),
  947. v.stride(1),
  948. *bias_strides,
  949. do.stride(0),
  950. do.stride(2),
  951. do.stride(1),
  952. dq_accum.stride(0),
  953. dq_accum.stride(2),
  954. dq_accum.stride(1),
  955. dk.stride(0),
  956. dk.stride(2),
  957. dk.stride(1),
  958. dv.stride(0),
  959. dv.stride(2),
  960. dv.stride(1),
  961. nheads,
  962. seqlen_q,
  963. seqlen_k,
  964. seqlen_q_rounded,
  965. d,
  966. seqlen_q // 32,
  967. seqlen_k // 32, # key for triton cache (limit number of compilations)
  968. # Can't use kwargs here because triton autotune expects key to be args, not kwargs
  969. # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
  970. bias_type,
  971. causal,
  972. BLOCK_HEADDIM,
  973. # SEQUENCE_PARALLEL=False,
  974. # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
  975. # num_warps=num_warps,
  976. # num_stages=1,
  977. )
  978. dq.copy_(dq_accum)
  979. class FlashAttnQKVPackedFunc(torch.autograd.Function):
  980. @staticmethod
  981. def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
  982. """
  983. qkv: (batch, seqlen, 3, nheads, headdim)
  984. bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
  985. For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
  986. ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
  987. """
  988. # Make sure that the last dimension is contiguous
  989. if qkv.stride(-1) != 1:
  990. qkv = qkv.contiguous()
  991. o, lse, ctx.softmax_scale = _flash_attn_forward(
  992. qkv[:, :, 0],
  993. qkv[:, :, 1],
  994. qkv[:, :, 2],
  995. bias=bias,
  996. causal=causal,
  997. softmax_scale=softmax_scale,
  998. )
  999. ctx.save_for_backward(qkv, o, lse, bias)
  1000. ctx.causal = causal
  1001. return o
  1002. @staticmethod
  1003. def backward(ctx, do):
  1004. qkv, o, lse, bias = ctx.saved_tensors
  1005. assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet"
  1006. # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
  1007. # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
  1008. with torch.inference_mode():
  1009. dqkv = torch.empty_like(qkv)
  1010. _flash_attn_backward(
  1011. do,
  1012. qkv[:, :, 0],
  1013. qkv[:, :, 1],
  1014. qkv[:, :, 2],
  1015. o,
  1016. lse,
  1017. dqkv[:, :, 0],
  1018. dqkv[:, :, 1],
  1019. dqkv[:, :, 2],
  1020. bias=bias,
  1021. causal=ctx.causal,
  1022. softmax_scale=ctx.softmax_scale,
  1023. )
  1024. return dqkv, None, None, None
  1025. flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
  1026. class FlashAttnKVPackedFunc(torch.autograd.Function):
  1027. @staticmethod
  1028. def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
  1029. """
  1030. q: (batch, seqlen_q, nheads, headdim)
  1031. kv: (batch, seqlen_k, 2, nheads, headdim)
  1032. bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
  1033. For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
  1034. ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
  1035. """
  1036. # Make sure that the last dimension is contiguous
  1037. q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
  1038. o, lse, ctx.softmax_scale = _flash_attn_forward(
  1039. q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
  1040. )
  1041. ctx.save_for_backward(q, kv, o, lse, bias)
  1042. ctx.causal = causal
  1043. return o
  1044. @staticmethod
  1045. def backward(ctx, do):
  1046. q, kv, o, lse, bias = ctx.saved_tensors
  1047. if len(ctx.needs_input_grad) >= 3:
  1048. assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet"
  1049. # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
  1050. # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
  1051. with torch.inference_mode():
  1052. dq = torch.empty_like(q)
  1053. dkv = torch.empty_like(kv)
  1054. _flash_attn_backward(
  1055. do,
  1056. q,
  1057. kv[:, :, 0],
  1058. kv[:, :, 1],
  1059. o,
  1060. lse,
  1061. dq,
  1062. dkv[:, :, 0],
  1063. dkv[:, :, 1],
  1064. bias=bias,
  1065. causal=ctx.causal,
  1066. softmax_scale=ctx.softmax_scale,
  1067. )
  1068. return dq, dkv, None, None, None
  1069. flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
  1070. class FlashAttnFunc(torch.autograd.Function):
  1071. @staticmethod
  1072. def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
  1073. """
  1074. q: (batch_size, seqlen_q, nheads, headdim)
  1075. k, v: (batch_size, seqlen_k, nheads, headdim)
  1076. bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
  1077. For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
  1078. ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
  1079. """
  1080. # Make sure that the last dimension is contiguous
  1081. q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
  1082. o, lse, ctx.softmax_scale = _flash_attn_forward(
  1083. q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
  1084. )
  1085. ctx.save_for_backward(q, k, v, o, lse, bias)
  1086. ctx.causal = causal
  1087. return o
  1088. @staticmethod
  1089. def backward(ctx, do):
  1090. q, k, v, o, lse, bias = ctx.saved_tensors
  1091. assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet"
  1092. # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
  1093. # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
  1094. with torch.inference_mode():
  1095. dq = torch.empty_like(q)
  1096. dk = torch.empty_like(k)
  1097. dv = torch.empty_like(v)
  1098. _flash_attn_backward(
  1099. do,
  1100. q,
  1101. k,
  1102. v,
  1103. o,
  1104. lse,
  1105. dq,
  1106. dk,
  1107. dv,
  1108. bias=bias,
  1109. causal=ctx.causal,
  1110. softmax_scale=ctx.softmax_scale,
  1111. )
  1112. return dq, dk, dv, None, None, None
  1113. flash_attn_func = FlashAttnFunc.apply