prefix_prefill.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. # The kernels in this file are adapted from LightLLM's context_attention_fwd:
  2. # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
  3. import torch
  4. import triton
  5. import triton.language as tl
  6. if triton.__version__ >= "2.1.0":
  7. @triton.jit
  8. def _fwd_kernel(
  9. Q,
  10. K,
  11. V,
  12. K_cache,
  13. V_cache,
  14. B_Loc,
  15. sm_scale,
  16. B_Start_Loc,
  17. B_Seqlen,
  18. B_Ctxlen,
  19. block_size,
  20. x,
  21. Out,
  22. stride_b_loc_b,
  23. stride_b_loc_s,
  24. stride_qbs,
  25. stride_qh,
  26. stride_qd,
  27. stride_kbs,
  28. stride_kh,
  29. stride_kd,
  30. stride_vbs,
  31. stride_vh,
  32. stride_vd,
  33. stride_obs,
  34. stride_oh,
  35. stride_od,
  36. stride_k_cache_bs,
  37. stride_k_cache_h,
  38. stride_k_cache_d,
  39. stride_k_cache_bl,
  40. stride_k_cache_x,
  41. stride_v_cache_bs,
  42. stride_v_cache_h,
  43. stride_v_cache_d,
  44. stride_v_cache_bl,
  45. BLOCK_M: tl.constexpr,
  46. BLOCK_DMODEL: tl.constexpr,
  47. BLOCK_N: tl.constexpr,
  48. ):
  49. cur_batch = tl.program_id(0)
  50. cur_head = tl.program_id(1)
  51. start_m = tl.program_id(2)
  52. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  53. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  54. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  55. block_start_loc = BLOCK_M * start_m
  56. # initialize offsets
  57. offs_n = tl.arange(0, BLOCK_N)
  58. offs_d = tl.arange(0, BLOCK_DMODEL)
  59. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  60. off_q = (
  61. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  62. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  63. q = tl.load(
  64. Q + off_q,
  65. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
  66. other=0.0)
  67. # # initialize pointer to m and l
  68. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  69. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  70. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  71. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  72. start_n = tl.multiple_of(start_n, BLOCK_N)
  73. # -- compute qk ----
  74. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  75. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  76. mask=(start_n + offs_n) < cur_batch_ctx_len,
  77. other=0)
  78. off_k = (bn[None, :] * stride_k_cache_bs +
  79. cur_head * stride_k_cache_h +
  80. (offs_d[:, None] // x) * stride_k_cache_d +
  81. ((start_n + offs_n[None, :]) % block_size) *
  82. stride_k_cache_bl +
  83. (offs_d[:, None] % x) * stride_k_cache_x)
  84. off_v = (
  85. bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
  86. offs_d[None, :] * stride_v_cache_d +
  87. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  88. k = tl.load(K_cache + off_k,
  89. mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
  90. other=0.0)
  91. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  92. qk += tl.dot(q, k)
  93. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  94. float("-inf"))
  95. qk *= sm_scale
  96. # -- compute m_ij, p, l_ij
  97. m_ij = tl.max(qk, 1)
  98. p = tl.exp(qk - m_ij[:, None])
  99. l_ij = tl.sum(p, 1)
  100. # -- update m_i and l_i
  101. m_i_new = tl.maximum(m_i, m_ij)
  102. alpha = tl.exp(m_i - m_i_new)
  103. beta = tl.exp(m_ij - m_i_new)
  104. l_i_new = alpha * l_i + beta * l_ij
  105. # -- update output accumulator --
  106. # scale p
  107. p_scale = beta / l_i_new
  108. p = p * p_scale[:, None]
  109. # scale acc
  110. acc_scale = l_i / l_i_new * alpha
  111. acc = acc * acc_scale[:, None]
  112. # update acc
  113. v = tl.load(V_cache + off_v,
  114. mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
  115. other=0.0)
  116. p = p.to(v.dtype)
  117. acc += tl.dot(p, v)
  118. # # update m_i and l_i
  119. l_i = l_i_new
  120. m_i = m_i_new
  121. off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
  122. offs_d[:, None] * stride_kd)
  123. off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
  124. offs_d[None, :] * stride_vd)
  125. k_ptrs = K + off_k
  126. v_ptrs = V + off_v
  127. block_mask = tl.where(
  128. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  129. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  130. start_n = tl.multiple_of(start_n, BLOCK_N)
  131. # -- compute qk ----
  132. k = tl.load(k_ptrs +
  133. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  134. mask=(start_n + offs_n[None, :]) <
  135. cur_batch_seq_len - cur_batch_ctx_len,
  136. other=0.0)
  137. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  138. qk += tl.dot(q, k)
  139. qk *= sm_scale
  140. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  141. float("-inf"))
  142. # -- compute m_ij, p, l_ij
  143. m_ij = tl.max(qk, 1)
  144. p = tl.exp(qk - m_ij[:, None])
  145. l_ij = tl.sum(p, 1)
  146. # -- update m_i and l_i
  147. m_i_new = tl.maximum(m_i, m_ij)
  148. alpha = tl.exp(m_i - m_i_new)
  149. beta = tl.exp(m_ij - m_i_new)
  150. l_i_new = alpha * l_i + beta * l_ij
  151. # -- update output accumulator --
  152. # scale p
  153. p_scale = beta / l_i_new
  154. p = p * p_scale[:, None]
  155. # scale acc
  156. acc_scale = l_i / l_i_new * alpha
  157. acc = acc * acc_scale[:, None]
  158. # update acc
  159. v = tl.load(v_ptrs +
  160. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  161. mask=(start_n + offs_n[:, None]) <
  162. cur_batch_seq_len - cur_batch_ctx_len,
  163. other=0.0)
  164. p = p.to(v.dtype)
  165. acc += tl.dot(p, v)
  166. # update m_i and l_i
  167. l_i = l_i_new
  168. m_i = m_i_new
  169. # initialize pointers to output
  170. off_o = (
  171. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  172. cur_head * stride_oh + offs_d[None, :] * stride_od)
  173. out_ptrs = Out + off_o
  174. tl.store(out_ptrs,
  175. acc,
  176. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
  177. return
  178. @triton.jit
  179. def _fwd_kernel_flash_attn_v2(
  180. Q,
  181. K,
  182. V,
  183. K_cache,
  184. V_cache,
  185. B_Loc,
  186. sm_scale,
  187. B_Start_Loc,
  188. B_Seqlen,
  189. B_Ctxlen,
  190. block_size,
  191. x,
  192. Out,
  193. stride_b_loc_b,
  194. stride_b_loc_s,
  195. stride_qbs,
  196. stride_qh,
  197. stride_qd,
  198. stride_kbs,
  199. stride_kh,
  200. stride_kd,
  201. stride_vbs,
  202. stride_vh,
  203. stride_vd,
  204. stride_obs,
  205. stride_oh,
  206. stride_od,
  207. stride_k_cache_bs,
  208. stride_k_cache_h,
  209. stride_k_cache_d,
  210. stride_k_cache_bl,
  211. stride_k_cache_x,
  212. stride_v_cache_bs,
  213. stride_v_cache_h,
  214. stride_v_cache_d,
  215. stride_v_cache_bl,
  216. BLOCK_M: tl.constexpr,
  217. BLOCK_DMODEL: tl.constexpr,
  218. BLOCK_N: tl.constexpr,
  219. ):
  220. cur_batch = tl.program_id(0)
  221. cur_head = tl.program_id(1)
  222. start_m = tl.program_id(2)
  223. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  224. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  225. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  226. block_start_loc = BLOCK_M * start_m
  227. # initialize offsets
  228. offs_n = tl.arange(0, BLOCK_N)
  229. offs_d = tl.arange(0, BLOCK_DMODEL)
  230. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  231. off_q = (
  232. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  233. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  234. q = tl.load(
  235. Q + off_q,
  236. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
  237. other=0.0)
  238. # # initialize pointer to m and l
  239. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  240. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  241. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  242. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  243. start_n = tl.multiple_of(start_n, BLOCK_N)
  244. # -- compute qk ----
  245. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  246. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  247. mask=(start_n + offs_n) < cur_batch_ctx_len,
  248. other=0)
  249. off_k = (bn[None, :] * stride_k_cache_bs +
  250. cur_head * stride_k_cache_h +
  251. (offs_d[:, None] // x) * stride_k_cache_d +
  252. ((start_n + offs_n[None, :]) % block_size) *
  253. stride_k_cache_bl +
  254. (offs_d[:, None] % x) * stride_k_cache_x)
  255. off_v = (
  256. bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
  257. offs_d[None, :] * stride_v_cache_d +
  258. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  259. k = tl.load(K_cache + off_k,
  260. mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
  261. other=0.0)
  262. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  263. qk += tl.dot(q, k)
  264. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  265. float("-inf"))
  266. qk *= sm_scale
  267. # -- compute m_ij, p, l_ij
  268. m_ij = tl.max(qk, 1)
  269. m_i_new = tl.maximum(m_i, m_ij)
  270. p = tl.math.exp(qk - m_i_new[:, None])
  271. l_ij = tl.sum(p, 1)
  272. # -- update m_i and l_i
  273. alpha = tl.math.exp(m_i - m_i_new)
  274. l_i_new = alpha * l_i + l_ij
  275. # -- update output accumulator --
  276. # scale p
  277. # scale acc
  278. acc_scale = alpha
  279. # acc_scale = l_i / l_i_new * alpha
  280. acc = acc * acc_scale[:, None]
  281. # update acc
  282. v = tl.load(V_cache + off_v,
  283. mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
  284. other=0.0)
  285. p = p.to(v.dtype)
  286. acc += tl.dot(p, v)
  287. # update m_i and l_i
  288. l_i = l_i_new
  289. m_i = m_i_new
  290. off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
  291. offs_d[:, None] * stride_kd)
  292. off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
  293. offs_d[None, :] * stride_vd)
  294. k_ptrs = K + off_k
  295. v_ptrs = V + off_v
  296. block_mask = tl.where(
  297. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  298. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  299. start_n = tl.multiple_of(start_n, BLOCK_N)
  300. # -- compute qk ----
  301. k = tl.load(k_ptrs +
  302. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  303. mask=(start_n + offs_n[None, :]) <
  304. cur_batch_seq_len - cur_batch_ctx_len,
  305. other=0.0)
  306. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  307. qk += tl.dot(q, k)
  308. qk *= sm_scale
  309. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  310. float("-inf"))
  311. # -- compute m_ij, p, l_ij
  312. m_ij = tl.max(qk, 1)
  313. m_i_new = tl.maximum(m_i, m_ij)
  314. p = tl.math.exp(qk - m_i_new[:, None])
  315. l_ij = tl.sum(p, 1)
  316. # -- update m_i and l_i
  317. alpha = tl.math.exp(m_i - m_i_new)
  318. l_i_new = alpha * l_i + l_ij
  319. # -- update output accumulator --
  320. # scale p
  321. # scale acc
  322. acc_scale = alpha
  323. # acc_scale = l_i / l_i_new * alpha
  324. acc = acc * acc_scale[:, None]
  325. # update acc
  326. v = tl.load(v_ptrs +
  327. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  328. mask=(start_n + offs_n[:, None]) <
  329. cur_batch_seq_len - cur_batch_ctx_len,
  330. other=0.0)
  331. p = p.to(v.dtype)
  332. acc += tl.dot(p, v)
  333. # update m_i and l_i
  334. l_i = l_i_new
  335. m_i = m_i_new
  336. # acc /= l_i[:, None]
  337. # initialize pointers to output
  338. off_o = (
  339. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  340. cur_head * stride_oh + offs_d[None, :] * stride_od)
  341. out_ptrs = Out + off_o
  342. tl.store(out_ptrs,
  343. acc,
  344. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
  345. return
  346. @triton.jit
  347. def _fwd_kernel_alibi(
  348. Q,
  349. K,
  350. V,
  351. K_cache,
  352. V_cache,
  353. B_Loc,
  354. sm_scale,
  355. B_Start_Loc,
  356. B_Seqlen,
  357. B_Ctxlen,
  358. Alibi_slopes,
  359. block_size,
  360. x,
  361. Out,
  362. stride_b_loc_b,
  363. stride_b_loc_s,
  364. stride_qbs,
  365. stride_qh,
  366. stride_qd,
  367. stride_kbs,
  368. stride_kh,
  369. stride_kd,
  370. stride_vbs,
  371. stride_vh,
  372. stride_vd,
  373. stride_obs,
  374. stride_oh,
  375. stride_od,
  376. stride_k_cache_bs,
  377. stride_k_cache_h,
  378. stride_k_cache_d,
  379. stride_k_cache_bl,
  380. stride_k_cache_x,
  381. stride_v_cache_bs,
  382. stride_v_cache_h,
  383. stride_v_cache_d,
  384. stride_v_cache_bl,
  385. BLOCK_M: tl.constexpr,
  386. BLOCK_DMODEL: tl.constexpr,
  387. BLOCK_N: tl.constexpr,
  388. ):
  389. # attn_bias[]
  390. cur_batch = tl.program_id(0)
  391. cur_head = tl.program_id(1)
  392. start_m = tl.program_id(2)
  393. # cur_batch_seq_len: the length of prompts
  394. # cur_batch_ctx_len: the length of prefix
  395. # cur_batch_in_all_start_index: the start id of the dim=0
  396. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  397. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  398. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  399. block_start_loc = BLOCK_M * start_m
  400. # initialize offsets
  401. offs_n = tl.arange(0, BLOCK_N)
  402. offs_d = tl.arange(0, BLOCK_DMODEL)
  403. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  404. off_q = (
  405. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  406. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  407. q = tl.load(
  408. Q + off_q,
  409. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
  410. other=0.0)
  411. # # initialize pointer to m and l
  412. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  413. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  414. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  415. alibi_slope = tl.load(Alibi_slopes + cur_head)
  416. alibi_start_q = tl.arange(
  417. 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
  418. alibi_start_k = 0
  419. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  420. start_n = tl.multiple_of(start_n, BLOCK_N)
  421. # -- compute qk ----
  422. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  423. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  424. mask=(start_n + offs_n) < cur_batch_ctx_len,
  425. other=0)
  426. off_k = (bn[None, :] * stride_k_cache_bs +
  427. cur_head * stride_k_cache_h +
  428. (offs_d[:, None] // x) * stride_k_cache_d +
  429. ((start_n + offs_n[None, :]) % block_size) *
  430. stride_k_cache_bl +
  431. (offs_d[:, None] % x) * stride_k_cache_x)
  432. off_v = (
  433. bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
  434. offs_d[None, :] * stride_v_cache_d +
  435. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  436. k = tl.load(K_cache + off_k,
  437. mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
  438. other=0.0)
  439. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  440. qk += tl.dot(q, k)
  441. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  442. float("-inf"))
  443. qk *= sm_scale
  444. # load alibi
  445. alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
  446. alibi_start_q[:, None]) * alibi_slope
  447. alibi = tl.where(
  448. (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
  449. alibi, float("-inf"))
  450. qk += alibi
  451. alibi_start_k += BLOCK_N
  452. # -- compute m_ij, p, l_ij
  453. m_ij = tl.max(qk, 1)
  454. m_i_new = tl.maximum(m_i, m_ij)
  455. p = tl.math.exp(qk - m_i_new[:, None])
  456. l_ij = tl.sum(p, 1)
  457. # -- update m_i and l_i
  458. alpha = tl.math.exp(m_i - m_i_new)
  459. l_i_new = alpha * l_i + l_ij
  460. # -- update output accumulator --
  461. # scale p
  462. # scale acc
  463. acc_scale = alpha
  464. # acc_scale = l_i / l_i_new * alpha
  465. acc = acc * acc_scale[:, None]
  466. # update acc
  467. v = tl.load(V_cache + off_v,
  468. mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
  469. other=0.0)
  470. p = p.to(v.dtype)
  471. acc += tl.dot(p, v, allow_tf32=False)
  472. # update m_i and l_i
  473. l_i = l_i_new
  474. m_i = m_i_new
  475. off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
  476. offs_d[:, None] * stride_kd)
  477. off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
  478. offs_d[None, :] * stride_vd)
  479. k_ptrs = K + off_k
  480. v_ptrs = V + off_v
  481. block_mask = tl.where(
  482. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  483. # init alibi
  484. alibi_slope = tl.load(Alibi_slopes + cur_head)
  485. alibi_start_q = tl.arange(
  486. 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
  487. alibi_start_k = cur_batch_ctx_len
  488. # # init debuger
  489. # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
  490. # offset_db_k = tl.arange(0, BLOCK_N)
  491. # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
  492. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  493. start_n = tl.multiple_of(start_n, BLOCK_N)
  494. # -- compute qk ----
  495. k = tl.load(k_ptrs +
  496. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  497. mask=(start_n + offs_n[None, :]) <
  498. cur_batch_seq_len - cur_batch_ctx_len,
  499. other=0.0)
  500. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  501. qk += tl.dot(q, k, allow_tf32=False)
  502. qk *= sm_scale
  503. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  504. float("-inf"))
  505. # load alibi
  506. alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
  507. alibi_start_q[:, None]) * alibi_slope
  508. alibi = tl.where(
  509. (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
  510. alibi, float("-inf"))
  511. qk += alibi
  512. alibi_start_k += BLOCK_N
  513. # -- compute m_ij, p, l_ij
  514. m_ij = tl.max(qk, 1)
  515. m_i_new = tl.maximum(m_i, m_ij)
  516. p = tl.math.exp(qk - m_i_new[:, None])
  517. l_ij = tl.sum(p, 1)
  518. # -- update m_i and l_i
  519. alpha = tl.math.exp(m_i - m_i_new)
  520. l_i_new = alpha * l_i + l_ij
  521. # -- update output accumulator --
  522. # scale p
  523. # scale acc
  524. acc_scale = alpha
  525. # acc_scale = l_i / l_i_new * alpha
  526. acc = acc * acc_scale[:, None]
  527. # update acc
  528. v = tl.load(v_ptrs +
  529. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  530. mask=(start_n + offs_n[:, None]) <
  531. cur_batch_seq_len - cur_batch_ctx_len,
  532. other=0.0)
  533. p = p.to(v.dtype)
  534. acc += tl.dot(p, v, allow_tf32=False)
  535. # update m_i and l_i
  536. l_i = l_i_new
  537. m_i = m_i_new
  538. acc = acc / l_i[:, None]
  539. # initialize pointers to output
  540. off_o = (
  541. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  542. cur_head * stride_oh + offs_d[None, :] * stride_od)
  543. out_ptrs = Out + off_o
  544. tl.store(out_ptrs,
  545. acc,
  546. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
  547. return
  548. @torch.inference_mode()
  549. def context_attention_fwd(q,
  550. k,
  551. v,
  552. o,
  553. k_cache,
  554. v_cache,
  555. b_loc,
  556. b_start_loc,
  557. b_seq_len,
  558. b_ctx_len,
  559. max_input_len,
  560. alibi_slopes=None):
  561. cap = torch.cuda.get_device_capability()
  562. BLOCK = 128 if cap[0] >= 8 else 64
  563. # shape constraints
  564. Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
  565. assert Lq == Lk and Lk == Lv
  566. assert Lk in {16, 32, 64, 128}
  567. sm_scale = 1.0 / (Lq**0.5)
  568. batch, head = b_seq_len.shape[0], q.shape[1]
  569. grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
  570. num_warps = 8 if Lk <= 64 else 8
  571. if alibi_slopes is not None:
  572. _fwd_kernel_alibi[grid](
  573. q,
  574. k,
  575. v,
  576. k_cache,
  577. v_cache,
  578. b_loc,
  579. sm_scale,
  580. b_start_loc,
  581. b_seq_len,
  582. b_ctx_len,
  583. alibi_slopes,
  584. v_cache.shape[3],
  585. 8,
  586. o,
  587. b_loc.stride(0),
  588. b_loc.stride(1),
  589. q.stride(0),
  590. q.stride(1),
  591. q.stride(2),
  592. k.stride(0),
  593. k.stride(1),
  594. k.stride(2),
  595. v.stride(0),
  596. v.stride(1),
  597. v.stride(2),
  598. o.stride(0),
  599. o.stride(1),
  600. o.stride(2),
  601. k_cache.stride(0),
  602. k_cache.stride(1),
  603. k_cache.stride(2),
  604. k_cache.stride(3),
  605. k_cache.stride(
  606. 4
  607. ), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
  608. v_cache.stride(0),
  609. v_cache.stride(1),
  610. v_cache.stride(2),
  611. v_cache.stride(
  612. 3), #[num_blocks, num_kv_heads, head_size, block_size]
  613. BLOCK_M=BLOCK,
  614. BLOCK_DMODEL=Lk,
  615. BLOCK_N=BLOCK,
  616. num_warps=num_warps,
  617. num_stages=1,
  618. )
  619. return
  620. _fwd_kernel[grid](
  621. q,
  622. k,
  623. v,
  624. k_cache,
  625. v_cache,
  626. b_loc,
  627. sm_scale,
  628. b_start_loc,
  629. b_seq_len,
  630. b_ctx_len,
  631. v_cache.shape[3],
  632. 8,
  633. o,
  634. b_loc.stride(0),
  635. b_loc.stride(1),
  636. q.stride(0),
  637. q.stride(1),
  638. q.stride(2),
  639. k.stride(0),
  640. k.stride(1),
  641. k.stride(2),
  642. v.stride(0),
  643. v.stride(1),
  644. v.stride(2),
  645. o.stride(0),
  646. o.stride(1),
  647. o.stride(2),
  648. k_cache.stride(0),
  649. k_cache.stride(1),
  650. k_cache.stride(2),
  651. k_cache.stride(3),
  652. k_cache.stride(
  653. 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
  654. v_cache.stride(0),
  655. v_cache.stride(1),
  656. v_cache.stride(2),
  657. v_cache.stride(
  658. 3), #[num_blocks, num_kv_heads, head_size, block_size]
  659. BLOCK_M=BLOCK,
  660. BLOCK_DMODEL=Lk,
  661. BLOCK_N=BLOCK,
  662. num_warps=num_warps,
  663. num_stages=1,
  664. )
  665. return