tree_attn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # The kernels in this file are adapted from LightLLM's context_attention_fwd:
  2. # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
  3. import torch
  4. import triton
  5. import triton.language as tl
  6. if triton.__version__ >= "2.1.0":
  7. @triton.jit
  8. def _fwd_kernel(
  9. Q,
  10. K_cache,
  11. V_cache,
  12. B_Loc,
  13. sm_scale,
  14. B_Ctxlen,
  15. prompt_lens,
  16. block_size,
  17. x,
  18. Out,
  19. stride_b_loc_b,
  20. stride_b_loc_s,
  21. stride_qbs,
  22. stride_qh,
  23. stride_qd,
  24. stride_obs,
  25. stride_oh,
  26. stride_od,
  27. stride_k_cache_bs,
  28. stride_k_cache_h,
  29. stride_k_cache_d,
  30. stride_k_cache_bl,
  31. stride_k_cache_x,
  32. stride_v_cache_bs,
  33. stride_v_cache_h,
  34. stride_v_cache_d,
  35. stride_v_cache_bl,
  36. num_queries_per_kv: int,
  37. tree_width: int,
  38. BLOCK_M: tl.constexpr,
  39. BLOCK_DMODEL: tl.constexpr,
  40. BLOCK_N: tl.constexpr,
  41. ):
  42. cur_batch = tl.program_id(0)
  43. cur_head = tl.program_id(1)
  44. start_m = tl.program_id(2)
  45. cur_kv_head = cur_head // num_queries_per_kv
  46. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  47. cur_batch_in_all_start_index = cur_batch * tree_width
  48. cur_batch_prompt_len = tl.load(prompt_lens + cur_batch)
  49. block_start_loc = BLOCK_M * start_m
  50. # initialize offsets
  51. offs_n = tl.arange(0, BLOCK_N)
  52. offs_d = tl.arange(0, BLOCK_DMODEL)
  53. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  54. off_q = (
  55. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  56. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  57. q = tl.load(Q + off_q, mask=offs_m[:, None] < tree_width, other=0.0)
  58. # # initialize pointer to m and l
  59. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  60. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  61. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  62. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  63. start_n = tl.multiple_of(start_n, BLOCK_N)
  64. # -- compute qk ----
  65. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  66. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  67. mask=(start_n + offs_n) < cur_batch_ctx_len,
  68. other=0)
  69. off_k = (bn[None, :] * stride_k_cache_bs +
  70. cur_kv_head * stride_k_cache_h +
  71. (offs_d[:, None] // x) * stride_k_cache_d +
  72. ((start_n + offs_n[None, :]) % block_size) *
  73. stride_k_cache_bl +
  74. (offs_d[:, None] % x) * stride_k_cache_x)
  75. off_v = (
  76. bn[:, None] * stride_v_cache_bs +
  77. cur_kv_head * stride_v_cache_h +
  78. offs_d[None, :] * stride_v_cache_d +
  79. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  80. k = tl.load(K_cache + off_k,
  81. mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
  82. other=0.0)
  83. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  84. qk += tl.dot(q, k)
  85. cur_step = start_n + offs_n[None, :] # [1, BlockN]
  86. is_prompt = cur_step < cur_batch_prompt_len # [1, BlockN]
  87. tree_mask = (
  88. cur_step - cur_batch_prompt_len - offs_m[:, None]
  89. ) % tree_width == 0 # [1, BlockN] - [BlockM, 1] = [BlockM, BlockN]
  90. tree_mask = is_prompt or tree_mask
  91. mask = tree_mask and (cur_step < cur_batch_ctx_len)
  92. qk = tl.where(mask, qk, -3.4028234663852886e+38)
  93. qk *= sm_scale
  94. # -- compute m_ij, p, l_ij
  95. m_ij = tl.max(qk, 1)
  96. p = tl.exp(qk - m_ij[:, None])
  97. l_ij = tl.sum(p, 1)
  98. # -- update m_i and l_i
  99. m_i_new = tl.maximum(m_i, m_ij)
  100. alpha = tl.exp(m_i - m_i_new)
  101. beta = tl.exp(m_ij - m_i_new)
  102. l_i_new = alpha * l_i + beta * l_ij
  103. # -- update output accumulator --
  104. # scale p
  105. p_scale = beta / l_i_new
  106. p = p * p_scale[:, None]
  107. # scale acc
  108. acc_scale = l_i / l_i_new * alpha
  109. acc = acc * acc_scale[:, None]
  110. # update acc
  111. cur_step = start_n + offs_n[:, None] # (BlockN, 1)
  112. v = tl.load(V_cache + off_v,
  113. mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
  114. other=0.0)
  115. p = p.to(v.dtype)
  116. acc += tl.dot(p, v)
  117. # # update m_i and l_is
  118. l_i = l_i_new
  119. m_i = m_i_new
  120. # initialize pointers to output
  121. off_o = (
  122. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  123. cur_head * stride_oh + offs_d[None, :] * stride_od)
  124. out_ptrs = Out + off_o
  125. tl.store(out_ptrs, acc, mask=offs_m[:, None] < tree_width)
  126. return
  127. @triton.jit
  128. def _fwd_kernel_alibi(
  129. Q,
  130. K_cache,
  131. V_cache,
  132. B_Loc,
  133. sm_scale,
  134. B_Start_Loc,
  135. B_Seqlen,
  136. B_Ctxlen,
  137. Alibi_slopes,
  138. block_size,
  139. x,
  140. Out,
  141. stride_b_loc_b,
  142. stride_b_loc_s,
  143. stride_qbs,
  144. stride_qh,
  145. stride_qd,
  146. stride_obs,
  147. stride_oh,
  148. stride_od,
  149. stride_k_cache_bs,
  150. stride_k_cache_h,
  151. stride_k_cache_d,
  152. stride_k_cache_bl,
  153. stride_k_cache_x,
  154. stride_v_cache_bs,
  155. stride_v_cache_h,
  156. stride_v_cache_d,
  157. stride_v_cache_bl,
  158. num_queries_per_kv: int,
  159. BLOCK_M: tl.constexpr,
  160. BLOCK_DMODEL: tl.constexpr,
  161. BLOCK_N: tl.constexpr,
  162. ):
  163. # attn_bias[]
  164. cur_batch = tl.program_id(0)
  165. cur_head = tl.program_id(1)
  166. start_m = tl.program_id(2)
  167. cur_kv_head = cur_head // num_queries_per_kv
  168. # cur_batch_seq_len: the length of prompts
  169. # cur_batch_ctx_len: the length of prefix
  170. # cur_batch_in_all_start_index: the start id of the dim=0
  171. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  172. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  173. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  174. block_start_loc = BLOCK_M * start_m
  175. # initialize offsets
  176. offs_n = tl.arange(0, BLOCK_N)
  177. offs_d = tl.arange(0, BLOCK_DMODEL)
  178. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  179. off_q = (
  180. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  181. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  182. q = tl.load(
  183. Q + off_q,
  184. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
  185. other=0.0)
  186. # # initialize pointer to m and l
  187. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  188. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  189. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  190. alibi_slope = tl.load(Alibi_slopes + cur_head)
  191. alibi_start_q = tl.arange(
  192. 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
  193. alibi_start_k = 0
  194. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  195. start_n = tl.multiple_of(start_n, BLOCK_N)
  196. # -- compute qk ----
  197. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  198. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  199. mask=(start_n + offs_n) < cur_batch_ctx_len,
  200. other=0)
  201. off_k = (bn[None, :] * stride_k_cache_bs +
  202. cur_kv_head * stride_k_cache_h +
  203. (offs_d[:, None] // x) * stride_k_cache_d +
  204. ((start_n + offs_n[None, :]) % block_size) *
  205. stride_k_cache_bl +
  206. (offs_d[:, None] % x) * stride_k_cache_x)
  207. off_v = (
  208. bn[:, None] * stride_v_cache_bs +
  209. cur_kv_head * stride_v_cache_h +
  210. offs_d[None, :] * stride_v_cache_d +
  211. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  212. k = tl.load(K_cache + off_k,
  213. mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
  214. other=0.0)
  215. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  216. qk += tl.dot(q, k)
  217. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  218. float("-inf"))
  219. qk *= sm_scale
  220. # load alibi
  221. alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
  222. alibi_start_q[:, None]) * alibi_slope
  223. alibi = tl.where(
  224. (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
  225. alibi, float("-inf"))
  226. qk += alibi
  227. alibi_start_k += BLOCK_N
  228. # -- compute m_ij, p, l_ij
  229. m_ij = tl.max(qk, 1)
  230. m_i_new = tl.maximum(m_i, m_ij)
  231. p = tl.math.exp(qk - m_i_new[:, None])
  232. l_ij = tl.sum(p, 1)
  233. # -- update m_i and l_i
  234. alpha = tl.math.exp(m_i - m_i_new)
  235. l_i_new = alpha * l_i + l_ij
  236. # -- update output accumulator --
  237. # scale p
  238. # scale acc
  239. acc_scale = alpha
  240. # acc_scale = l_i / l_i_new * alpha
  241. acc = acc * acc_scale[:, None]
  242. # update acc
  243. v = tl.load(V_cache + off_v,
  244. mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
  245. other=0.0)
  246. p = p.to(v.dtype)
  247. acc += tl.dot(p, v, allow_tf32=False)
  248. # update m_i and l_i
  249. l_i = l_i_new
  250. m_i = m_i_new
  251. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  252. offs_d[:, None] * stride_kd)
  253. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  254. offs_d[None, :] * stride_vd)
  255. k_ptrs = K + off_k
  256. v_ptrs = V + off_v
  257. block_mask = tl.where(
  258. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  259. # init alibi
  260. alibi_slope = tl.load(Alibi_slopes + cur_head)
  261. alibi_start_q = tl.arange(
  262. 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
  263. alibi_start_k = cur_batch_ctx_len
  264. # # init debugger
  265. # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
  266. # offset_db_k = tl.arange(0, BLOCK_N)
  267. # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
  268. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  269. start_n = tl.multiple_of(start_n, BLOCK_N)
  270. # -- compute qk ----
  271. k = tl.load(k_ptrs +
  272. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  273. mask=(start_n + offs_n[None, :]) <
  274. cur_batch_seq_len - cur_batch_ctx_len,
  275. other=0.0)
  276. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  277. qk += tl.dot(q, k, allow_tf32=False)
  278. qk *= sm_scale
  279. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  280. float("-inf"))
  281. # load alibi
  282. alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
  283. alibi_start_q[:, None]) * alibi_slope
  284. alibi = tl.where(
  285. (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
  286. alibi, float("-inf"))
  287. qk += alibi
  288. alibi_start_k += BLOCK_N
  289. # -- compute m_ij, p, l_ij
  290. m_ij = tl.max(qk, 1)
  291. m_i_new = tl.maximum(m_i, m_ij)
  292. p = tl.math.exp(qk - m_i_new[:, None])
  293. l_ij = tl.sum(p, 1)
  294. # -- update m_i and l_i
  295. alpha = tl.math.exp(m_i - m_i_new)
  296. l_i_new = alpha * l_i + l_ij
  297. # -- update output accumulator --
  298. # scale p
  299. # scale acc
  300. acc_scale = alpha
  301. # acc_scale = l_i / l_i_new * alpha
  302. acc = acc * acc_scale[:, None]
  303. # update acc
  304. v = tl.load(v_ptrs +
  305. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  306. mask=(start_n + offs_n[:, None]) <
  307. cur_batch_seq_len - cur_batch_ctx_len,
  308. other=0.0)
  309. p = p.to(v.dtype)
  310. acc += tl.dot(p, v, allow_tf32=False)
  311. # update m_i and l_i
  312. l_i = l_i_new
  313. m_i = m_i_new
  314. acc = acc / l_i[:, None]
  315. # initialize pointers to output
  316. off_o = (
  317. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  318. cur_head * stride_oh + offs_d[None, :] * stride_od)
  319. out_ptrs = Out + off_o
  320. tl.store(out_ptrs,
  321. acc,
  322. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
  323. return
  324. @torch.inference_mode()
  325. def tree_attention_fwd(q,
  326. o,
  327. k_cache,
  328. v_cache,
  329. block_table,
  330. context_len,
  331. prompt_len,
  332. tree_width,
  333. alibi_slopes=None):
  334. cap = torch.cuda.get_device_capability()
  335. BLOCK_N = 128 if cap[0] >= 8 else 64
  336. BLOCK_M = triton.cdiv(tree_width, 16) * 16
  337. # shape constraints
  338. Lq = q.shape[-1]
  339. Lk = k_cache.shape[-1] * k_cache.shape[-3]
  340. assert Lq == Lk
  341. assert Lk in {16, 32, 64, 128}
  342. sm_scale = 1.0 / (Lq**0.5)
  343. batch, head = context_len.shape[0], q.shape[1]
  344. num_queries_per_kv = q.shape[1] // k_cache.shape[1]
  345. grid = (batch, head, triton.cdiv(tree_width, BLOCK_M)) # batch, head,
  346. num_warps = 8
  347. _fwd_kernel[grid](
  348. q,
  349. k_cache,
  350. v_cache,
  351. block_table,
  352. sm_scale,
  353. context_len,
  354. prompt_len,
  355. v_cache.shape[3],
  356. 8,
  357. o,
  358. block_table.stride(0),
  359. block_table.stride(1),
  360. q.stride(0),
  361. q.stride(1),
  362. q.stride(2),
  363. o.stride(0),
  364. o.stride(1),
  365. o.stride(2),
  366. k_cache.stride(0),
  367. k_cache.stride(1),
  368. k_cache.stride(2),
  369. k_cache.stride(3),
  370. k_cache.stride(
  371. 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
  372. v_cache.stride(0),
  373. v_cache.stride(1),
  374. v_cache.stride(2),
  375. v_cache.stride(
  376. 3), #[num_blocks, num_kv_heads, head_size, block_size]
  377. num_queries_per_kv=num_queries_per_kv,
  378. tree_width=tree_width,
  379. BLOCK_M=BLOCK_M,
  380. BLOCK_DMODEL=Lk,
  381. BLOCK_N=BLOCK_N,
  382. num_warps=num_warps,
  383. num_stages=1,
  384. )
  385. return