blocksparse_attention_kernel.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. import torch
  2. import triton
  3. import triton.language as tl
  4. def blocksparse_flash_attn_varlen_fwd(
  5. q,
  6. k,
  7. v, # (#tokens, n_heads, head_size)
  8. cu_seqlens_k,
  9. cu_seqlens_q,
  10. sm_scale,
  11. sparse_layout,
  12. *,
  13. block_size=64,
  14. q_block_size=None,
  15. max_seqlen=None):
  16. # split q to blocks
  17. assert isinstance(sparse_layout, (list, tuple))
  18. _, n_heads, head_size = q.shape
  19. batch_size = cu_seqlens_k.size(0) - 1
  20. q_block_size = q_block_size or block_size
  21. assert q.dim() == k.dim() == v.dim() == 3
  22. assert q.size(1) % k.size(1) == 0
  23. assert q.size(2) == k.size(2)
  24. # TODO: allow k, v to have different head_size
  25. assert k.shape == v.shape
  26. assert cu_seqlens_k.dim() == 1
  27. q_k_ratio = q.size(1) // k.size(1)
  28. if cu_seqlens_q is None:
  29. if q.size(0) == batch_size: # decoding only
  30. cu_seqlens_q = torch.arange(
  31. 0,
  32. batch_size + 1,
  33. dtype=cu_seqlens_k.dtype,
  34. device=cu_seqlens_k.device,
  35. )
  36. elif q.size(0) == k.size(0):
  37. cu_seqlens_q = cu_seqlens_k
  38. else:
  39. raise ValueError("cu_seqlens_q must be specified\
  40. if it mix of prefilling and decoding.")
  41. else:
  42. assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
  43. # switch to use cpu to avoid too many kernel launches when iterated over
  44. q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
  45. k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
  46. assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
  47. "length of q should either be 1 (decoding) or same as k (prefilling).")
  48. if max_seqlen:
  49. assert k_lens.max() <= max_seqlen
  50. n_blocks = (q_lens + q_block_size - 1) // q_block_size
  51. q_batch_ids = torch.tensor(
  52. [i for i, n in enumerate(n_blocks) for _ in range(n)],
  53. dtype=cu_seqlens_q.dtype,
  54. device=cu_seqlens_q.device,
  55. )
  56. q_start_sids = torch.tensor(
  57. [i * q_block_size for n in n_blocks for i in range(n)],
  58. dtype=cu_seqlens_q.dtype,
  59. device=cu_seqlens_q.device,
  60. )
  61. out = q.new_empty(q.shape)
  62. cu_seqlens_q = cu_seqlens_q.contiguous()
  63. cu_seqlens_k = cu_seqlens_k.contiguous()
  64. layout_crow_indices, layout_col_indices = sparse_layout
  65. block_d = triton.next_power_of_2(head_size)
  66. decoding_only = (q_lens == 1).all().item()
  67. grid = (len(q_start_sids), n_heads, 1)
  68. _fwd_kernel_batch_inference[grid](
  69. q,
  70. k,
  71. v,
  72. out,
  73. sm_scale,
  74. cu_seqlens_q[:-1],
  75. cu_seqlens_q[1:],
  76. cu_seqlens_k[:-1],
  77. cu_seqlens_k[1:],
  78. q_batch_ids,
  79. q_start_sids,
  80. 0,
  81. *q.stride(),
  82. 0,
  83. *k.stride(),
  84. 0,
  85. *v.stride(),
  86. 0,
  87. *out.stride(),
  88. layout_crow_indices,
  89. layout_col_indices,
  90. *layout_crow_indices.stride(),
  91. *layout_col_indices.stride(),
  92. q_k_ratio,
  93. HAS_BATCH_DIM=False,
  94. D_HEAD=head_size,
  95. BLOCK_M=q_block_size,
  96. BLOCK_N=block_size,
  97. BLOCK_D=block_d,
  98. BLOCK_M_LOADING=(16 if decoding_only else
  99. q_block_size), # smaller for decoding
  100. EVEN_D=block_d == head_size,
  101. num_warps=1 if decoding_only else 4,
  102. num_stages=3)
  103. return out
  104. @triton.jit
  105. def _fwd_kernel_inner(
  106. acc,
  107. l_i,
  108. m_i,
  109. q,
  110. Q,
  111. k_block_col_idx,
  112. layout_col_ptr,
  113. layout_col_stride_h,
  114. layout_col_stride_m,
  115. k_ptrs,
  116. v_ptrs,
  117. off_h,
  118. offs_m,
  119. offs_n,
  120. offs_d,
  121. stride_kt,
  122. stride_vt,
  123. sm_scale,
  124. k_seqlen,
  125. past_len,
  126. LAST_K_BLOCK: tl.constexpr,
  127. BLOCK_M_LOADING: tl.constexpr,
  128. BLOCK_N: tl.constexpr,
  129. D_HEAD: tl.constexpr,
  130. EVEN_D: tl.constexpr,
  131. M_LT_N: tl.constexpr,
  132. ):
  133. k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
  134. k_block_col_idx * layout_col_stride_m).to(tl.int32)
  135. start_n = k_block_id * BLOCK_N
  136. if LAST_K_BLOCK:
  137. if EVEN_D:
  138. k = tl.load(
  139. k_ptrs + start_n * stride_kt,
  140. mask=offs_n[None, :] + start_n < k_seqlen,
  141. )
  142. else:
  143. k = tl.load(
  144. k_ptrs + start_n * stride_kt,
  145. mask=(offs_n[None, :] + start_n < k_seqlen) &
  146. (offs_d[:, None] < D_HEAD),
  147. )
  148. else:
  149. if EVEN_D:
  150. k = tl.load(k_ptrs + start_n * stride_kt)
  151. else:
  152. k = tl.load(k_ptrs + start_n * stride_kt,
  153. mask=offs_d[:, None] < D_HEAD)
  154. qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
  155. qk += tl.dot(q, k)
  156. qk *= sm_scale
  157. # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
  158. if LAST_K_BLOCK | M_LT_N:
  159. qk += tl.where(
  160. offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
  161. 0,
  162. float("-inf"),
  163. )
  164. # flash-attn2
  165. m_ij = tl.maximum(m_i, tl.max(qk, 1))
  166. p = tl.math.exp2(qk - m_ij[:, None])
  167. l_ij = tl.sum(p, 1)
  168. alpha = tl.math.exp2(m_i - m_ij)
  169. acc = acc * alpha[:, None]
  170. # update m_i
  171. m_i = m_ij
  172. l_i = l_i * alpha + l_ij
  173. p = p.to(Q.dtype.element_ty)
  174. # update acc
  175. if LAST_K_BLOCK:
  176. if EVEN_D:
  177. v = tl.load(
  178. v_ptrs + start_n * stride_vt,
  179. mask=offs_n[:, None] + start_n < k_seqlen,
  180. )
  181. else:
  182. v = tl.load(
  183. v_ptrs + start_n * stride_vt,
  184. mask=(offs_n[:, None] + start_n < k_seqlen) &
  185. (offs_d[None, :] < D_HEAD),
  186. )
  187. else:
  188. if EVEN_D:
  189. v = tl.load(v_ptrs + start_n * stride_vt)
  190. else:
  191. v = tl.load(v_ptrs + start_n * stride_vt,
  192. mask=offs_d[None, :] < D_HEAD)
  193. acc += tl.dot(p, v)
  194. return acc, l_i, m_i
  195. @triton.heuristics({
  196. "M_LT_N":
  197. lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
  198. })
  199. @triton.jit
  200. def _fwd_kernel_batch_inference(
  201. Q,
  202. K,
  203. V,
  204. Out,
  205. sm_scale,
  206. q_batch_starts,
  207. q_batch_ends,
  208. k_batch_starts,
  209. k_batch_ends,
  210. q_batch_ids,
  211. q_start_sids,
  212. stride_qb,
  213. stride_qt,
  214. stride_qh,
  215. stride_qd,
  216. stride_kb,
  217. stride_kt,
  218. stride_kh,
  219. stride_kd,
  220. stride_vb,
  221. stride_vt,
  222. stride_vh,
  223. stride_vd,
  224. stride_ob,
  225. stride_ot,
  226. stride_oh,
  227. stride_od,
  228. layout_crow_ptr,
  229. layout_col_ptr,
  230. layout_crow_stride_h,
  231. layout_crow_stride_m,
  232. layout_col_stride_h,
  233. layout_col_stride_m,
  234. q_k_ratio,
  235. HAS_BATCH_DIM: tl.constexpr,
  236. D_HEAD: tl.constexpr,
  237. BLOCK_M: tl.constexpr,
  238. BLOCK_N: tl.constexpr,
  239. BLOCK_D: tl.constexpr,
  240. BLOCK_M_LOADING: tl.constexpr,
  241. EVEN_D: tl.constexpr,
  242. M_LT_N: tl.constexpr,
  243. ):
  244. """
  245. NOTATION:
  246. pid: position id
  247. sid: storage id
  248. sbid: storage block id
  249. pbid: position block id
  250. offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
  251. TODO:
  252. Optimize grouped-attn
  253. """
  254. off_zm = tl.program_id(0)
  255. off_h = tl.program_id(1)
  256. off_h_for_kv = off_h // q_k_ratio
  257. if HAS_BATCH_DIM:
  258. off_z = tl.program_id(2)
  259. Q += off_z * stride_qb
  260. K += off_z * stride_kb
  261. V += off_z * stride_vb
  262. Out += off_z * stride_ob
  263. start_m = off_zm
  264. q_start_sid = start_m * BLOCK_M # always 0 for decoding
  265. else:
  266. off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
  267. q_start_sid = tl.load(q_start_sids + off_zm)
  268. start_m = q_start_sid // BLOCK_M # q_sbid
  269. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
  270. offs_n = tl.arange(0, BLOCK_N)
  271. offs_d = tl.arange(0, BLOCK_D)
  272. q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
  273. q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
  274. k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
  275. k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
  276. past_len = k_seqlen - q_seqlen
  277. Q += q_cu_start * stride_qt + off_h * stride_qh
  278. K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
  279. V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
  280. Out += q_cu_start * stride_ot + off_h * stride_oh
  281. q_pbid = (past_len + q_start_sid) // BLOCK_M
  282. if EVEN_D:
  283. q = tl.load(
  284. Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
  285. mask=offs_m[:, None] < q_seqlen,
  286. )
  287. else:
  288. q = tl.load(
  289. Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
  290. mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
  291. other=0,
  292. )
  293. sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
  294. q_pbid * layout_crow_stride_m)
  295. # TODO: load at once, with any Triton version
  296. # that supports `tl.split`, e.g., Triton 3.0
  297. k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
  298. k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
  299. m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
  300. l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
  301. acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
  302. k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
  303. v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
  304. sm_scale *= (
  305. 1.44269504 # 1/log2 as we use base2 for exponential and logarithm
  306. )
  307. for k_block_col_idx in range(k_block_start, k_block_end - 1):
  308. acc, l_i, m_i = _fwd_kernel_inner(
  309. acc,
  310. l_i,
  311. m_i,
  312. q,
  313. Q,
  314. k_block_col_idx,
  315. layout_col_ptr,
  316. layout_col_stride_h,
  317. layout_col_stride_m,
  318. k_ptrs,
  319. v_ptrs,
  320. off_h,
  321. offs_m,
  322. offs_n,
  323. offs_d,
  324. stride_kt,
  325. stride_vt,
  326. sm_scale,
  327. k_seqlen,
  328. past_len,
  329. False,
  330. BLOCK_M_LOADING,
  331. BLOCK_N,
  332. D_HEAD,
  333. EVEN_D,
  334. M_LT_N,
  335. )
  336. acc, l_i, m_i = _fwd_kernel_inner(
  337. acc,
  338. l_i,
  339. m_i,
  340. q,
  341. Q,
  342. k_block_end - 1,
  343. layout_col_ptr,
  344. layout_col_stride_h,
  345. layout_col_stride_m,
  346. k_ptrs,
  347. v_ptrs,
  348. off_h,
  349. offs_m,
  350. offs_n,
  351. offs_d,
  352. stride_kt,
  353. stride_vt,
  354. sm_scale,
  355. k_seqlen,
  356. past_len,
  357. True,
  358. BLOCK_M_LOADING,
  359. BLOCK_N,
  360. D_HEAD,
  361. EVEN_D,
  362. M_LT_N,
  363. )
  364. # flash-attn 2
  365. m_i += tl.math.log2(l_i)
  366. acc = acc / l_i[:, None]
  367. # write output
  368. if EVEN_D:
  369. tl.store(
  370. Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
  371. acc,
  372. mask=offs_m[:, None] < q_seqlen,
  373. )
  374. else:
  375. tl.store(
  376. Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
  377. acc,
  378. mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
  379. )