prefix_prefill.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. # The kernels in this file are adapted from LightLLM's context_attention_fwd:
  2. # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
  3. import torch
  4. import triton
  5. import triton.language as tl
  6. if triton.__version__ >= "2.1.0":
  7. @triton.jit
  8. def _fwd_kernel(
  9. Q,
  10. K,
  11. V,
  12. K_cache,
  13. V_cache,
  14. B_Loc,
  15. sm_scale,
  16. B_Start_Loc,
  17. B_Seqlen,
  18. B_Ctxlen,
  19. block_size,
  20. x,
  21. Out,
  22. stride_b_loc_b,
  23. stride_b_loc_s,
  24. stride_qbs,
  25. stride_qh,
  26. stride_qd,
  27. stride_kbs,
  28. stride_kh,
  29. stride_kd,
  30. stride_vbs,
  31. stride_vh,
  32. stride_vd,
  33. stride_obs,
  34. stride_oh,
  35. stride_od,
  36. stride_k_cache_bs,
  37. stride_k_cache_h,
  38. stride_k_cache_d,
  39. stride_k_cache_bl,
  40. stride_k_cache_x,
  41. stride_v_cache_bs,
  42. stride_v_cache_h,
  43. stride_v_cache_d,
  44. stride_v_cache_bl,
  45. num_queries_per_kv: int,
  46. BLOCK_M: tl.constexpr,
  47. BLOCK_DMODEL: tl.constexpr,
  48. BLOCK_N: tl.constexpr,
  49. ):
  50. cur_batch = tl.program_id(0)
  51. cur_head = tl.program_id(1)
  52. start_m = tl.program_id(2)
  53. cur_kv_head = cur_head // num_queries_per_kv
  54. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  55. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  56. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  57. block_start_loc = BLOCK_M * start_m
  58. # initialize offsets
  59. offs_n = tl.arange(0, BLOCK_N)
  60. offs_d = tl.arange(0, BLOCK_DMODEL)
  61. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  62. off_q = (
  63. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  64. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  65. q = tl.load(
  66. Q + off_q,
  67. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
  68. other=0.0)
  69. # # initialize pointer to m and l
  70. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  71. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  72. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  73. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  74. start_n = tl.multiple_of(start_n, BLOCK_N)
  75. # -- compute qk ----
  76. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  77. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  78. mask=(start_n + offs_n) < cur_batch_ctx_len,
  79. other=0)
  80. off_k = (bn[None, :] * stride_k_cache_bs +
  81. cur_kv_head * stride_k_cache_h +
  82. (offs_d[:, None] // x) * stride_k_cache_d +
  83. ((start_n + offs_n[None, :]) % block_size) *
  84. stride_k_cache_bl +
  85. (offs_d[:, None] % x) * stride_k_cache_x)
  86. off_v = (
  87. bn[:, None] * stride_v_cache_bs +
  88. cur_kv_head * stride_v_cache_h +
  89. offs_d[None, :] * stride_v_cache_d +
  90. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  91. k = tl.load(K_cache + off_k,
  92. mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
  93. other=0.0)
  94. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  95. qk += tl.dot(q, k)
  96. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  97. float("-inf"))
  98. qk *= sm_scale
  99. # -- compute m_ij, p, l_ij
  100. m_ij = tl.max(qk, 1)
  101. p = tl.exp(qk - m_ij[:, None])
  102. l_ij = tl.sum(p, 1)
  103. # -- update m_i and l_i
  104. m_i_new = tl.maximum(m_i, m_ij)
  105. alpha = tl.exp(m_i - m_i_new)
  106. beta = tl.exp(m_ij - m_i_new)
  107. l_i_new = alpha * l_i + beta * l_ij
  108. # -- update output accumulator --
  109. # scale p
  110. p_scale = beta / l_i_new
  111. p = p * p_scale[:, None]
  112. # scale acc
  113. acc_scale = l_i / l_i_new * alpha
  114. acc = acc * acc_scale[:, None]
  115. # update acc
  116. v = tl.load(V_cache + off_v,
  117. mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
  118. other=0.0)
  119. p = p.to(v.dtype)
  120. acc += tl.dot(p, v)
  121. # # update m_i and l_i
  122. l_i = l_i_new
  123. m_i = m_i_new
  124. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  125. offs_d[:, None] * stride_kd)
  126. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  127. offs_d[None, :] * stride_vd)
  128. k_ptrs = K + off_k
  129. v_ptrs = V + off_v
  130. block_mask = tl.where(
  131. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  132. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  133. start_n = tl.multiple_of(start_n, BLOCK_N)
  134. # -- compute qk ----
  135. k = tl.load(k_ptrs +
  136. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  137. mask=(start_n + offs_n[None, :]) <
  138. cur_batch_seq_len - cur_batch_ctx_len,
  139. other=0.0)
  140. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  141. qk += tl.dot(q, k)
  142. qk *= sm_scale
  143. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  144. float("-inf"))
  145. # -- compute m_ij, p, l_ij
  146. m_ij = tl.max(qk, 1)
  147. p = tl.exp(qk - m_ij[:, None])
  148. l_ij = tl.sum(p, 1)
  149. # -- update m_i and l_i
  150. m_i_new = tl.maximum(m_i, m_ij)
  151. alpha = tl.exp(m_i - m_i_new)
  152. beta = tl.exp(m_ij - m_i_new)
  153. l_i_new = alpha * l_i + beta * l_ij
  154. # -- update output accumulator --
  155. # scale p
  156. p_scale = beta / l_i_new
  157. p = p * p_scale[:, None]
  158. # scale acc
  159. acc_scale = l_i / l_i_new * alpha
  160. acc = acc * acc_scale[:, None]
  161. # update acc
  162. v = tl.load(v_ptrs +
  163. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  164. mask=(start_n + offs_n[:, None]) <
  165. cur_batch_seq_len - cur_batch_ctx_len,
  166. other=0.0)
  167. p = p.to(v.dtype)
  168. acc += tl.dot(p, v)
  169. # update m_i and l_i
  170. l_i = l_i_new
  171. m_i = m_i_new
  172. # initialize pointers to output
  173. off_o = (
  174. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  175. cur_head * stride_oh + offs_d[None, :] * stride_od)
  176. out_ptrs = Out + off_o
  177. tl.store(out_ptrs,
  178. acc,
  179. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
  180. return
  181. @triton.jit
  182. def _fwd_kernel_flash_attn_v2(
  183. Q,
  184. K,
  185. V,
  186. K_cache,
  187. V_cache,
  188. B_Loc,
  189. sm_scale,
  190. B_Start_Loc,
  191. B_Seqlen,
  192. B_Ctxlen,
  193. block_size,
  194. x,
  195. Out,
  196. stride_b_loc_b,
  197. stride_b_loc_s,
  198. stride_qbs,
  199. stride_qh,
  200. stride_qd,
  201. stride_kbs,
  202. stride_kh,
  203. stride_kd,
  204. stride_vbs,
  205. stride_vh,
  206. stride_vd,
  207. stride_obs,
  208. stride_oh,
  209. stride_od,
  210. stride_k_cache_bs,
  211. stride_k_cache_h,
  212. stride_k_cache_d,
  213. stride_k_cache_bl,
  214. stride_k_cache_x,
  215. stride_v_cache_bs,
  216. stride_v_cache_h,
  217. stride_v_cache_d,
  218. stride_v_cache_bl,
  219. num_queries_per_kv: int,
  220. BLOCK_M: tl.constexpr,
  221. BLOCK_DMODEL: tl.constexpr,
  222. BLOCK_N: tl.constexpr,
  223. ):
  224. cur_batch = tl.program_id(0)
  225. cur_head = tl.program_id(1)
  226. start_m = tl.program_id(2)
  227. cur_kv_head = cur_head // num_queries_per_kv
  228. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  229. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  230. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  231. block_start_loc = BLOCK_M * start_m
  232. # initialize offsets
  233. offs_n = tl.arange(0, BLOCK_N)
  234. offs_d = tl.arange(0, BLOCK_DMODEL)
  235. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  236. off_q = (
  237. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  238. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  239. q = tl.load(
  240. Q + off_q,
  241. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
  242. other=0.0)
  243. # # initialize pointer to m and l
  244. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  245. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  246. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  247. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  248. start_n = tl.multiple_of(start_n, BLOCK_N)
  249. # -- compute qk ----
  250. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  251. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  252. mask=(start_n + offs_n) < cur_batch_ctx_len,
  253. other=0)
  254. off_k = (bn[None, :] * stride_k_cache_bs +
  255. cur_kv_head * stride_k_cache_h +
  256. (offs_d[:, None] // x) * stride_k_cache_d +
  257. ((start_n + offs_n[None, :]) % block_size) *
  258. stride_k_cache_bl +
  259. (offs_d[:, None] % x) * stride_k_cache_x)
  260. off_v = (
  261. bn[:, None] * stride_v_cache_bs +
  262. cur_kv_head * stride_v_cache_h +
  263. offs_d[None, :] * stride_v_cache_d +
  264. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  265. k = tl.load(K_cache + off_k,
  266. mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
  267. other=0.0)
  268. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  269. qk += tl.dot(q, k)
  270. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  271. float("-inf"))
  272. qk *= sm_scale
  273. # -- compute m_ij, p, l_ij
  274. m_ij = tl.max(qk, 1)
  275. m_i_new = tl.maximum(m_i, m_ij)
  276. p = tl.math.exp(qk - m_i_new[:, None])
  277. l_ij = tl.sum(p, 1)
  278. # -- update m_i and l_i
  279. alpha = tl.math.exp(m_i - m_i_new)
  280. l_i_new = alpha * l_i + l_ij
  281. # -- update output accumulator --
  282. # scale p
  283. # scale acc
  284. acc_scale = alpha
  285. # acc_scale = l_i / l_i_new * alpha
  286. acc = acc * acc_scale[:, None]
  287. # update acc
  288. v = tl.load(V_cache + off_v,
  289. mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
  290. other=0.0)
  291. p = p.to(v.dtype)
  292. acc += tl.dot(p, v)
  293. # update m_i and l_i
  294. l_i = l_i_new
  295. m_i = m_i_new
  296. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  297. offs_d[:, None] * stride_kd)
  298. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  299. offs_d[None, :] * stride_vd)
  300. k_ptrs = K + off_k
  301. v_ptrs = V + off_v
  302. block_mask = tl.where(
  303. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  304. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  305. start_n = tl.multiple_of(start_n, BLOCK_N)
  306. # -- compute qk ----
  307. k = tl.load(k_ptrs +
  308. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  309. mask=(start_n + offs_n[None, :]) <
  310. cur_batch_seq_len - cur_batch_ctx_len,
  311. other=0.0)
  312. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  313. qk += tl.dot(q, k)
  314. qk *= sm_scale
  315. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  316. float("-inf"))
  317. # -- compute m_ij, p, l_ij
  318. m_ij = tl.max(qk, 1)
  319. m_i_new = tl.maximum(m_i, m_ij)
  320. p = tl.math.exp(qk - m_i_new[:, None])
  321. l_ij = tl.sum(p, 1)
  322. # -- update m_i and l_i
  323. alpha = tl.math.exp(m_i - m_i_new)
  324. l_i_new = alpha * l_i + l_ij
  325. # -- update output accumulator --
  326. # scale p
  327. # scale acc
  328. acc_scale = alpha
  329. # acc_scale = l_i / l_i_new * alpha
  330. acc = acc * acc_scale[:, None]
  331. # update acc
  332. v = tl.load(v_ptrs +
  333. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  334. mask=(start_n + offs_n[:, None]) <
  335. cur_batch_seq_len - cur_batch_ctx_len,
  336. other=0.0)
  337. p = p.to(v.dtype)
  338. acc += tl.dot(p, v)
  339. # update m_i and l_i
  340. l_i = l_i_new
  341. m_i = m_i_new
  342. # acc /= l_i[:, None]
  343. # initialize pointers to output
  344. off_o = (
  345. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  346. cur_head * stride_oh + offs_d[None, :] * stride_od)
  347. out_ptrs = Out + off_o
  348. tl.store(out_ptrs,
  349. acc,
  350. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
  351. return
  352. @triton.jit
  353. def _fwd_kernel_alibi(
  354. Q,
  355. K,
  356. V,
  357. K_cache,
  358. V_cache,
  359. B_Loc,
  360. sm_scale,
  361. B_Start_Loc,
  362. B_Seqlen,
  363. B_Ctxlen,
  364. Alibi_slopes,
  365. block_size,
  366. x,
  367. Out,
  368. stride_b_loc_b,
  369. stride_b_loc_s,
  370. stride_qbs,
  371. stride_qh,
  372. stride_qd,
  373. stride_kbs,
  374. stride_kh,
  375. stride_kd,
  376. stride_vbs,
  377. stride_vh,
  378. stride_vd,
  379. stride_obs,
  380. stride_oh,
  381. stride_od,
  382. stride_k_cache_bs,
  383. stride_k_cache_h,
  384. stride_k_cache_d,
  385. stride_k_cache_bl,
  386. stride_k_cache_x,
  387. stride_v_cache_bs,
  388. stride_v_cache_h,
  389. stride_v_cache_d,
  390. stride_v_cache_bl,
  391. num_queries_per_kv: int,
  392. BLOCK_M: tl.constexpr,
  393. BLOCK_DMODEL: tl.constexpr,
  394. BLOCK_N: tl.constexpr,
  395. ):
  396. # attn_bias[]
  397. cur_batch = tl.program_id(0)
  398. cur_head = tl.program_id(1)
  399. start_m = tl.program_id(2)
  400. cur_kv_head = cur_head // num_queries_per_kv
  401. # cur_batch_seq_len: the length of prompts
  402. # cur_batch_ctx_len: the length of prefix
  403. # cur_batch_in_all_start_index: the start id of the dim=0
  404. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  405. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  406. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  407. block_start_loc = BLOCK_M * start_m
  408. # initialize offsets
  409. offs_n = tl.arange(0, BLOCK_N)
  410. offs_d = tl.arange(0, BLOCK_DMODEL)
  411. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  412. off_q = (
  413. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  414. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  415. q = tl.load(
  416. Q + off_q,
  417. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
  418. other=0.0)
  419. # # initialize pointer to m and l
  420. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  421. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  422. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  423. alibi_slope = tl.load(Alibi_slopes + cur_head)
  424. alibi_start_q = tl.arange(
  425. 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
  426. alibi_start_k = 0
  427. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  428. start_n = tl.multiple_of(start_n, BLOCK_N)
  429. # -- compute qk ----
  430. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  431. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  432. mask=(start_n + offs_n) < cur_batch_ctx_len,
  433. other=0)
  434. off_k = (bn[None, :] * stride_k_cache_bs +
  435. cur_kv_head * stride_k_cache_h +
  436. (offs_d[:, None] // x) * stride_k_cache_d +
  437. ((start_n + offs_n[None, :]) % block_size) *
  438. stride_k_cache_bl +
  439. (offs_d[:, None] % x) * stride_k_cache_x)
  440. off_v = (
  441. bn[:, None] * stride_v_cache_bs +
  442. cur_kv_head * stride_v_cache_h +
  443. offs_d[None, :] * stride_v_cache_d +
  444. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  445. k = tl.load(K_cache + off_k,
  446. mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
  447. other=0.0)
  448. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  449. qk += tl.dot(q, k)
  450. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  451. float("-inf"))
  452. qk *= sm_scale
  453. # load alibi
  454. alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
  455. alibi_start_q[:, None]) * alibi_slope
  456. alibi = tl.where(
  457. (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
  458. alibi, float("-inf"))
  459. qk += alibi
  460. alibi_start_k += BLOCK_N
  461. # -- compute m_ij, p, l_ij
  462. m_ij = tl.max(qk, 1)
  463. m_i_new = tl.maximum(m_i, m_ij)
  464. p = tl.math.exp(qk - m_i_new[:, None])
  465. l_ij = tl.sum(p, 1)
  466. # -- update m_i and l_i
  467. alpha = tl.math.exp(m_i - m_i_new)
  468. l_i_new = alpha * l_i + l_ij
  469. # -- update output accumulator --
  470. # scale p
  471. # scale acc
  472. acc_scale = alpha
  473. # acc_scale = l_i / l_i_new * alpha
  474. acc = acc * acc_scale[:, None]
  475. # update acc
  476. v = tl.load(V_cache + off_v,
  477. mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
  478. other=0.0)
  479. p = p.to(v.dtype)
  480. acc += tl.dot(p, v, allow_tf32=False)
  481. # update m_i and l_i
  482. l_i = l_i_new
  483. m_i = m_i_new
  484. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  485. offs_d[:, None] * stride_kd)
  486. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  487. offs_d[None, :] * stride_vd)
  488. k_ptrs = K + off_k
  489. v_ptrs = V + off_v
  490. block_mask = tl.where(
  491. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  492. # init alibi
  493. alibi_slope = tl.load(Alibi_slopes + cur_head)
  494. alibi_start_q = tl.arange(
  495. 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
  496. alibi_start_k = cur_batch_ctx_len
  497. # # init debugger
  498. # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
  499. # offset_db_k = tl.arange(0, BLOCK_N)
  500. # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
  501. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  502. start_n = tl.multiple_of(start_n, BLOCK_N)
  503. # -- compute qk ----
  504. k = tl.load(k_ptrs +
  505. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  506. mask=(start_n + offs_n[None, :]) <
  507. cur_batch_seq_len - cur_batch_ctx_len,
  508. other=0.0)
  509. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  510. qk += tl.dot(q, k, allow_tf32=False)
  511. qk *= sm_scale
  512. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  513. float("-inf"))
  514. # load alibi
  515. alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
  516. alibi_start_q[:, None]) * alibi_slope
  517. alibi = tl.where(
  518. (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
  519. alibi, float("-inf"))
  520. qk += alibi
  521. alibi_start_k += BLOCK_N
  522. # -- compute m_ij, p, l_ij
  523. m_ij = tl.max(qk, 1)
  524. m_i_new = tl.maximum(m_i, m_ij)
  525. p = tl.math.exp(qk - m_i_new[:, None])
  526. l_ij = tl.sum(p, 1)
  527. # -- update m_i and l_i
  528. alpha = tl.math.exp(m_i - m_i_new)
  529. l_i_new = alpha * l_i + l_ij
  530. # -- update output accumulator --
  531. # scale p
  532. # scale acc
  533. acc_scale = alpha
  534. # acc_scale = l_i / l_i_new * alpha
  535. acc = acc * acc_scale[:, None]
  536. # update acc
  537. v = tl.load(v_ptrs +
  538. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  539. mask=(start_n + offs_n[:, None]) <
  540. cur_batch_seq_len - cur_batch_ctx_len,
  541. other=0.0)
  542. p = p.to(v.dtype)
  543. acc += tl.dot(p, v, allow_tf32=False)
  544. # update m_i and l_i
  545. l_i = l_i_new
  546. m_i = m_i_new
  547. acc = acc / l_i[:, None]
  548. # initialize pointers to output
  549. off_o = (
  550. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  551. cur_head * stride_oh + offs_d[None, :] * stride_od)
  552. out_ptrs = Out + off_o
  553. tl.store(out_ptrs,
  554. acc,
  555. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
  556. return
  557. @torch.inference_mode()
  558. def context_attention_fwd(q,
  559. k,
  560. v,
  561. o,
  562. k_cache,
  563. v_cache,
  564. b_loc,
  565. b_start_loc,
  566. b_seq_len,
  567. b_ctx_len,
  568. max_input_len,
  569. alibi_slopes=None):
  570. cap = torch.cuda.get_device_capability()
  571. BLOCK = 128 if cap[0] >= 8 else 64
  572. # shape constraints
  573. Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
  574. assert Lq == Lk and Lk == Lv
  575. assert Lk in {16, 32, 64, 128}
  576. sm_scale = 1.0 / (Lq**0.5)
  577. batch, head = b_seq_len.shape[0], q.shape[1]
  578. num_queries_per_kv = q.shape[1] // k.shape[1]
  579. grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
  580. num_warps = 8 if Lk <= 64 else 8
  581. if alibi_slopes is not None:
  582. _fwd_kernel_alibi[grid](
  583. q,
  584. k,
  585. v,
  586. k_cache,
  587. v_cache,
  588. b_loc,
  589. sm_scale,
  590. b_start_loc,
  591. b_seq_len,
  592. b_ctx_len,
  593. alibi_slopes,
  594. v_cache.shape[3],
  595. 8,
  596. o,
  597. b_loc.stride(0),
  598. b_loc.stride(1),
  599. q.stride(0),
  600. q.stride(1),
  601. q.stride(2),
  602. k.stride(0),
  603. k.stride(1),
  604. k.stride(2),
  605. v.stride(0),
  606. v.stride(1),
  607. v.stride(2),
  608. o.stride(0),
  609. o.stride(1),
  610. o.stride(2),
  611. k_cache.stride(0),
  612. k_cache.stride(1),
  613. k_cache.stride(2),
  614. k_cache.stride(3),
  615. k_cache.stride(
  616. 4
  617. ), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
  618. v_cache.stride(0),
  619. v_cache.stride(1),
  620. v_cache.stride(2),
  621. v_cache.stride(
  622. 3), #[num_blocks, num_kv_heads, head_size, block_size]
  623. num_queries_per_kv=num_queries_per_kv,
  624. BLOCK_M=BLOCK,
  625. BLOCK_DMODEL=Lk,
  626. BLOCK_N=BLOCK,
  627. num_warps=num_warps,
  628. num_stages=1,
  629. )
  630. return
  631. _fwd_kernel[grid](
  632. q,
  633. k,
  634. v,
  635. k_cache,
  636. v_cache,
  637. b_loc,
  638. sm_scale,
  639. b_start_loc,
  640. b_seq_len,
  641. b_ctx_len,
  642. v_cache.shape[3],
  643. 8,
  644. o,
  645. b_loc.stride(0),
  646. b_loc.stride(1),
  647. q.stride(0),
  648. q.stride(1),
  649. q.stride(2),
  650. k.stride(0),
  651. k.stride(1),
  652. k.stride(2),
  653. v.stride(0),
  654. v.stride(1),
  655. v.stride(2),
  656. o.stride(0),
  657. o.stride(1),
  658. o.stride(2),
  659. k_cache.stride(0),
  660. k_cache.stride(1),
  661. k_cache.stride(2),
  662. k_cache.stride(3),
  663. k_cache.stride(
  664. 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
  665. v_cache.stride(0),
  666. v_cache.stride(1),
  667. v_cache.stride(2),
  668. v_cache.stride(
  669. 3), #[num_blocks, num_kv_heads, head_size, block_size]
  670. num_queries_per_kv=num_queries_per_kv,
  671. BLOCK_M=BLOCK,
  672. BLOCK_DMODEL=Lk,
  673. BLOCK_N=BLOCK,
  674. num_warps=num_warps,
  675. num_stages=1,
  676. )
  677. return