prefix_prefill.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755
  1. # The kernels in this file are adapted from LightLLM's context_attention_fwd:
  2. # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
  3. import torch
  4. import triton
  5. import triton.language as tl
  6. if triton.__version__ >= "2.1.0":
  7. @triton.jit
  8. def _fwd_kernel(
  9. Q,
  10. K,
  11. V,
  12. K_cache,
  13. V_cache,
  14. B_Loc,
  15. sm_scale,
  16. B_Start_Loc,
  17. B_Seqlen,
  18. B_Ctxlen,
  19. block_size,
  20. x,
  21. Out,
  22. stride_b_loc_b,
  23. stride_b_loc_s,
  24. stride_qbs,
  25. stride_qh,
  26. stride_qd,
  27. stride_kbs,
  28. stride_kh,
  29. stride_kd,
  30. stride_vbs,
  31. stride_vh,
  32. stride_vd,
  33. stride_obs,
  34. stride_oh,
  35. stride_od,
  36. stride_k_cache_bs,
  37. stride_k_cache_h,
  38. stride_k_cache_d,
  39. stride_k_cache_bl,
  40. stride_k_cache_x,
  41. stride_v_cache_bs,
  42. stride_v_cache_h,
  43. stride_v_cache_d,
  44. stride_v_cache_bl,
  45. num_queries_per_kv: int,
  46. BLOCK_M: tl.constexpr,
  47. BLOCK_DMODEL: tl.constexpr, # head size
  48. BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
  49. BLOCK_N: tl.constexpr,
  50. ):
  51. cur_batch = tl.program_id(0)
  52. cur_head = tl.program_id(1)
  53. start_m = tl.program_id(2)
  54. cur_kv_head = cur_head // num_queries_per_kv
  55. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  56. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  57. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  58. cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
  59. block_start_loc = BLOCK_M * start_m
  60. # initialize offsets
  61. offs_n = tl.arange(0, BLOCK_N)
  62. offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
  63. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  64. off_q = (
  65. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  66. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  67. dim_mask = tl.where(
  68. tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
  69. q = tl.load(Q + off_q,
  70. mask=dim_mask[None, :] &
  71. (offs_m[:, None] < cur_batch_query_len),
  72. other=0.0)
  73. # # initialize pointer to m and l
  74. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  75. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  76. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
  77. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  78. start_n = tl.multiple_of(start_n, BLOCK_N)
  79. # -- compute qk ----
  80. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  81. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  82. mask=(start_n + offs_n) < cur_batch_ctx_len,
  83. other=0)
  84. off_k = (bn[None, :] * stride_k_cache_bs +
  85. cur_kv_head * stride_k_cache_h +
  86. (offs_d[:, None] // x) * stride_k_cache_d +
  87. ((start_n + offs_n[None, :]) % block_size) *
  88. stride_k_cache_bl +
  89. (offs_d[:, None] % x) * stride_k_cache_x)
  90. off_v = (
  91. bn[:, None] * stride_v_cache_bs +
  92. cur_kv_head * stride_v_cache_h +
  93. offs_d[None, :] * stride_v_cache_d +
  94. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  95. k = tl.load(K_cache + off_k,
  96. mask=dim_mask[:, None] &
  97. ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
  98. other=0.0)
  99. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  100. qk += tl.dot(q, k)
  101. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  102. float("-inf"))
  103. qk *= sm_scale
  104. # -- compute m_ij, p, l_ij
  105. m_ij = tl.max(qk, 1)
  106. p = tl.exp(qk - m_ij[:, None])
  107. l_ij = tl.sum(p, 1)
  108. # -- update m_i and l_i
  109. m_i_new = tl.maximum(m_i, m_ij)
  110. alpha = tl.exp(m_i - m_i_new)
  111. beta = tl.exp(m_ij - m_i_new)
  112. l_i_new = alpha * l_i + beta * l_ij
  113. # -- update output accumulator --
  114. # scale p
  115. p_scale = beta / l_i_new
  116. p = p * p_scale[:, None]
  117. # scale acc
  118. acc_scale = l_i / l_i_new * alpha
  119. acc = acc * acc_scale[:, None]
  120. # update acc
  121. v = tl.load(V_cache + off_v,
  122. mask=dim_mask[None, :] &
  123. ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
  124. other=0.0)
  125. p = p.to(v.dtype)
  126. acc += tl.dot(p, v)
  127. # # update m_i and l_i
  128. l_i = l_i_new
  129. m_i = m_i_new
  130. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  131. offs_d[:, None] * stride_kd)
  132. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  133. offs_d[None, :] * stride_vd)
  134. k_ptrs = K + off_k
  135. v_ptrs = V + off_v
  136. block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
  137. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  138. start_n = tl.multiple_of(start_n, BLOCK_N)
  139. # -- compute qk ----
  140. k = tl.load(k_ptrs +
  141. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  142. mask=dim_mask[:, None] &
  143. ((start_n + offs_n[None, :]) < cur_batch_query_len),
  144. other=0.0)
  145. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  146. qk += tl.dot(q, k)
  147. qk *= sm_scale
  148. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  149. float("-inf"))
  150. # -- compute m_ij, p, l_ij
  151. m_ij = tl.max(qk, 1)
  152. p = tl.exp(qk - m_ij[:, None])
  153. l_ij = tl.sum(p, 1)
  154. # -- update m_i and l_i
  155. m_i_new = tl.maximum(m_i, m_ij)
  156. alpha = tl.exp(m_i - m_i_new)
  157. beta = tl.exp(m_ij - m_i_new)
  158. l_i_new = alpha * l_i + beta * l_ij
  159. # -- update output accumulator --
  160. # scale p
  161. p_scale = beta / l_i_new
  162. p = p * p_scale[:, None]
  163. # scale acc
  164. acc_scale = l_i / l_i_new * alpha
  165. acc = acc * acc_scale[:, None]
  166. # update acc
  167. v = tl.load(v_ptrs +
  168. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  169. mask=dim_mask[None, :] &
  170. ((start_n + offs_n[:, None]) < cur_batch_query_len),
  171. other=0.0)
  172. p = p.to(v.dtype)
  173. acc += tl.dot(p, v)
  174. # update m_i and l_i
  175. l_i = l_i_new
  176. m_i = m_i_new
  177. # initialize pointers to output
  178. off_o = (
  179. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  180. cur_head * stride_oh + offs_d[None, :] * stride_od)
  181. out_ptrs = Out + off_o
  182. tl.store(out_ptrs,
  183. acc,
  184. mask=dim_mask[None, :] &
  185. (offs_m[:, None] < cur_batch_query_len))
  186. return
  187. @triton.jit
  188. def _fwd_kernel_flash_attn_v2(
  189. Q,
  190. K,
  191. V,
  192. K_cache,
  193. V_cache,
  194. B_Loc,
  195. sm_scale,
  196. B_Start_Loc,
  197. B_Seqlen,
  198. B_Ctxlen,
  199. block_size,
  200. x,
  201. Out,
  202. stride_b_loc_b,
  203. stride_b_loc_s,
  204. stride_qbs,
  205. stride_qh,
  206. stride_qd,
  207. stride_kbs,
  208. stride_kh,
  209. stride_kd,
  210. stride_vbs,
  211. stride_vh,
  212. stride_vd,
  213. stride_obs,
  214. stride_oh,
  215. stride_od,
  216. stride_k_cache_bs,
  217. stride_k_cache_h,
  218. stride_k_cache_d,
  219. stride_k_cache_bl,
  220. stride_k_cache_x,
  221. stride_v_cache_bs,
  222. stride_v_cache_h,
  223. stride_v_cache_d,
  224. stride_v_cache_bl,
  225. num_queries_per_kv: int,
  226. BLOCK_M: tl.constexpr,
  227. BLOCK_DMODEL: tl.constexpr,
  228. BLOCK_N: tl.constexpr,
  229. ):
  230. cur_batch = tl.program_id(0)
  231. cur_head = tl.program_id(1)
  232. start_m = tl.program_id(2)
  233. cur_kv_head = cur_head // num_queries_per_kv
  234. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  235. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  236. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  237. block_start_loc = BLOCK_M * start_m
  238. # initialize offsets
  239. offs_n = tl.arange(0, BLOCK_N)
  240. offs_d = tl.arange(0, BLOCK_DMODEL)
  241. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  242. off_q = (
  243. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  244. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  245. q = tl.load(
  246. Q + off_q,
  247. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
  248. other=0.0)
  249. # # initialize pointer to m and l
  250. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  251. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  252. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  253. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  254. start_n = tl.multiple_of(start_n, BLOCK_N)
  255. # -- compute qk ----
  256. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  257. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  258. mask=(start_n + offs_n) < cur_batch_ctx_len,
  259. other=0)
  260. off_k = (bn[None, :] * stride_k_cache_bs +
  261. cur_kv_head * stride_k_cache_h +
  262. (offs_d[:, None] // x) * stride_k_cache_d +
  263. ((start_n + offs_n[None, :]) % block_size) *
  264. stride_k_cache_bl +
  265. (offs_d[:, None] % x) * stride_k_cache_x)
  266. off_v = (
  267. bn[:, None] * stride_v_cache_bs +
  268. cur_kv_head * stride_v_cache_h +
  269. offs_d[None, :] * stride_v_cache_d +
  270. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  271. k = tl.load(K_cache + off_k,
  272. mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
  273. other=0.0)
  274. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  275. qk += tl.dot(q, k)
  276. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  277. float("-inf"))
  278. qk *= sm_scale
  279. # -- compute m_ij, p, l_ij
  280. m_ij = tl.max(qk, 1)
  281. m_i_new = tl.maximum(m_i, m_ij)
  282. p = tl.math.exp(qk - m_i_new[:, None])
  283. l_ij = tl.sum(p, 1)
  284. # -- update m_i and l_i
  285. alpha = tl.math.exp(m_i - m_i_new)
  286. l_i_new = alpha * l_i + l_ij
  287. # -- update output accumulator --
  288. # scale p
  289. # scale acc
  290. acc_scale = alpha
  291. # acc_scale = l_i / l_i_new * alpha
  292. acc = acc * acc_scale[:, None]
  293. # update acc
  294. v = tl.load(V_cache + off_v,
  295. mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
  296. other=0.0)
  297. p = p.to(v.dtype)
  298. acc += tl.dot(p, v)
  299. # update m_i and l_i
  300. l_i = l_i_new
  301. m_i = m_i_new
  302. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  303. offs_d[:, None] * stride_kd)
  304. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  305. offs_d[None, :] * stride_vd)
  306. k_ptrs = K + off_k
  307. v_ptrs = V + off_v
  308. block_mask = tl.where(
  309. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  310. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  311. start_n = tl.multiple_of(start_n, BLOCK_N)
  312. # -- compute qk ----
  313. k = tl.load(k_ptrs +
  314. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  315. mask=(start_n + offs_n[None, :]) <
  316. cur_batch_seq_len - cur_batch_ctx_len,
  317. other=0.0)
  318. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  319. qk += tl.dot(q, k)
  320. qk *= sm_scale
  321. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  322. float("-inf"))
  323. # -- compute m_ij, p, l_ij
  324. m_ij = tl.max(qk, 1)
  325. m_i_new = tl.maximum(m_i, m_ij)
  326. p = tl.math.exp(qk - m_i_new[:, None])
  327. l_ij = tl.sum(p, 1)
  328. # -- update m_i and l_i
  329. alpha = tl.math.exp(m_i - m_i_new)
  330. l_i_new = alpha * l_i + l_ij
  331. # -- update output accumulator --
  332. # scale p
  333. # scale acc
  334. acc_scale = alpha
  335. # acc_scale = l_i / l_i_new * alpha
  336. acc = acc * acc_scale[:, None]
  337. # update acc
  338. v = tl.load(v_ptrs +
  339. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  340. mask=(start_n + offs_n[:, None]) <
  341. cur_batch_seq_len - cur_batch_ctx_len,
  342. other=0.0)
  343. p = p.to(v.dtype)
  344. acc += tl.dot(p, v)
  345. # update m_i and l_i
  346. l_i = l_i_new
  347. m_i = m_i_new
  348. # acc /= l_i[:, None]
  349. # initialize pointers to output
  350. off_o = (
  351. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  352. cur_head * stride_oh + offs_d[None, :] * stride_od)
  353. out_ptrs = Out + off_o
  354. tl.store(out_ptrs,
  355. acc,
  356. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
  357. return
  358. @triton.jit
  359. def _fwd_kernel_alibi(
  360. Q,
  361. K,
  362. V,
  363. K_cache,
  364. V_cache,
  365. B_Loc,
  366. sm_scale,
  367. B_Start_Loc,
  368. B_Seqlen,
  369. B_Ctxlen,
  370. Alibi_slopes,
  371. block_size,
  372. x,
  373. Out,
  374. stride_b_loc_b,
  375. stride_b_loc_s,
  376. stride_qbs,
  377. stride_qh,
  378. stride_qd,
  379. stride_kbs,
  380. stride_kh,
  381. stride_kd,
  382. stride_vbs,
  383. stride_vh,
  384. stride_vd,
  385. stride_obs,
  386. stride_oh,
  387. stride_od,
  388. stride_k_cache_bs,
  389. stride_k_cache_h,
  390. stride_k_cache_d,
  391. stride_k_cache_bl,
  392. stride_k_cache_x,
  393. stride_v_cache_bs,
  394. stride_v_cache_h,
  395. stride_v_cache_d,
  396. stride_v_cache_bl,
  397. num_queries_per_kv: int,
  398. BLOCK_M: tl.constexpr,
  399. BLOCK_DMODEL: tl.constexpr,
  400. BLOCK_N: tl.constexpr,
  401. ):
  402. # attn_bias[]
  403. cur_batch = tl.program_id(0)
  404. cur_head = tl.program_id(1)
  405. start_m = tl.program_id(2)
  406. cur_kv_head = cur_head // num_queries_per_kv
  407. # cur_batch_seq_len: the length of prompts
  408. # cur_batch_ctx_len: the length of prefix
  409. # cur_batch_in_all_start_index: the start id of the dim=0
  410. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  411. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  412. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  413. block_start_loc = BLOCK_M * start_m
  414. # initialize offsets
  415. offs_n = tl.arange(0, BLOCK_N)
  416. offs_d = tl.arange(0, BLOCK_DMODEL)
  417. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  418. off_q = (
  419. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  420. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  421. q = tl.load(
  422. Q + off_q,
  423. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
  424. other=0.0)
  425. # # initialize pointer to m and l
  426. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  427. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  428. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  429. alibi_slope = tl.load(Alibi_slopes + cur_head)
  430. alibi_start_q = tl.arange(
  431. 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
  432. alibi_start_k = 0
  433. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  434. start_n = tl.multiple_of(start_n, BLOCK_N)
  435. # -- compute qk ----
  436. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  437. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  438. mask=(start_n + offs_n) < cur_batch_ctx_len,
  439. other=0)
  440. off_k = (bn[None, :] * stride_k_cache_bs +
  441. cur_kv_head * stride_k_cache_h +
  442. (offs_d[:, None] // x) * stride_k_cache_d +
  443. ((start_n + offs_n[None, :]) % block_size) *
  444. stride_k_cache_bl +
  445. (offs_d[:, None] % x) * stride_k_cache_x)
  446. off_v = (
  447. bn[:, None] * stride_v_cache_bs +
  448. cur_kv_head * stride_v_cache_h +
  449. offs_d[None, :] * stride_v_cache_d +
  450. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  451. k = tl.load(K_cache + off_k,
  452. mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
  453. other=0.0)
  454. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  455. qk += tl.dot(q, k)
  456. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  457. float("-inf"))
  458. qk *= sm_scale
  459. # load alibi
  460. alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
  461. alibi_start_q[:, None]) * alibi_slope
  462. alibi = tl.where(
  463. (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
  464. alibi, float("-inf"))
  465. qk += alibi
  466. alibi_start_k += BLOCK_N
  467. # -- compute m_ij, p, l_ij
  468. m_ij = tl.max(qk, 1)
  469. m_i_new = tl.maximum(m_i, m_ij)
  470. p = tl.math.exp(qk - m_i_new[:, None])
  471. l_ij = tl.sum(p, 1)
  472. # -- update m_i and l_i
  473. alpha = tl.math.exp(m_i - m_i_new)
  474. l_i_new = alpha * l_i + l_ij
  475. # -- update output accumulator --
  476. # scale p
  477. # scale acc
  478. acc_scale = alpha
  479. # acc_scale = l_i / l_i_new * alpha
  480. acc = acc * acc_scale[:, None]
  481. # update acc
  482. v = tl.load(V_cache + off_v,
  483. mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
  484. other=0.0)
  485. p = p.to(v.dtype)
  486. acc += tl.dot(p, v, allow_tf32=False)
  487. # update m_i and l_i
  488. l_i = l_i_new
  489. m_i = m_i_new
  490. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  491. offs_d[:, None] * stride_kd)
  492. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  493. offs_d[None, :] * stride_vd)
  494. k_ptrs = K + off_k
  495. v_ptrs = V + off_v
  496. block_mask = tl.where(
  497. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  498. # init alibi
  499. alibi_slope = tl.load(Alibi_slopes + cur_head)
  500. alibi_start_q = tl.arange(
  501. 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
  502. alibi_start_k = cur_batch_ctx_len
  503. # # init debugger
  504. # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
  505. # offset_db_k = tl.arange(0, BLOCK_N)
  506. # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
  507. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  508. start_n = tl.multiple_of(start_n, BLOCK_N)
  509. # -- compute qk ----
  510. k = tl.load(k_ptrs +
  511. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  512. mask=(start_n + offs_n[None, :]) <
  513. cur_batch_seq_len - cur_batch_ctx_len,
  514. other=0.0)
  515. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  516. qk += tl.dot(q, k, allow_tf32=False)
  517. qk *= sm_scale
  518. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  519. float("-inf"))
  520. # load alibi
  521. alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
  522. alibi_start_q[:, None]) * alibi_slope
  523. alibi = tl.where(
  524. (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
  525. alibi, float("-inf"))
  526. qk += alibi
  527. alibi_start_k += BLOCK_N
  528. # -- compute m_ij, p, l_ij
  529. m_ij = tl.max(qk, 1)
  530. m_i_new = tl.maximum(m_i, m_ij)
  531. p = tl.math.exp(qk - m_i_new[:, None])
  532. l_ij = tl.sum(p, 1)
  533. # -- update m_i and l_i
  534. alpha = tl.math.exp(m_i - m_i_new)
  535. l_i_new = alpha * l_i + l_ij
  536. # -- update output accumulator --
  537. # scale p
  538. # scale acc
  539. acc_scale = alpha
  540. # acc_scale = l_i / l_i_new * alpha
  541. acc = acc * acc_scale[:, None]
  542. # update acc
  543. v = tl.load(v_ptrs +
  544. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  545. mask=(start_n + offs_n[:, None]) <
  546. cur_batch_seq_len - cur_batch_ctx_len,
  547. other=0.0)
  548. p = p.to(v.dtype)
  549. acc += tl.dot(p, v, allow_tf32=False)
  550. # update m_i and l_i
  551. l_i = l_i_new
  552. m_i = m_i_new
  553. acc = acc / l_i[:, None]
  554. # initialize pointers to output
  555. off_o = (
  556. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  557. cur_head * stride_oh + offs_d[None, :] * stride_od)
  558. out_ptrs = Out + off_o
  559. tl.store(out_ptrs,
  560. acc,
  561. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
  562. return
  563. @torch.inference_mode()
  564. def context_attention_fwd(q,
  565. k,
  566. v,
  567. o,
  568. k_cache,
  569. v_cache,
  570. b_loc,
  571. b_start_loc,
  572. b_seq_len,
  573. b_ctx_len,
  574. max_input_len,
  575. alibi_slopes=None):
  576. cap = torch.cuda.get_device_capability()
  577. BLOCK = 128 if cap[0] >= 8 else 64
  578. # shape constraints
  579. Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
  580. assert Lq == Lk and Lk == Lv
  581. # round up Lk to a power of 2 - this is required for Triton block size
  582. Lk_padded = 2**((Lk - 1).bit_length())
  583. sm_scale = 1.0 / (Lq**0.5)
  584. batch, head = b_seq_len.shape[0], q.shape[1]
  585. num_queries_per_kv = q.shape[1] // k.shape[1]
  586. grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
  587. num_warps = 8 if Lk <= 64 else 8
  588. if alibi_slopes is not None:
  589. assert Lk == Lk_padded
  590. _fwd_kernel_alibi[grid](
  591. q,
  592. k,
  593. v,
  594. k_cache,
  595. v_cache,
  596. b_loc,
  597. sm_scale,
  598. b_start_loc,
  599. b_seq_len,
  600. b_ctx_len,
  601. alibi_slopes,
  602. v_cache.shape[3],
  603. 8,
  604. o,
  605. b_loc.stride(0),
  606. b_loc.stride(1),
  607. q.stride(0),
  608. q.stride(1),
  609. q.stride(2),
  610. k.stride(0),
  611. k.stride(1),
  612. k.stride(2),
  613. v.stride(0),
  614. v.stride(1),
  615. v.stride(2),
  616. o.stride(0),
  617. o.stride(1),
  618. o.stride(2),
  619. k_cache.stride(0),
  620. k_cache.stride(1),
  621. k_cache.stride(2),
  622. k_cache.stride(3),
  623. k_cache.stride(
  624. 4
  625. ), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
  626. v_cache.stride(0),
  627. v_cache.stride(1),
  628. v_cache.stride(2),
  629. v_cache.stride(
  630. 3), #[num_blocks, num_kv_heads, head_size, block_size]
  631. num_queries_per_kv=num_queries_per_kv,
  632. BLOCK_M=BLOCK,
  633. BLOCK_DMODEL=Lk,
  634. BLOCK_N=BLOCK,
  635. num_warps=num_warps,
  636. num_stages=1,
  637. )
  638. return
  639. _fwd_kernel[grid](
  640. q,
  641. k,
  642. v,
  643. k_cache,
  644. v_cache,
  645. b_loc,
  646. sm_scale,
  647. b_start_loc,
  648. b_seq_len,
  649. b_ctx_len,
  650. v_cache.shape[3],
  651. 8,
  652. o,
  653. b_loc.stride(0),
  654. b_loc.stride(1),
  655. q.stride(0),
  656. q.stride(1),
  657. q.stride(2),
  658. k.stride(0),
  659. k.stride(1),
  660. k.stride(2),
  661. v.stride(0),
  662. v.stride(1),
  663. v.stride(2),
  664. o.stride(0),
  665. o.stride(1),
  666. o.stride(2),
  667. k_cache.stride(0),
  668. k_cache.stride(1),
  669. k_cache.stride(2),
  670. k_cache.stride(3),
  671. k_cache.stride(
  672. 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
  673. v_cache.stride(0),
  674. v_cache.stride(1),
  675. v_cache.stride(2),
  676. v_cache.stride(
  677. 3), #[num_blocks, num_kv_heads, head_size, block_size]
  678. num_queries_per_kv=num_queries_per_kv,
  679. BLOCK_M=BLOCK,
  680. BLOCK_DMODEL=Lk,
  681. BLOCK_DMODEL_PADDED=Lk_padded,
  682. BLOCK_N=BLOCK,
  683. num_warps=num_warps,
  684. num_stages=1,
  685. )
  686. return