flash_attn_triton_og.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. # [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
  2. # for benchmarking.
  3. # We fixed a few dtype cast to make it work for bf16
  4. """
  5. Fused Attention
  6. ===============
  7. This is a Triton implementation of the Flash Attention algorithm
  8. (see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
  9. """
  10. import pytest
  11. import torch
  12. import triton
  13. import triton.language as tl
  14. @triton.jit
  15. def _fwd_kernel(
  16. Q,
  17. K,
  18. V,
  19. sm_scale,
  20. TMP,
  21. L,
  22. M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
  23. Out,
  24. stride_qz,
  25. stride_qh,
  26. stride_qm,
  27. stride_qk,
  28. stride_kz,
  29. stride_kh,
  30. stride_kn,
  31. stride_kk,
  32. stride_vz,
  33. stride_vh,
  34. stride_vk,
  35. stride_vn,
  36. stride_oz,
  37. stride_oh,
  38. stride_om,
  39. stride_on,
  40. Z,
  41. H,
  42. N_CTX,
  43. BLOCK_M: tl.constexpr,
  44. BLOCK_DMODEL: tl.constexpr,
  45. BLOCK_N: tl.constexpr,
  46. ):
  47. start_m = tl.program_id(0)
  48. off_hz = tl.program_id(1)
  49. # initialize offsets
  50. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  51. offs_n = tl.arange(0, BLOCK_N)
  52. offs_d = tl.arange(0, BLOCK_DMODEL)
  53. off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
  54. off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
  55. off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
  56. # Initialize pointers to Q, K, V
  57. q_ptrs = Q + off_q
  58. k_ptrs = K + off_k
  59. v_ptrs = V + off_v
  60. # initialize pointer to m and l
  61. t_ptrs = TMP + off_hz * N_CTX + offs_m
  62. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  63. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  64. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  65. # load q: it will stay in SRAM throughout
  66. q = tl.load(q_ptrs)
  67. # loop over k, v and update accumulator
  68. for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
  69. start_n = tl.multiple_of(start_n, BLOCK_N)
  70. # -- compute qk ----
  71. k = tl.load(k_ptrs + start_n * stride_kn)
  72. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  73. qk += tl.dot(q, k, trans_b=True)
  74. qk *= sm_scale
  75. qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
  76. # -- compute m_ij, p, l_ij
  77. m_ij = tl.max(qk, 1)
  78. p = tl.exp(qk - m_ij[:, None])
  79. l_ij = tl.sum(p, 1)
  80. # -- update m_i and l_i
  81. m_i_new = tl.maximum(m_i, m_ij)
  82. alpha = tl.exp(m_i - m_i_new)
  83. beta = tl.exp(m_ij - m_i_new)
  84. l_i_new = alpha * l_i + beta * l_ij
  85. # -- update output accumulator --
  86. # scale p
  87. p_scale = beta / l_i_new
  88. p = p * p_scale[:, None]
  89. # scale acc
  90. acc_scale = l_i / l_i_new * alpha
  91. tl.store(t_ptrs, acc_scale)
  92. acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
  93. acc = acc * acc_scale[:, None]
  94. # update acc
  95. v = tl.load(v_ptrs + start_n * stride_vk)
  96. p = p.to(v.dtype)
  97. acc += tl.dot(p, v)
  98. # update m_i and l_i
  99. l_i = l_i_new
  100. m_i = m_i_new
  101. # rematerialize offsets to save registers
  102. start_m = tl.program_id(0)
  103. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  104. # write back l and m
  105. l_ptrs = L + off_hz * N_CTX + offs_m
  106. m_ptrs = M + off_hz * N_CTX + offs_m
  107. tl.store(l_ptrs, l_i)
  108. tl.store(m_ptrs, m_i)
  109. # initialize pointers to output
  110. offs_n = tl.arange(0, BLOCK_DMODEL)
  111. off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
  112. out_ptrs = Out + off_o
  113. tl.store(out_ptrs, acc)
  114. @triton.jit
  115. def _bwd_preprocess(
  116. Out,
  117. DO,
  118. L,
  119. NewDO,
  120. Delta,
  121. BLOCK_M: tl.constexpr,
  122. D_HEAD: tl.constexpr,
  123. ):
  124. off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
  125. off_n = tl.arange(0, D_HEAD)
  126. # load
  127. o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
  128. do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
  129. denom = tl.load(L + off_m).to(tl.float32)
  130. # compute
  131. do = do / denom[:, None]
  132. delta = tl.sum(o * do, axis=1)
  133. # write-back
  134. tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
  135. tl.store(Delta + off_m, delta)
  136. @triton.jit
  137. def _bwd_kernel(
  138. Q,
  139. K,
  140. V,
  141. sm_scale,
  142. Out,
  143. DO,
  144. DQ,
  145. DK,
  146. DV,
  147. L,
  148. M,
  149. D,
  150. stride_qz,
  151. stride_qh,
  152. stride_qm,
  153. stride_qk,
  154. stride_kz,
  155. stride_kh,
  156. stride_kn,
  157. stride_kk,
  158. stride_vz,
  159. stride_vh,
  160. stride_vk,
  161. stride_vn,
  162. Z,
  163. H,
  164. N_CTX,
  165. num_block,
  166. BLOCK_M: tl.constexpr,
  167. BLOCK_DMODEL: tl.constexpr,
  168. BLOCK_N: tl.constexpr,
  169. ):
  170. off_hz = tl.program_id(0)
  171. off_z = off_hz // H
  172. off_h = off_hz % H
  173. # offset pointers for batch/head
  174. Q += off_z * stride_qz + off_h * stride_qh
  175. K += off_z * stride_qz + off_h * stride_qh
  176. V += off_z * stride_qz + off_h * stride_qh
  177. DO += off_z * stride_qz + off_h * stride_qh
  178. DQ += off_z * stride_qz + off_h * stride_qh
  179. DK += off_z * stride_qz + off_h * stride_qh
  180. DV += off_z * stride_qz + off_h * stride_qh
  181. for start_n in range(0, num_block):
  182. lo = start_n * BLOCK_M
  183. # initialize row/col offsets
  184. offs_qm = lo + tl.arange(0, BLOCK_M)
  185. offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
  186. offs_m = tl.arange(0, BLOCK_N)
  187. offs_k = tl.arange(0, BLOCK_DMODEL)
  188. # initialize pointers to value-like data
  189. q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
  190. k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
  191. v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
  192. do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
  193. dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
  194. # pointer to row-wise quantities in value-like data
  195. D_ptrs = D + off_hz * N_CTX
  196. m_ptrs = M + off_hz * N_CTX
  197. # initialize dv amd dk
  198. dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  199. dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  200. # k and v stay in SRAM throughout
  201. k = tl.load(k_ptrs)
  202. v = tl.load(v_ptrs)
  203. # loop over rows
  204. for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
  205. offs_m_curr = start_m + offs_m
  206. # load q, k, v, do on-chip
  207. q = tl.load(q_ptrs)
  208. # recompute p = softmax(qk, dim=-1).T
  209. # NOTE: `do` is pre-divided by `l`; no normalization here
  210. qk = tl.dot(q, k, trans_b=True)
  211. qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
  212. m = tl.load(m_ptrs + offs_m_curr)
  213. p = tl.exp(qk * sm_scale - m[:, None])
  214. # compute dv
  215. do = tl.load(do_ptrs)
  216. dv += tl.dot(p.to(do.dtype), do, trans_a=True)
  217. # compute dp = dot(v, do)
  218. Di = tl.load(D_ptrs + offs_m_curr)
  219. dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
  220. dp += tl.dot(do, v, trans_b=True)
  221. # compute ds = p * (dp - delta[:, None])
  222. ds = p * dp * sm_scale
  223. # compute dk = dot(ds.T, q)
  224. dk += tl.dot(ds.to(q.dtype), q, trans_a=True)
  225. # # compute dq
  226. dq = tl.load(dq_ptrs, eviction_policy="evict_last")
  227. dq += tl.dot(ds.to(k.dtype), k)
  228. tl.store(dq_ptrs, dq, eviction_policy="evict_last")
  229. # # increment pointers
  230. dq_ptrs += BLOCK_M * stride_qm
  231. q_ptrs += BLOCK_M * stride_qm
  232. do_ptrs += BLOCK_M * stride_qm
  233. # write-back
  234. dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
  235. dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
  236. tl.store(dv_ptrs, dv)
  237. tl.store(dk_ptrs, dk)
  238. class _attention(torch.autograd.Function):
  239. @staticmethod
  240. def forward(ctx, q, k, v, sm_scale):
  241. BLOCK = 128
  242. # shape constraints
  243. Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
  244. assert Lq == Lk and Lk == Lv
  245. assert Lk in {16, 32, 64, 128}
  246. o = torch.empty_like(q)
  247. grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
  248. tmp = torch.empty(
  249. (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
  250. )
  251. L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
  252. m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
  253. num_warps = 4 if Lk <= 64 else 8
  254. _fwd_kernel[grid](
  255. q,
  256. k,
  257. v,
  258. sm_scale,
  259. tmp,
  260. L,
  261. m,
  262. o,
  263. q.stride(0),
  264. q.stride(1),
  265. q.stride(2),
  266. q.stride(3),
  267. k.stride(0),
  268. k.stride(1),
  269. k.stride(2),
  270. k.stride(3),
  271. v.stride(0),
  272. v.stride(1),
  273. v.stride(2),
  274. v.stride(3),
  275. o.stride(0),
  276. o.stride(1),
  277. o.stride(2),
  278. o.stride(3),
  279. q.shape[0],
  280. q.shape[1],
  281. q.shape[2],
  282. BLOCK_M=BLOCK,
  283. BLOCK_N=BLOCK,
  284. BLOCK_DMODEL=Lk,
  285. num_warps=num_warps,
  286. num_stages=1,
  287. )
  288. ctx.save_for_backward(q, k, v, o, L, m)
  289. ctx.BLOCK = BLOCK
  290. ctx.grid = grid
  291. ctx.sm_scale = sm_scale
  292. ctx.BLOCK_DMODEL = Lk
  293. return o
  294. @staticmethod
  295. def backward(ctx, do):
  296. q, k, v, o, l, m = ctx.saved_tensors
  297. do = do.contiguous()
  298. dq = torch.zeros_like(q, dtype=torch.float32)
  299. dk = torch.empty_like(k)
  300. dv = torch.empty_like(v)
  301. do_scaled = torch.empty_like(do)
  302. delta = torch.empty_like(l)
  303. _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
  304. o,
  305. do,
  306. l,
  307. do_scaled,
  308. delta,
  309. BLOCK_M=ctx.BLOCK,
  310. D_HEAD=ctx.BLOCK_DMODEL,
  311. )
  312. # NOTE: kernel currently buggy for other values of `num_warps`
  313. num_warps = 8
  314. _bwd_kernel[(ctx.grid[1],)](
  315. q,
  316. k,
  317. v,
  318. ctx.sm_scale,
  319. o,
  320. do_scaled,
  321. dq,
  322. dk,
  323. dv,
  324. l,
  325. m,
  326. delta,
  327. q.stride(0),
  328. q.stride(1),
  329. q.stride(2),
  330. q.stride(3),
  331. k.stride(0),
  332. k.stride(1),
  333. k.stride(2),
  334. k.stride(3),
  335. v.stride(0),
  336. v.stride(1),
  337. v.stride(2),
  338. v.stride(3),
  339. q.shape[0],
  340. q.shape[1],
  341. q.shape[2],
  342. ctx.grid[0],
  343. BLOCK_M=ctx.BLOCK,
  344. BLOCK_N=ctx.BLOCK,
  345. BLOCK_DMODEL=ctx.BLOCK_DMODEL,
  346. num_warps=num_warps,
  347. num_stages=1,
  348. )
  349. return dq.to(q.dtype), dk, dv, None
  350. attention = _attention.apply