fwd_prefill.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652
  1. import torch
  2. import triton
  3. import triton.language as tl
  4. from .utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, DEBUG, AUTOTUNE
  5. @triton.jit
  6. def cdiv_fn(x, y):
  7. return (x + y - 1) // y
  8. @triton.jit
  9. def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
  10. ms = tl.arange(0, m)
  11. ns = tl.arange(0, n)
  12. return philox_offset + ms[:, None] * stride + ns[None, :]
  13. @triton.jit
  14. def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
  15. rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)
  16. # TODO: use tl.randint for better performance
  17. return tl.rand(philox_seed, rng_offsets)
  18. @triton.jit
  19. def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
  20. rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
  21. rng_keep = rng_output > dropout_p
  22. return rng_keep
  23. # Convenience function to load with optional boundary checks.
  24. # "First" is the major dim, "second" is the minor dim.
  25. @triton.jit
  26. def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second):
  27. if offset_first is not None and offset_second is not None:
  28. mask = (offset_first[:, None] < boundary_first) & \
  29. (offset_second[None, :] < boundary_second)
  30. tensor = tl.load(ptrs, mask=mask, other=0.0)
  31. elif offset_first is not None:
  32. mask = offset_first[:, None] < boundary_first
  33. tensor = tl.load(ptrs, mask=mask, other=0.0)
  34. elif offset_second is not None:
  35. mask = offset_second[None, :] < boundary_second
  36. tensor = tl.load(ptrs, mask=mask, other=0.0)
  37. else:
  38. tensor = tl.load(ptrs)
  39. return tensor
  40. @triton.jit
  41. def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False):
  42. # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix
  43. # for casual mask we want something like this where (1 is kept and 0 is masked)
  44. # seqlen_q = 2 and seqlen_k = 5
  45. # 1 1 1 1 0
  46. # 1 1 1 1 1
  47. # seqlen_q = 5 and seqlen_k = 2
  48. # 0 0
  49. # 0 0
  50. # 0 0
  51. # 1 0
  52. # 1 1
  53. # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal
  54. # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
  55. # 1. offs_m[:,None] = [[0],
  56. # [1],
  57. # 2. offs_m[:,None] + seqlen_k = [[5],
  58. # [6],
  59. # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
  60. # [4],
  61. # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1],
  62. # [4], [ 4, 3, 2, 1, 0]]
  63. # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
  64. # [ -4, -3, -2, -1, 0]],
  65. relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :]
  66. alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
  67. if transpose:
  68. return alibi_block.T
  69. else:
  70. return alibi_block
  71. @triton.jit
  72. def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m,
  73. actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, exp_scores_ptrs,
  74. block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs,
  75. IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
  76. OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
  77. ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr,
  78. ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_EXP2: tl.constexpr,
  79. RETURN_SCORES: tl.constexpr):
  80. if USE_EXP2:
  81. RCP_LN2: tl.constexpr = 1.4426950408889634
  82. # loop over k, v, and update accumulator
  83. for start_n in range(block_min, block_max, BLOCK_N):
  84. # For padded blocks, we will overrun the tensor size if
  85. # we load all BLOCK_N. For others, the blocks are all within range.
  86. if MASK_STEPS:
  87. k_offs_n = start_n + tl.arange(0, BLOCK_N)
  88. else:
  89. k_offs_n = None
  90. k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)
  91. k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k)
  92. if PRE_LOAD_V:
  93. # We can use the same offsets as k, just with dims transposed.
  94. v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
  95. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  96. # We start from end of seqlen_k so only the first iteration would need
  97. # to be checked for padding if it is not a multiple of block_n
  98. # TODO: This can be optimized to only be true for the padded block.
  99. if MASK_STEPS:
  100. # If this is the last block / iteration, we want to
  101. # mask if the sequence length is not a multiple of block size
  102. # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn.
  103. # last step might get wasted but that is okay. check if this masking works For
  104. # that case.
  105. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
  106. boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
  107. size_n = start_n + OFFS_N[None, :]
  108. mask = size_n < boundary_m[:, None]
  109. qk = tl.where(mask, qk, float("-inf"))
  110. # -- compute qk ----
  111. qk += tl.dot(q, k)
  112. qk_scaled = qk * SM_SCALE
  113. if RETURN_SCORES:
  114. score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
  115. tl.store(score_ptrs, qk_scaled, mask=score_mask)
  116. if IS_CAUSAL:
  117. causal_boundary = start_n + offs_n_causal
  118. causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
  119. qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf"))
  120. if bias_ptrs is not None:
  121. bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None
  122. bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k)
  123. qk_scaled += bias
  124. if alibi_slope is not None:
  125. # Compute the global position of each token within the sequence
  126. global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  127. global_n_positions = start_n + tl.arange(0, BLOCK_N)
  128. alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions,
  129. global_n_positions)
  130. qk_scaled += alibi_block
  131. # get max scores so far
  132. m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1))
  133. # scale and subtract max
  134. q_shifted = qk_scaled - m_ij[:, None]
  135. if RETURN_SCORES:
  136. # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
  137. scores_scaled_shifted_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
  138. tl.store(scores_scaled_shifted_ptrs, q_shifted, mask=scores_scaled_shifted_mask)
  139. # Compute scaled QK and softmax probabilities
  140. if USE_EXP2:
  141. p = tl.math.exp2(q_shifted * RCP_LN2)
  142. else:
  143. p = tl.math.exp(q_shifted)
  144. # CAVEAT: Must update l_ij before applying dropout
  145. l_ij = tl.sum(p, 1)
  146. if ENABLE_DROPOUT:
  147. philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N
  148. keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k)
  149. if RETURN_SCORES:
  150. # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
  151. exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
  152. tl.store(exp_scores_ptrs, tl.where(keep, p, -p), mask=exp_score_mask)
  153. p = tl.where(keep, p, 0.0)
  154. elif RETURN_SCORES:
  155. # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
  156. exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
  157. tl.store(exp_scores_ptrs, p, mask=exp_score_mask)
  158. # -- update output accumulator --
  159. # alpha is an adjustment factor for acc and li as we loop and find new maxes
  160. # store the diff in maxes to adjust acc and li as we discover new maxes
  161. m_diff = m_i - m_ij
  162. if USE_EXP2:
  163. alpha = tl.math.exp2(m_diff * RCP_LN2)
  164. else:
  165. alpha = tl.math.exp(m_diff)
  166. acc = acc * alpha[:, None]
  167. if not PRE_LOAD_V:
  168. v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
  169. # -- update m_i and l_i
  170. l_i = l_i * alpha + l_ij
  171. # update m_i and l_i
  172. m_i = m_ij
  173. acc += tl.dot(p.to(v.type.element_ty), v)
  174. k_ptrs += BLOCK_N * stride_kn
  175. v_ptrs += BLOCK_N * stride_vk
  176. if bias_ptrs is not None:
  177. bias_ptrs += BLOCK_N * stride_bn
  178. if RETURN_SCORES:
  179. score_ptrs += BLOCK_N
  180. scores_scaled_shifted_ptrs += BLOCK_N
  181. exp_scores_ptrs += BLOCK_N
  182. return acc, l_i, m_i
  183. def get_cdna_autotune_configs():
  184. return [
  185. triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
  186. num_warps=4),
  187. triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
  188. num_warps=4),
  189. triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1,
  190. num_warps=4),
  191. triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
  192. num_warps=4),
  193. triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
  194. num_warps=4),
  195. triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
  196. num_warps=4),
  197. # Fall-back config.
  198. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
  199. num_warps=4),
  200. ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']
  201. def get_rdna_autotune_configs():
  202. return [
  203. triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
  204. num_warps=2),
  205. triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
  206. num_warps=2),
  207. triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
  208. num_warps=2),
  209. triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
  210. num_warps=2),
  211. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
  212. num_warps=2),
  213. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
  214. num_warps=2),
  215. # Fall-back config.
  216. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
  217. num_warps=2),
  218. ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']
  219. def get_autotune_configs():
  220. if AUTOTUNE:
  221. if is_rdna():
  222. return get_rdna_autotune_configs()
  223. elif is_cdna():
  224. return get_cdna_autotune_configs()
  225. else:
  226. raise ValueError("Unknown Device Type")
  227. else:
  228. return [
  229. triton.Config(
  230. {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False},
  231. num_stages=1,
  232. num_warps=4,
  233. ),
  234. ], [
  235. "IS_CAUSAL",
  236. "dropout_p",
  237. "MAX_SEQLENS_Q",
  238. "MAX_SEQLENS_K",
  239. "ACTUAL_BLOCK_DMODEL",
  240. "VARLEN",
  241. "HQ",
  242. "HK",
  243. ]
  244. autotune_configs, autotune_keys = get_autotune_configs()
  245. @triton.autotune(
  246. configs=autotune_configs,
  247. key=autotune_keys,
  248. use_cuda_graph=True,
  249. )
  250. @triton.jit
  251. def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
  252. stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
  253. stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah,
  254. stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
  255. dropout_p, philox_seed, philox_offset_base, scores, scores_scaled_shifted, exp_scores, alibi_slopes, HQ: tl.constexpr,
  256. HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
  257. MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
  258. BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
  259. ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr):
  260. start_m = tl.program_id(0)
  261. off_h_q = tl.program_id(1)
  262. off_z = tl.program_id(2)
  263. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  264. offs_n = tl.arange(0, BLOCK_N)
  265. offs_d = tl.arange(0, BLOCK_DMODEL)
  266. if VARLEN:
  267. cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
  268. cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
  269. # print("cu_seqlens_q_start:", cu_seqlens_q_start)
  270. seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
  271. # We have a one-size-fits-all grid in id(0). Some seqlens might be too
  272. # small for all start_m so for those we return early.
  273. if start_m * BLOCK_M > seqlen_q:
  274. return
  275. cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
  276. cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
  277. seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
  278. else:
  279. cu_seqlens_q_start = 0
  280. cu_seqlens_k_start = 0
  281. seqlen_q = MAX_SEQLENS_Q
  282. seqlen_k = MAX_SEQLENS_K
  283. # Now we compute whether we need to exit early due to causal masking.
  284. # This is because for seqlen_q > seqlen_k, M rows of the attn scores
  285. # are completely masked, resulting in 0s written to the output, and
  286. # inf written to LSE. We don't need to do any GEMMs in this case.
  287. # This block of code determines what N is, and if this WG is operating
  288. # on those M rows.
  289. n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
  290. if (IS_CAUSAL):
  291. # If seqlen_q == seqlen_k, the attn scores are a square matrix.
  292. # If seqlen_q != seqlen_k, attn scores are rectangular which means
  293. # the causal mask boundary is bottom right aligned, and ends at either
  294. # the top edge (seqlen_q < seqlen_k) or left edge.
  295. # This captures the decrease in n_blocks if we have a rectangular attn matrix
  296. n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
  297. # This is what adjusts the block_max for the current WG, only
  298. # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
  299. n_blocks = min(n_blocks, n_blocks_seqlen)
  300. # If we have no blocks after adjusting for seqlen deltas, this WG is part of
  301. # the blocks that are all 0. We exit early.
  302. if n_blocks <= 0:
  303. o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om
  304. o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
  305. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
  306. o_ptrs_mask = offs_m[:, None] < seqlen_q
  307. # We still need to write 0s to the result
  308. tl.store(o_ptrs, acc, mask=o_ptrs_mask)
  309. # The tensor allocated for L is based on MAX_SEQLENS_Q as that is
  310. # statically known.
  311. l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m
  312. l_ptrs = l_offset + offs_m * stride_lse_m
  313. l = tl.full([BLOCK_M], value=0.0, dtype=tl.float32)
  314. # mask_m_offsets = start_m + tl.arange(0, BLOCK_M)
  315. # lse_mask = mask_m_offsets < causal_start_idx
  316. # softmax_lse = tl.where(lse_mask, 0.0, softmax_lse)
  317. l_ptrs_mask = offs_m < MAX_SEQLENS_Q
  318. tl.store(l_ptrs, l, mask=l_ptrs_mask)
  319. # TODO: Should dropout and return encoded softmax be handled here too?
  320. return
  321. # If MQA / GQA, set the K and V head offsets appropriately.
  322. GROUP_SIZE: tl.constexpr = HQ // HK
  323. if GROUP_SIZE != 1:
  324. off_h_k = off_h_q // GROUP_SIZE
  325. else:
  326. off_h_k = off_h_q
  327. n_extra_tokens = 0
  328. # print("n_extra_tokens:", n_extra_tokens)
  329. # print("seqlen_k:", seqlen_k)
  330. # print("BLOCK_N:", BLOCK_N)
  331. # return
  332. if seqlen_k < BLOCK_N:
  333. n_extra_tokens = BLOCK_N - seqlen_k
  334. elif seqlen_k % BLOCK_N:
  335. n_extra_tokens = seqlen_k % BLOCK_N
  336. PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
  337. # Compute pointers for all the tensors used in this kernel.
  338. q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
  339. q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
  340. k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
  341. k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn
  342. v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
  343. v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn
  344. if USE_BIAS:
  345. # Note: this might get large enough to overflow on some configs
  346. bias_offset = off_h_q * stride_bh
  347. bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn
  348. else:
  349. bias_ptrs = None
  350. if USE_ALIBI:
  351. a_offset = off_z * stride_az + off_h_q * stride_ah
  352. alibi_slope = tl.load(alibi_slopes + a_offset)
  353. else:
  354. alibi_slope = None
  355. if RETURN_SCORES:
  356. scores_offset = scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
  357. score_ptrs = scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
  358. scores_scaled_shifted_offset = scores_scaled_shifted + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
  359. scores_scaled_shifted_ptrs = scores_scaled_shifted_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
  360. exp_scores_offset = exp_scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
  361. exp_scores_ptrs = exp_scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
  362. else:
  363. score_ptrs = None
  364. scores_scaled_shifted_ptrs = None
  365. exp_scores_ptrs = None
  366. if ENABLE_DROPOUT:
  367. off_hz = off_z * HQ + off_h_q
  368. batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k
  369. else:
  370. batch_philox_offset = 0
  371. # initialize pointer to m and l
  372. m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
  373. l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
  374. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  375. # Q is loaded once at the beginning and shared by all N blocks.
  376. q_ptrs_mask = offs_m[:, None] < seqlen_q
  377. if PADDED_HEAD:
  378. q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL)
  379. q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)
  380. # Here we compute how many full and masked blocks we have.
  381. padded_block_k = n_extra_tokens != 0
  382. is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
  383. if IS_CAUSAL:
  384. # There are always at least BLOCK_M // BLOCK_N masked blocks.
  385. # Additionally there might be one more due to dissimilar seqlens.
  386. masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
  387. else:
  388. # Padding on Q does not need to be masked in the FA loop.
  389. masked_blocks = padded_block_k
  390. # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block.
  391. # In this case we might exceed n_blocks so pick the min.
  392. masked_blocks = min(masked_blocks, n_blocks)
  393. n_full_blocks = n_blocks - masked_blocks
  394. block_min = 0
  395. block_max = n_blocks * BLOCK_N
  396. # Compute for full blocks. Here we set causal to false regardless of its actual
  397. # value because there is no masking. Similarly we do not need padding.
  398. if n_full_blocks > 0:
  399. block_max = (n_blocks - masked_blocks) * BLOCK_N
  400. acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn,
  401. start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset,
  402. exp_scores_ptrs,
  403. # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
  404. block_min, block_max, 0, 0, 0, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs,
  405. # IS_CAUSAL, ....
  406. False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
  407. # _, MASK_STEPS, ...
  408. PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD,
  409. ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
  410. block_min = block_max
  411. block_max = n_blocks * BLOCK_N
  412. tl.debug_barrier()
  413. # Remaining blocks, if any, are full / not masked.
  414. if (masked_blocks > 0):
  415. if IS_CAUSAL:
  416. offs_n_causal = offs_n + (seqlen_q - seqlen_k)
  417. else:
  418. offs_n_causal = 0
  419. k_ptrs += n_full_blocks * BLOCK_N * stride_kn
  420. v_ptrs += n_full_blocks * BLOCK_N * stride_vk
  421. if USE_BIAS:
  422. bias_ptrs += n_full_blocks * BLOCK_N * stride_bn
  423. if RETURN_SCORES:
  424. score_ptrs += n_full_blocks * BLOCK_N
  425. scores_scaled_shifted_ptrs += n_full_blocks * BLOCK_N
  426. exp_scores_ptrs += n_full_blocks * BLOCK_N
  427. acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn,
  428. start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset,
  429. exp_scores_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
  430. n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs,
  431. IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
  432. # _, MASK_STEPS, ...
  433. PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD,
  434. ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
  435. # epilogue
  436. # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger.
  437. l_recip = 1 / l_i[:, None]
  438. acc = acc * l_recip
  439. if ENABLE_DROPOUT:
  440. acc = acc / (1 - dropout_p)
  441. # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
  442. # then we have one block with a row of all NaNs which come from computing
  443. # softmax over a row of all -infs (-inf - inf = NaN). We check for that here
  444. # and store 0s where there are NaNs as these rows should've been zeroed out.
  445. end_m_idx = (start_m + 1) * BLOCK_M
  446. start_m_idx = start_m * BLOCK_M
  447. causal_start_idx = seqlen_q - seqlen_k
  448. acc = acc.to(Out.type.element_ty)
  449. if IS_CAUSAL:
  450. if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
  451. out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32)
  452. mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
  453. out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
  454. z = 0.0
  455. acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
  456. # write back LSE(Log Sum Exponents), the log of the normalization constant
  457. l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m
  458. l_ptrs = l_offset + offs_m * stride_lse_m
  459. if USE_EXP2:
  460. RCP_LN2: tl.constexpr = 1.4426950408889634
  461. LN2: tl.constexpr = 0.6931471824645996
  462. # compute log-sum-exp in base 2 units
  463. mi_base2 = m_i * RCP_LN2
  464. softmax_lse = mi_base2 + tl.math.log2(l_i)
  465. # convert back to natural units
  466. softmax_lse *= LN2
  467. else:
  468. softmax_lse = m_i + tl.math.log(l_i)
  469. if IS_CAUSAL:
  470. # zero out nans caused by -infs when doing causal
  471. lse_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx
  472. softmax_lse = tl.where(lse_mask, 0.0, softmax_lse)
  473. # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows.
  474. # This is only true for the last M block. For others, overflow_size will be -ve
  475. overflow_size = end_m_idx - seqlen_q
  476. if overflow_size > 0:
  477. boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32)
  478. l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary
  479. tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) # the log of the normalization constant
  480. else:
  481. tl.store(l_ptrs, softmax_lse) # the log of the normalization constant
  482. # write back O
  483. o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om
  484. o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
  485. o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1)
  486. if overflow_size > 0:
  487. o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q)
  488. if PADDED_HEAD:
  489. o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL)
  490. tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)
  491. def attention_prefill_forward_triton_impl(
  492. q,
  493. k,
  494. v,
  495. o,
  496. sm_scale,
  497. alibi_slopes,
  498. causal,
  499. bias,
  500. dropout_p,
  501. layout,
  502. cu_seqlens_q,
  503. cu_seqlens_k,
  504. max_seqlens_q,
  505. max_seqlens_k,
  506. return_scores,
  507. use_exp2):
  508. if DEBUG:
  509. print()
  510. print("attention_prefill_forward_triton_impl")
  511. print("q:", q, q.shape)
  512. print("k:", k, k.shape)
  513. print("v:", v, v.shape)
  514. print("o:", o, o.shape)
  515. print("sm_scale:", sm_scale)
  516. print("alibi_slopes:", alibi_slopes)
  517. print("causal:", causal)
  518. print("bias:", bias)
  519. print("dropout_p:", dropout_p)
  520. print("layout:", layout)
  521. print("cu_seqlens_q:", cu_seqlens_q)
  522. print("cu_seqlens_k:", cu_seqlens_k)
  523. print("max_seqlens_q:", max_seqlens_q)
  524. print("max_seqlens_k:", max_seqlens_k)
  525. print("return_scores:", return_scores)
  526. print("use_exp2:", use_exp2)
  527. # check if varlen
  528. is_varlen = layout == "thd"
  529. # NOTE: a large bias tensor leads to overflow during pointer arithmetic
  530. if (bias is not None):
  531. assert (bias.numel() < 2**31)
  532. batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k)
  533. q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
  534. # Get closest power of 2 over or equal to 32.
  535. padded_d_model = 1 << (head_size - 1).bit_length()
  536. # Smallest head_dim supported is 16. If smaller, the tile in the
  537. # kernel is padded - there is no padding in memory for any dims.
  538. padded_d_model = max(padded_d_model, 16)
  539. grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch)
  540. if return_scores:
  541. scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
  542. dtype=torch.float32)
  543. scores_scaled_shifted = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
  544. dtype=torch.float32)
  545. scores_strides = (scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3))
  546. else:
  547. scores = None
  548. scores_scaled_shifted = None
  549. scores_strides = (0, 0 , 0 , 0)
  550. # exp_scores is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out
  551. # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according
  552. # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing
  553. # only. This return holds no useful output aside from debugging.
  554. if return_scores:
  555. exp_scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
  556. dtype=torch.float32)
  557. else:
  558. exp_scores = None
  559. # stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities)
  560. if is_varlen:
  561. softmax_lse = torch.empty((q.shape[0], nheads_q), device=q.device, dtype=torch.float32)
  562. stride_lse_m, stride_lse_h = softmax_lse.stride()
  563. stride_lse_z = 0
  564. else:
  565. softmax_lse = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32)
  566. stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride()
  567. # Seed the RNG so we get reproducible results for testing.
  568. philox_seed = 0x1BF52
  569. philox_offset = 0x1D4B42
  570. if bias is not None:
  571. bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2),
  572. bias.stride(3))
  573. else:
  574. bias_strides = (0, 0, 0, 0)
  575. if alibi_slopes is not None:
  576. alibi_strides = (alibi_slopes.stride(0), alibi_slopes.stride(1))
  577. else:
  578. alibi_strides = (0, 0)
  579. attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
  580. *bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
  581. dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, scores=scores,
  582. scores_scaled_shifted=scores_scaled_shifted, exp_scores=exp_scores, alibi_slopes=alibi_slopes,
  583. HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
  584. MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
  585. BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
  586. USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
  587. > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_scores)
  588. return o, softmax_lse, exp_scores, grid, head_size, philox_seed, philox_offset, scores, scores_scaled_shifted