1
0

triton_flash_attn.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820
  1. #!/usr/bin/env python
  2. """
  3. Fused Attention
  4. ===============
  5. This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
  6. (https://tridao.me/publications/flash2/flash2.pdf)
  7. Credits: OpenAI kernel team, AMD ML Frameworks Triton team
  8. Features supported:
  9. 1) Fwd with causal masking
  10. 2) Any sequence lengths without padding (currently fwd kernel only)
  11. 3) Support for different sequence lengths for q and k
  12. 4) Nested tensor API currently does not support dropout or bias.
  13. Not currently supported:
  14. 1) Non power of two head dims
  15. """
  16. import torch
  17. import triton
  18. import triton.language as tl
  19. torch_dtype: tl.constexpr = torch.float16
  20. @triton.jit
  21. def cdiv_fn(x, y):
  22. return (x + y - 1) // y
  23. @triton.jit
  24. def max_fn(x, y):
  25. return tl.math.max(x, y)
  26. @triton.jit
  27. def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
  28. ms = tl.arange(0, m)
  29. ns = tl.arange(0, n)
  30. return philox_offset + ms[:, None] * stride + ns[None, :]
  31. @triton.jit
  32. def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
  33. rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
  34. stride).to(tl.uint32)
  35. # TODO: use tl.randint for better performance
  36. return tl.rand(philox_seed, rng_offsets)
  37. @triton.jit
  38. def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
  39. rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
  40. stride)
  41. rng_keep = rng_output > dropout_p
  42. return rng_keep
  43. @triton.jit
  44. def load_fn(block_ptr, first, second, pad):
  45. if first and second:
  46. tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
  47. elif first:
  48. tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
  49. elif second:
  50. tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
  51. else:
  52. tensor = tl.load(block_ptr)
  53. return tensor
  54. @triton.jit
  55. def _attn_fwd_inner(
  56. acc,
  57. l_i,
  58. m_i,
  59. q,
  60. K_block_ptr,
  61. V_block_ptr,
  62. start_m,
  63. actual_seqlen_k,
  64. dropout_p,
  65. philox_seed,
  66. batch_philox_offset,
  67. encoded_softmax_block_ptr,
  68. block_min,
  69. block_max,
  70. offs_n_causal,
  71. masked_blocks,
  72. n_extra_tokens,
  73. bias_ptr,
  74. IS_CAUSAL: tl.constexpr,
  75. BLOCK_M: tl.constexpr,
  76. BLOCK_DMODEL: tl.constexpr,
  77. BLOCK_N: tl.constexpr,
  78. OFFS_M: tl.constexpr,
  79. OFFS_N: tl.constexpr,
  80. PRE_LOAD_V: tl.constexpr,
  81. MASK_STEPS: tl.constexpr,
  82. ENABLE_DROPOUT: tl.constexpr,
  83. RETURN_ENCODED_SOFTMAX: tl.constexpr,
  84. PADDED_HEAD: tl.constexpr,
  85. ):
  86. # loop over k, v, and update accumulator
  87. for start_n in range(block_min, block_max, BLOCK_N):
  88. # For padded blocks, we will overrun the tensor size if
  89. # we load all BLOCK_N. For others, the blocks are all within range.
  90. k = load_fn(
  91. K_block_ptr,
  92. PADDED_HEAD,
  93. MASK_STEPS and (n_extra_tokens != 0),
  94. "zero",
  95. )
  96. if PRE_LOAD_V:
  97. v = load_fn(
  98. V_block_ptr,
  99. MASK_STEPS and (n_extra_tokens != 0),
  100. PADDED_HEAD,
  101. "zero",
  102. )
  103. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  104. # We start from end of seqlen_k so only the first iteration would need
  105. # to be checked for padding if it is not a multiple of block_n
  106. # TODO: This can be optimized to only be true for the padded block.
  107. if MASK_STEPS: # noqa: SIM102
  108. # If this is the last block / iteration, we want to
  109. # mask if the sequence length is not a multiple of block size
  110. # a solution is to always do BLOCK_M // BLOCK_N + 1 steps
  111. # if not is_modulo_mn. last step might get wasted but that is okay.
  112. # check if this masking works for that case.
  113. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
  114. boundary_m = tl.full([BLOCK_M],
  115. actual_seqlen_k,
  116. dtype=tl.int32)
  117. size_n = start_n + OFFS_N[None, :]
  118. mask = size_n < boundary_m[:, None]
  119. qk = tl.where(mask, qk, float("-inf"))
  120. if IS_CAUSAL:
  121. causal_boundary = start_n + offs_n_causal
  122. causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
  123. qk = tl.where(causal_mask, qk, float("-inf"))
  124. # -- compute qk ----
  125. qk += tl.dot(q, k)
  126. if bias_ptr is not None:
  127. bias = load_fn(bias_ptr, False, MASK_STEPS
  128. and (n_extra_tokens != 0), "zero")
  129. # While bias is added after multiplying qk with sm_scale, our
  130. # optimization to use 2^x instead of e^x results in an additional
  131. # scale factor of log2(e) which we must also multiply the bias with.
  132. qk += bias * 1.44269504089
  133. m_ij = tl.maximum(m_i, tl.max(qk, 1))
  134. qk = qk - m_ij[:, None]
  135. p = tl.math.exp2(qk)
  136. # CAVEAT: Must update l_ij before applying dropout
  137. l_ij = tl.sum(p, 1)
  138. if ENABLE_DROPOUT:
  139. philox_offset = (batch_philox_offset +
  140. start_m * BLOCK_M * actual_seqlen_k + start_n -
  141. BLOCK_N)
  142. keep = dropout_mask(
  143. philox_seed,
  144. philox_offset,
  145. dropout_p,
  146. BLOCK_M,
  147. BLOCK_N,
  148. actual_seqlen_k,
  149. )
  150. if RETURN_ENCODED_SOFTMAX:
  151. tl.store(
  152. encoded_softmax_block_ptr,
  153. tl.where(keep, p,
  154. -p).to(encoded_softmax_block_ptr.type.element_ty),
  155. )
  156. p = tl.where(keep, p, 0.0)
  157. elif RETURN_ENCODED_SOFTMAX:
  158. tl.store(
  159. encoded_softmax_block_ptr,
  160. p.to(encoded_softmax_block_ptr.type.element_ty),
  161. )
  162. # -- update output accumulator --
  163. alpha = tl.math.exp2(m_i - m_ij)
  164. acc = acc * alpha[:, None]
  165. if not PRE_LOAD_V:
  166. v = load_fn(
  167. V_block_ptr,
  168. MASK_STEPS and (n_extra_tokens != 0),
  169. PADDED_HEAD,
  170. "zero",
  171. )
  172. # -- update m_i and l_i
  173. l_i = l_i * alpha + l_ij
  174. # update m_i and l_i
  175. m_i = m_ij
  176. acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
  177. V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
  178. K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
  179. if bias_ptr is not None:
  180. bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
  181. if RETURN_ENCODED_SOFTMAX:
  182. encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
  183. (0, BLOCK_N))
  184. return acc, l_i, m_i
  185. @triton.autotune(
  186. configs=[
  187. triton.Config(
  188. {
  189. "BLOCK_M": 256,
  190. "BLOCK_N": 64,
  191. "waves_per_eu": 2,
  192. "PRE_LOAD_V": False,
  193. },
  194. num_stages=1,
  195. num_warps=8,
  196. ),
  197. triton.Config(
  198. {
  199. "BLOCK_M": 128,
  200. "BLOCK_N": 128,
  201. "waves_per_eu": 2,
  202. "PRE_LOAD_V": False,
  203. },
  204. num_stages=1,
  205. num_warps=4,
  206. ),
  207. triton.Config(
  208. {
  209. "BLOCK_M": 256,
  210. "BLOCK_N": 128,
  211. "waves_per_eu": 2,
  212. "PRE_LOAD_V": False,
  213. },
  214. num_stages=1,
  215. num_warps=8,
  216. ),
  217. triton.Config(
  218. {
  219. "BLOCK_M": 128,
  220. "BLOCK_N": 64,
  221. "waves_per_eu": 1,
  222. "PRE_LOAD_V": False,
  223. },
  224. num_stages=1,
  225. num_warps=4,
  226. ),
  227. triton.Config(
  228. {
  229. "BLOCK_M": 128,
  230. "BLOCK_N": 64,
  231. "waves_per_eu": 3,
  232. "PRE_LOAD_V": True,
  233. },
  234. num_stages=1,
  235. num_warps=4,
  236. ),
  237. triton.Config(
  238. {
  239. "BLOCK_M": 128,
  240. "BLOCK_N": 64,
  241. "waves_per_eu": 3,
  242. "PRE_LOAD_V": False,
  243. },
  244. num_stages=1,
  245. num_warps=4,
  246. ),
  247. triton.Config(
  248. {
  249. "BLOCK_M": 64,
  250. "BLOCK_N": 64,
  251. "waves_per_eu": 4,
  252. "PRE_LOAD_V": False,
  253. },
  254. num_stages=1,
  255. num_warps=8,
  256. ),
  257. triton.Config(
  258. {
  259. "BLOCK_M": 32,
  260. "BLOCK_N": 32,
  261. "waves_per_eu": 4,
  262. "PRE_LOAD_V": False,
  263. },
  264. num_stages=1,
  265. num_warps=8,
  266. ),
  267. # TODO: This config fails with head_size not pow2 with data mismatches.
  268. # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
  269. # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
  270. triton.Config(
  271. {
  272. "BLOCK_M": 16,
  273. "BLOCK_N": 16,
  274. "waves_per_eu": 1,
  275. "PRE_LOAD_V": False,
  276. },
  277. num_stages=1,
  278. num_warps=4,
  279. ),
  280. ],
  281. key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
  282. )
  283. @triton.jit
  284. def attn_fwd(
  285. Q,
  286. K,
  287. V,
  288. bias,
  289. sm_scale,
  290. L,
  291. Out,
  292. stride_qz,
  293. stride_qh,
  294. stride_qm,
  295. stride_qk,
  296. stride_kz,
  297. stride_kh,
  298. stride_kn,
  299. stride_kk,
  300. stride_vz,
  301. stride_vh,
  302. stride_vk,
  303. stride_vn,
  304. stride_oz,
  305. stride_oh,
  306. stride_om,
  307. stride_on,
  308. stride_bz,
  309. stride_bh,
  310. stride_bm,
  311. stride_bn,
  312. cu_seqlens_q,
  313. cu_seqlens_k,
  314. dropout_p,
  315. philox_seed,
  316. philox_offset_base,
  317. encoded_softmax,
  318. HQ: tl.constexpr,
  319. HK: tl.constexpr,
  320. ACTUAL_BLOCK_DMODEL: tl.constexpr,
  321. MAX_SEQLENS_Q: tl.constexpr,
  322. MAX_SEQLENS_K: tl.constexpr,
  323. VARLEN: tl.constexpr,
  324. IS_CAUSAL: tl.constexpr,
  325. BLOCK_M: tl.constexpr,
  326. BLOCK_DMODEL: tl.constexpr,
  327. BLOCK_N: tl.constexpr,
  328. PRE_LOAD_V: tl.constexpr,
  329. BIAS_TYPE: tl.constexpr,
  330. ENABLE_DROPOUT: tl.constexpr,
  331. RETURN_ENCODED_SOFTMAX: tl.constexpr,
  332. ):
  333. start_m = tl.program_id(0)
  334. off_h_q = tl.program_id(1)
  335. off_z = tl.program_id(2)
  336. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  337. offs_n = tl.arange(0, BLOCK_N)
  338. if VARLEN:
  339. cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
  340. cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
  341. seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
  342. # We have a one-size-fits-all grid in id(0). Some seqlens might be too
  343. # small for all start_m so for those we return early.
  344. if start_m * BLOCK_M > seqlen_q:
  345. return
  346. cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
  347. cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
  348. seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
  349. else:
  350. cu_seqlens_q_start = 0
  351. cu_seqlens_k_start = 0
  352. seqlen_q = MAX_SEQLENS_Q
  353. seqlen_k = MAX_SEQLENS_K
  354. # Now we compute whether we need to exit early due to causal masking.
  355. # This is because for seqlen_q > seqlen_k, M rows of the attn scores
  356. # are completely masked, resulting in 0s written to the output, and
  357. # inf written to LSE. We don't need to do any GEMMs in this case.
  358. # This block of code determines what N is, and if this WG is operating
  359. # on those M rows.
  360. n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
  361. if IS_CAUSAL:
  362. # If seqlen_q == seqlen_k, the attn scores are a square matrix.
  363. # If seqlen_q != seqlen_k, attn scores are rectangular which means
  364. # the causal mask boundary is bottom right aligned, and ends at either
  365. # the top edge (seqlen_q < seqlen_k) or left edge.
  366. # This captures the decrease in n_blocks if we have a rectangular attn
  367. # matrix
  368. n_blocks_seqlen = cdiv_fn(
  369. (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
  370. # This is what adjusts the block_max for the current WG, only
  371. # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
  372. n_blocks = min(n_blocks, n_blocks_seqlen)
  373. # If we have no blocks after adjusting for seqlen deltas, this WG is
  374. # part of the blocks that are all 0. We exit early.
  375. if n_blocks <= 0:
  376. o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
  377. off_h_q * stride_oh)
  378. O_block_ptr = tl.make_block_ptr(
  379. base=Out + o_offset,
  380. shape=(seqlen_q, BLOCK_DMODEL),
  381. strides=(stride_om, stride_on),
  382. offsets=(start_m * BLOCK_M, 0),
  383. block_shape=(BLOCK_M, BLOCK_DMODEL),
  384. order=(1, 0),
  385. )
  386. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
  387. # We still need to write 0s to the result
  388. # tl.store(O_block_ptr,
  389. # acc.to(Out.type.element_ty), boundary_check=(0,1))
  390. # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
  391. # + offs_m
  392. # We store inf to LSE, not -inf because in the bwd pass,
  393. # we subtract this
  394. # from qk which makes it -inf, such that exp(qk - inf) = 0
  395. # for these masked blocks.
  396. # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
  397. # tl.store(l_ptrs, l)
  398. # TODO: Should dropout and return encoded softmax be handled here?
  399. return
  400. # If MQA / GQA, set the K and V head offsets appropriately.
  401. GROUP_SIZE: tl.constexpr = HQ // HK
  402. off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
  403. n_extra_tokens = 0
  404. if seqlen_k < BLOCK_N:
  405. n_extra_tokens = BLOCK_N - seqlen_k
  406. elif seqlen_k % BLOCK_N:
  407. n_extra_tokens = seqlen_k % BLOCK_N
  408. padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
  409. # Compute pointers for all the tensors used in this kernel.
  410. q_offset = (off_z * stride_qz + off_h_q * stride_qh +
  411. cu_seqlens_q_start * stride_qm)
  412. Q_block_ptr = tl.make_block_ptr(
  413. base=Q + q_offset,
  414. shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
  415. strides=(stride_qm, stride_qk),
  416. offsets=(start_m * BLOCK_M, 0),
  417. block_shape=(BLOCK_M, BLOCK_DMODEL),
  418. order=(1, 0),
  419. )
  420. k_offset = (off_z * stride_kz + off_h_k * stride_kh +
  421. cu_seqlens_k_start * stride_kn)
  422. K_block_ptr = tl.make_block_ptr(
  423. base=K + k_offset,
  424. shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
  425. strides=(stride_kk, stride_kn),
  426. offsets=(0, 0),
  427. block_shape=(BLOCK_DMODEL, BLOCK_N),
  428. order=(0, 1),
  429. )
  430. v_offset = (off_z * stride_vz + off_h_k * stride_vh +
  431. cu_seqlens_k_start * stride_vk)
  432. V_block_ptr = tl.make_block_ptr(
  433. base=V + v_offset,
  434. shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
  435. strides=(stride_vk, stride_vn),
  436. offsets=(0, 0),
  437. block_shape=(BLOCK_N, BLOCK_DMODEL),
  438. order=(1, 0),
  439. )
  440. if BIAS_TYPE != 0:
  441. bias_ptr = tl.make_block_ptr(
  442. base=bias + off_h_q * stride_bh,
  443. shape=(seqlen_q, seqlen_k),
  444. strides=(stride_bm, stride_bn),
  445. offsets=(start_m * BLOCK_M, 0),
  446. block_shape=(BLOCK_M, BLOCK_N),
  447. order=(1, 0),
  448. )
  449. else:
  450. bias_ptr = None
  451. if ENABLE_DROPOUT:
  452. batch_philox_offset = philox_offset_base \
  453. + (off_z * HQ + off_h_q) \
  454. * seqlen_q * seqlen_k
  455. else:
  456. batch_philox_offset = 0
  457. # We can ask to return the dropout mask without actually doing any dropout.
  458. # In this case, we return an invalid pointer so indicate the mask is not i
  459. # valid.
  460. # TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
  461. if RETURN_ENCODED_SOFTMAX:
  462. encoded_softmax_block_ptr = tl.make_block_ptr(
  463. base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
  464. shape=(seqlen_q, seqlen_k),
  465. strides=(seqlen_k, 1),
  466. offsets=(start_m * BLOCK_M, 0),
  467. block_shape=(BLOCK_M, BLOCK_N),
  468. order=(1, 0),
  469. )
  470. else:
  471. encoded_softmax_block_ptr = 0
  472. # initialize pointer to m and l
  473. m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
  474. l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
  475. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  476. # scale sm_scale by log_2(e) and use 2^x in the loop as we do not
  477. # have native e^x support in HW.
  478. qk_scale = sm_scale * 1.44269504089
  479. # Q is loaded once at the beginning and shared by all N blocks.
  480. q = load_fn(Q_block_ptr, True, padded_head, "zero")
  481. q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
  482. # Here we compute how many full and masked blocks we have.
  483. padded_block_k = n_extra_tokens != 0
  484. is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
  485. if IS_CAUSAL:
  486. # There are always at least BLOCK_M // BLOCK_N masked blocks.
  487. # Additionally there might be one more due to dissimilar seqlens.
  488. masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
  489. else:
  490. # Padding on Q does not need to be masked in the FA loop.
  491. masked_blocks = padded_block_k
  492. # if IS_CAUSAL, not is_modulo_mn does not always result in an additional
  493. # block. In this case we might exceed n_blocks so pick the min.
  494. masked_blocks = min(masked_blocks, n_blocks)
  495. n_full_blocks = n_blocks - masked_blocks
  496. block_min = 0
  497. block_max = n_blocks * BLOCK_N
  498. # Compute for full blocks. Here we set causal to false regardless of its
  499. # value because there is no masking. Similarly we do not need padding.
  500. if n_full_blocks > 0:
  501. block_max = (n_blocks - masked_blocks) * BLOCK_N
  502. acc, l_i, m_i = _attn_fwd_inner(
  503. acc,
  504. l_i,
  505. m_i,
  506. q,
  507. K_block_ptr,
  508. V_block_ptr,
  509. start_m,
  510. seqlen_k,
  511. dropout_p,
  512. philox_seed,
  513. batch_philox_offset,
  514. encoded_softmax_block_ptr,
  515. # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
  516. block_min,
  517. block_max,
  518. 0,
  519. 0,
  520. 0,
  521. bias_ptr,
  522. # IS_CAUSAL, ....
  523. False,
  524. BLOCK_M,
  525. BLOCK_DMODEL,
  526. BLOCK_N,
  527. offs_m,
  528. offs_n,
  529. # _, MASK_STEPS, ...
  530. PRE_LOAD_V,
  531. False,
  532. ENABLE_DROPOUT,
  533. RETURN_ENCODED_SOFTMAX,
  534. padded_head,
  535. )
  536. block_min = block_max
  537. block_max = n_blocks * BLOCK_N
  538. tl.debug_barrier()
  539. # Remaining blocks, if any, are full / not masked.
  540. if masked_blocks > 0:
  541. offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
  542. K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
  543. V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
  544. if bias_ptr is not None:
  545. bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
  546. if RETURN_ENCODED_SOFTMAX:
  547. encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
  548. (0, n_full_blocks))
  549. acc, l_i, m_i = _attn_fwd_inner(
  550. acc,
  551. l_i,
  552. m_i,
  553. q,
  554. K_block_ptr,
  555. V_block_ptr,
  556. start_m,
  557. seqlen_k,
  558. dropout_p,
  559. philox_seed,
  560. batch_philox_offset,
  561. encoded_softmax_block_ptr,
  562. block_min,
  563. block_max,
  564. offs_n_causal,
  565. masked_blocks,
  566. n_extra_tokens,
  567. bias_ptr,
  568. IS_CAUSAL,
  569. BLOCK_M,
  570. BLOCK_DMODEL,
  571. BLOCK_N,
  572. offs_m,
  573. offs_n,
  574. # _, MASK_STEPS, ...
  575. PRE_LOAD_V,
  576. True,
  577. ENABLE_DROPOUT,
  578. RETURN_ENCODED_SOFTMAX,
  579. padded_head,
  580. )
  581. # epilogue
  582. acc = acc / l_i[:, None]
  583. if ENABLE_DROPOUT:
  584. acc = acc / (1 - dropout_p)
  585. # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
  586. # then we have one block with a row of all NaNs which come from computing
  587. # softmax over a row of all -infs (-inf - inf = NaN). We check for that here
  588. # and store 0s where there are NaNs as these rows should've been zeroed out.
  589. end_m_idx = (start_m + 1) * BLOCK_M
  590. start_m_idx = start_m * BLOCK_M
  591. causal_start_idx = seqlen_q - seqlen_k
  592. acc = acc.to(Out.type.element_ty)
  593. if IS_CAUSAL: # noqa: SIM102
  594. if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
  595. out_mask_boundary = tl.full((BLOCK_DMODEL, ),
  596. causal_start_idx,
  597. dtype=tl.int32)
  598. mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
  599. out_ptrs_mask = (mask_m_offsets[:, None] >=
  600. out_mask_boundary[None, :])
  601. z = 0.0
  602. acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
  603. # write back LSE
  604. # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
  605. # If seqlen_q not multiple of BLOCK_M, we need to mask out the last
  606. # few rows. This is only true for the last M block. For others,
  607. # overflow_size will be -ve
  608. # overflow_size = end_m_idx - seqlen_q
  609. # if overflow_size > 0:
  610. # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
  611. # # This is a > check because mask being 0 blocks the store.
  612. # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
  613. # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
  614. # else:
  615. # tl.store(l_ptrs, m_i + tl.math.log2(l_i))
  616. # write back O
  617. o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
  618. off_h_q * stride_oh)
  619. O_block_ptr = tl.make_block_ptr(
  620. base=Out + o_offset,
  621. shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
  622. strides=(stride_om, stride_on),
  623. offsets=(start_m * BLOCK_M, 0),
  624. block_shape=(BLOCK_M, BLOCK_DMODEL),
  625. order=(1, 0),
  626. )
  627. # Need boundary check on this to make sure the padding from the
  628. # Q and KV tensors in both dims are not part of what we store back.
  629. # TODO: Do the boundary check optionally.
  630. tl.store(O_block_ptr, acc, boundary_check=(0, 1))
  631. def check_args(
  632. q,
  633. k,
  634. v,
  635. o,
  636. varlen=True,
  637. max_seqlens=None,
  638. cu_seqlens_q=None,
  639. cu_seqlens_k=None,
  640. ):
  641. assert q.dim() == k.dim() and q.dim() == v.dim()
  642. if varlen:
  643. assert q.dim() == 3
  644. total_q, nheads_q, head_size = q.shape
  645. total_k, nheads_k, _ = k.shape
  646. assert cu_seqlens_q is not None
  647. assert cu_seqlens_k is not None
  648. assert len(cu_seqlens_q) == len(cu_seqlens_k)
  649. else:
  650. assert q.dim() == 4
  651. batch, nheads_q, seqlen_q, head_size = q.shape
  652. _, nheads_k, seqlen_k, _ = k.shape
  653. assert max_seqlens > 0
  654. assert k.shape == v.shape
  655. assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
  656. # TODO: Change assert if we support qkl f8 and v f16
  657. assert q.dtype == k.dtype and q.dtype == v.dtype
  658. assert head_size <= 256
  659. assert o.shape == q.shape
  660. assert (nheads_q % nheads_k) == 0
  661. class _attention(torch.autograd.Function):
  662. @staticmethod
  663. def forward(
  664. ctx,
  665. q,
  666. k,
  667. v,
  668. o,
  669. cu_seqlens_q,
  670. cu_seqlens_k,
  671. max_seqlens_q,
  672. max_seqlens_k,
  673. causal=False,
  674. sm_scale=1.0,
  675. bias=None,
  676. ):
  677. if o is None:
  678. o = torch.empty_like(q, dtype=v.dtype)
  679. check_args(
  680. q,
  681. k,
  682. v,
  683. o,
  684. varlen=True,
  685. cu_seqlens_q=cu_seqlens_q,
  686. cu_seqlens_k=cu_seqlens_k,
  687. )
  688. if True: # varlen
  689. total_q, nheads_q, head_size = q.shape
  690. total_k, nheads_k, _ = k.shape
  691. batch = len(cu_seqlens_q) - 1
  692. q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
  693. k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
  694. v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
  695. o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
  696. else:
  697. batch, seqlen_q, nheads_q, head_size = q.shape
  698. _, seqlen_k, nheads_k, _ = k.shape
  699. q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
  700. k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
  701. v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
  702. o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
  703. # Get closest power of 2 over or equal to 32.
  704. unpadded_head_dims = {32, 64, 128, 256}
  705. if head_size not in unpadded_head_dims:
  706. padded_d_model = None
  707. for i in unpadded_head_dims:
  708. if i > head_size:
  709. padded_d_model = i
  710. break
  711. assert padded_d_model is not None
  712. else:
  713. padded_d_model = head_size
  714. grid = lambda META: (
  715. triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
  716. nheads_q,
  717. batch,
  718. )
  719. encoded_softmax = None
  720. # Seed the RNG so we get reproducible results for testing.
  721. philox_seed = 0x1BF52
  722. philox_offset = 0x1D4B42
  723. if bias is not None:
  724. bias_strides = (
  725. bias.stride(0),
  726. bias.stride(1),
  727. bias.stride(2),
  728. bias.stride(3),
  729. )
  730. else:
  731. bias_strides = (0, 0, 0, 0)
  732. attn_fwd[grid](
  733. q,
  734. k,
  735. v,
  736. bias,
  737. sm_scale,
  738. None,
  739. o,
  740. *q_strides,
  741. *k_strides,
  742. *v_strides,
  743. *o_strides,
  744. *bias_strides,
  745. cu_seqlens_q,
  746. cu_seqlens_k,
  747. dropout_p=0.0,
  748. philox_seed=philox_seed,
  749. philox_offset_base=philox_offset,
  750. encoded_softmax=encoded_softmax,
  751. HQ=nheads_q,
  752. HK=nheads_k,
  753. ACTUAL_BLOCK_DMODEL=head_size,
  754. MAX_SEQLENS_Q=max_seqlens_q,
  755. MAX_SEQLENS_K=max_seqlens_k,
  756. IS_CAUSAL=causal,
  757. VARLEN=True,
  758. BLOCK_DMODEL=padded_d_model,
  759. BIAS_TYPE=0 if bias is None else 1,
  760. ENABLE_DROPOUT=False,
  761. RETURN_ENCODED_SOFTMAX=False,
  762. )
  763. ctx.grid = grid
  764. ctx.sm_scale = sm_scale
  765. ctx.BLOCK_DMODEL = head_size
  766. ctx.causal = causal
  767. ctx.dropout_p = 0.0
  768. ctx.philox_seed = philox_seed
  769. ctx.philox_offset = philox_offset
  770. ctx.encoded_softmax = encoded_softmax
  771. ctx.return_encoded_softmax = False
  772. return o, encoded_softmax
  773. triton_attention = _attention.apply