prefix_prefill.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. # The kernels in this file are adapted from LightLLM's context_attention_fwd:
  2. # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
  3. import torch
  4. import triton
  5. import triton.language as tl
  6. from aphrodite.platforms import current_platform
  7. if triton.__version__ >= "2.1.0":
  8. @triton.jit
  9. def _fwd_kernel(
  10. Q,
  11. K,
  12. V,
  13. K_cache,
  14. V_cache,
  15. B_Loc,
  16. sm_scale,
  17. B_Start_Loc,
  18. B_Seqlen,
  19. B_Ctxlen,
  20. block_size,
  21. x,
  22. Out,
  23. stride_b_loc_b,
  24. stride_b_loc_s,
  25. stride_qbs,
  26. stride_qh,
  27. stride_qd,
  28. stride_kbs,
  29. stride_kh,
  30. stride_kd,
  31. stride_vbs,
  32. stride_vh,
  33. stride_vd,
  34. stride_obs,
  35. stride_oh,
  36. stride_od,
  37. stride_k_cache_bs,
  38. stride_k_cache_h,
  39. stride_k_cache_d,
  40. stride_k_cache_bl,
  41. stride_k_cache_x,
  42. stride_v_cache_bs,
  43. stride_v_cache_h,
  44. stride_v_cache_d,
  45. stride_v_cache_bl,
  46. num_queries_per_kv: int,
  47. BLOCK_M: tl.constexpr,
  48. BLOCK_DMODEL: tl.constexpr, # head size
  49. BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
  50. BLOCK_N: tl.constexpr,
  51. SLIDING_WINDOW: tl.constexpr,
  52. ):
  53. cur_batch = tl.program_id(0)
  54. cur_head = tl.program_id(1)
  55. start_m = tl.program_id(2)
  56. cur_kv_head = cur_head // num_queries_per_kv
  57. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  58. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  59. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  60. cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
  61. # start position inside of the query
  62. # generally, N goes over kv, while M goes over query_len
  63. block_start_loc = BLOCK_M * start_m
  64. # initialize offsets
  65. # [N]; starts at 0
  66. offs_n = tl.arange(0, BLOCK_N)
  67. # [D]; starts at 0
  68. offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
  69. # [M]; starts at current position in query
  70. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  71. # [M,D]
  72. off_q = (
  73. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  74. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  75. dim_mask = tl.where(
  76. tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
  77. 0).to(tl.int1) # [D]
  78. q = tl.load(Q + off_q,
  79. mask=dim_mask[None, :] &
  80. (offs_m[:, None] < cur_batch_query_len),
  81. other=0.0) # [M,D]
  82. # initialize pointer to m and l
  83. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M]
  84. l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M]
  85. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
  86. dtype=tl.float32) # [M,D]
  87. # compute query against context (no causal mask here)
  88. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  89. start_n = tl.multiple_of(start_n, BLOCK_N)
  90. # -- compute qk ----
  91. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  92. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  93. mask=(start_n + offs_n) < cur_batch_ctx_len,
  94. other=0) # [N]
  95. # [D,N]
  96. off_k = (bn[None, :] * stride_k_cache_bs +
  97. cur_kv_head * stride_k_cache_h +
  98. (offs_d[:, None] // x) * stride_k_cache_d +
  99. ((start_n + offs_n[None, :]) % block_size) *
  100. stride_k_cache_bl +
  101. (offs_d[:, None] % x) * stride_k_cache_x)
  102. # [N,D]
  103. off_v = (
  104. bn[:, None] * stride_v_cache_bs +
  105. cur_kv_head * stride_v_cache_h +
  106. offs_d[None, :] * stride_v_cache_d +
  107. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  108. k = tl.load(K_cache + off_k,
  109. mask=dim_mask[:, None] &
  110. ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
  111. other=0.0) # [D,N]
  112. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
  113. qk += tl.dot(q, k)
  114. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  115. float("-inf"))
  116. qk *= sm_scale
  117. if SLIDING_WINDOW > 0:
  118. # (cur_batch_ctx_len + offs_m[:, None]) are the positions of
  119. # Q entries in sequence
  120. # (start_n + offs_n[None, :]) are the positions of
  121. # KV entries in sequence
  122. # So the condition makes sure each entry in Q only attends
  123. # to KV entries not more than SLIDING_WINDOW away.
  124. #
  125. # We can't use -inf here, because the
  126. # sliding window may lead to the entire row being masked.
  127. # This then makes m_ij contain -inf, which causes NaNs in
  128. # exp().
  129. qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
  130. (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
  131. -10000)
  132. # -- compute m_ij, p, l_ij
  133. m_ij = tl.max(qk, 1) # [M]
  134. p = tl.exp(qk - m_ij[:, None]) # [M,N]
  135. l_ij = tl.sum(p, 1) # [M]
  136. # -- update m_i and l_i
  137. m_i_new = tl.maximum(m_i, m_ij) # [M]
  138. alpha = tl.exp(m_i - m_i_new) # [M]
  139. beta = tl.exp(m_ij - m_i_new) # [M]
  140. l_i_new = alpha * l_i + beta * l_ij # [M]
  141. # -- update output accumulator --
  142. # scale p
  143. p_scale = beta / l_i_new
  144. p = p * p_scale[:, None]
  145. # scale acc
  146. acc_scale = l_i / l_i_new * alpha
  147. acc = acc * acc_scale[:, None]
  148. # update acc
  149. v = tl.load(V_cache + off_v,
  150. mask=dim_mask[None, :] &
  151. ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
  152. other=0.0) # [N,D]
  153. p = p.to(v.dtype)
  154. acc += tl.dot(p, v)
  155. # # update m_i and l_i
  156. l_i = l_i_new
  157. m_i = m_i_new
  158. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  159. offs_d[:, None] * stride_kd)
  160. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  161. offs_d[None, :] * stride_vd)
  162. k_ptrs = K + off_k
  163. v_ptrs = V + off_v
  164. # block_mask is 0 when we're already past the current query length
  165. block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
  166. # compute query against itself (with causal mask)
  167. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  168. start_n = tl.multiple_of(start_n, BLOCK_N)
  169. # -- compute qk ----
  170. k = tl.load(k_ptrs +
  171. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  172. mask=dim_mask[:, None] &
  173. ((start_n + offs_n[None, :]) < cur_batch_query_len),
  174. other=0.0)
  175. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  176. qk += tl.dot(q, k)
  177. qk *= sm_scale
  178. # apply causal mask
  179. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  180. float("-inf"))
  181. if SLIDING_WINDOW > 0:
  182. qk = tl.where(
  183. offs_m[:, None] -
  184. (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000)
  185. # -- compute m_ij, p, l_ij
  186. m_ij = tl.max(qk, 1)
  187. p = tl.exp(qk - m_ij[:, None])
  188. l_ij = tl.sum(p, 1)
  189. # -- update m_i and l_i
  190. m_i_new = tl.maximum(m_i, m_ij)
  191. alpha = tl.exp(m_i - m_i_new)
  192. beta = tl.exp(m_ij - m_i_new)
  193. l_i_new = alpha * l_i + beta * l_ij
  194. # -- update output accumulator --
  195. # scale p
  196. p_scale = beta / l_i_new
  197. p = p * p_scale[:, None]
  198. # scale acc
  199. acc_scale = l_i / l_i_new * alpha
  200. acc = acc * acc_scale[:, None]
  201. # update acc
  202. v = tl.load(v_ptrs +
  203. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  204. mask=dim_mask[None, :] &
  205. ((start_n + offs_n[:, None]) < cur_batch_query_len),
  206. other=0.0)
  207. p = p.to(v.dtype)
  208. acc += tl.dot(p, v)
  209. # update m_i and l_i
  210. l_i = l_i_new
  211. m_i = m_i_new
  212. # initialize pointers to output
  213. off_o = (
  214. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  215. cur_head * stride_oh + offs_d[None, :] * stride_od)
  216. out_ptrs = Out + off_o
  217. tl.store(out_ptrs,
  218. acc,
  219. mask=dim_mask[None, :] &
  220. (offs_m[:, None] < cur_batch_query_len))
  221. return
  222. @triton.jit
  223. def _fwd_kernel_flash_attn_v2(
  224. Q,
  225. K,
  226. V,
  227. K_cache,
  228. V_cache,
  229. B_Loc,
  230. sm_scale,
  231. B_Start_Loc,
  232. B_Seqlen,
  233. B_Ctxlen,
  234. block_size,
  235. x,
  236. Out,
  237. stride_b_loc_b,
  238. stride_b_loc_s,
  239. stride_qbs,
  240. stride_qh,
  241. stride_qd,
  242. stride_kbs,
  243. stride_kh,
  244. stride_kd,
  245. stride_vbs,
  246. stride_vh,
  247. stride_vd,
  248. stride_obs,
  249. stride_oh,
  250. stride_od,
  251. stride_k_cache_bs,
  252. stride_k_cache_h,
  253. stride_k_cache_d,
  254. stride_k_cache_bl,
  255. stride_k_cache_x,
  256. stride_v_cache_bs,
  257. stride_v_cache_h,
  258. stride_v_cache_d,
  259. stride_v_cache_bl,
  260. num_queries_per_kv: int,
  261. BLOCK_M: tl.constexpr,
  262. BLOCK_DMODEL: tl.constexpr,
  263. BLOCK_N: tl.constexpr,
  264. ):
  265. cur_batch = tl.program_id(0)
  266. cur_head = tl.program_id(1)
  267. start_m = tl.program_id(2)
  268. cur_kv_head = cur_head // num_queries_per_kv
  269. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  270. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  271. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  272. block_start_loc = BLOCK_M * start_m
  273. # initialize offsets
  274. offs_n = tl.arange(0, BLOCK_N)
  275. offs_d = tl.arange(0, BLOCK_DMODEL)
  276. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  277. off_q = (
  278. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  279. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  280. q = tl.load(
  281. Q + off_q,
  282. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
  283. other=0.0)
  284. # # initialize pointer to m and l
  285. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  286. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  287. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  288. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  289. start_n = tl.multiple_of(start_n, BLOCK_N)
  290. # -- compute qk ----
  291. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  292. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  293. mask=(start_n + offs_n) < cur_batch_ctx_len,
  294. other=0)
  295. off_k = (bn[None, :] * stride_k_cache_bs +
  296. cur_kv_head * stride_k_cache_h +
  297. (offs_d[:, None] // x) * stride_k_cache_d +
  298. ((start_n + offs_n[None, :]) % block_size) *
  299. stride_k_cache_bl +
  300. (offs_d[:, None] % x) * stride_k_cache_x)
  301. off_v = (
  302. bn[:, None] * stride_v_cache_bs +
  303. cur_kv_head * stride_v_cache_h +
  304. offs_d[None, :] * stride_v_cache_d +
  305. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  306. k = tl.load(K_cache + off_k,
  307. mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
  308. other=0.0)
  309. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  310. qk += tl.dot(q, k)
  311. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  312. float("-inf"))
  313. qk *= sm_scale
  314. # -- compute m_ij, p, l_ij
  315. m_ij = tl.max(qk, 1)
  316. m_i_new = tl.maximum(m_i, m_ij)
  317. p = tl.math.exp(qk - m_i_new[:, None])
  318. l_ij = tl.sum(p, 1)
  319. # -- update m_i and l_i
  320. alpha = tl.math.exp(m_i - m_i_new)
  321. l_i_new = alpha * l_i + l_ij
  322. # -- update output accumulator --
  323. # scale p
  324. # scale acc
  325. acc_scale = alpha
  326. # acc_scale = l_i / l_i_new * alpha
  327. acc = acc * acc_scale[:, None]
  328. # update acc
  329. v = tl.load(V_cache + off_v,
  330. mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
  331. other=0.0)
  332. p = p.to(v.dtype)
  333. acc += tl.dot(p, v)
  334. # update m_i and l_i
  335. l_i = l_i_new
  336. m_i = m_i_new
  337. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  338. offs_d[:, None] * stride_kd)
  339. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  340. offs_d[None, :] * stride_vd)
  341. k_ptrs = K + off_k
  342. v_ptrs = V + off_v
  343. block_mask = tl.where(
  344. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  345. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  346. start_n = tl.multiple_of(start_n, BLOCK_N)
  347. # -- compute qk ----
  348. k = tl.load(k_ptrs +
  349. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  350. mask=(start_n + offs_n[None, :]) <
  351. cur_batch_seq_len - cur_batch_ctx_len,
  352. other=0.0)
  353. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  354. qk += tl.dot(q, k)
  355. qk *= sm_scale
  356. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  357. float("-inf"))
  358. # -- compute m_ij, p, l_ij
  359. m_ij = tl.max(qk, 1)
  360. m_i_new = tl.maximum(m_i, m_ij)
  361. p = tl.math.exp(qk - m_i_new[:, None])
  362. l_ij = tl.sum(p, 1)
  363. # -- update m_i and l_i
  364. alpha = tl.math.exp(m_i - m_i_new)
  365. l_i_new = alpha * l_i + l_ij
  366. # -- update output accumulator --
  367. # scale p
  368. # scale acc
  369. acc_scale = alpha
  370. # acc_scale = l_i / l_i_new * alpha
  371. acc = acc * acc_scale[:, None]
  372. # update acc
  373. v = tl.load(v_ptrs +
  374. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  375. mask=(start_n + offs_n[:, None]) <
  376. cur_batch_seq_len - cur_batch_ctx_len,
  377. other=0.0)
  378. p = p.to(v.dtype)
  379. acc += tl.dot(p, v)
  380. # update m_i and l_i
  381. l_i = l_i_new
  382. m_i = m_i_new
  383. # acc /= l_i[:, None]
  384. # initialize pointers to output
  385. off_o = (
  386. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  387. cur_head * stride_oh + offs_d[None, :] * stride_od)
  388. out_ptrs = Out + off_o
  389. tl.store(out_ptrs,
  390. acc,
  391. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
  392. return
  393. @triton.jit
  394. def _fwd_kernel_alibi(
  395. Q,
  396. K,
  397. V,
  398. K_cache,
  399. V_cache,
  400. B_Loc,
  401. sm_scale,
  402. B_Start_Loc,
  403. B_Seqlen,
  404. B_Ctxlen,
  405. Alibi_slopes,
  406. block_size,
  407. x,
  408. Out,
  409. stride_b_loc_b,
  410. stride_b_loc_s,
  411. stride_qbs,
  412. stride_qh,
  413. stride_qd,
  414. stride_kbs,
  415. stride_kh,
  416. stride_kd,
  417. stride_vbs,
  418. stride_vh,
  419. stride_vd,
  420. stride_obs,
  421. stride_oh,
  422. stride_od,
  423. stride_k_cache_bs,
  424. stride_k_cache_h,
  425. stride_k_cache_d,
  426. stride_k_cache_bl,
  427. stride_k_cache_x,
  428. stride_v_cache_bs,
  429. stride_v_cache_h,
  430. stride_v_cache_d,
  431. stride_v_cache_bl,
  432. num_queries_per_kv: int,
  433. BLOCK_M: tl.constexpr,
  434. BLOCK_DMODEL: tl.constexpr, # head size
  435. BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
  436. BLOCK_N: tl.constexpr,
  437. ):
  438. # attn_bias[]
  439. cur_batch = tl.program_id(0)
  440. cur_head = tl.program_id(1)
  441. start_m = tl.program_id(2)
  442. cur_kv_head = cur_head // num_queries_per_kv
  443. # cur_batch_seq_len: the length of prompts
  444. # cur_batch_ctx_len: the length of prefix
  445. # cur_batch_in_all_start_index: the start id of the dim=0
  446. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  447. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  448. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  449. block_start_loc = BLOCK_M * start_m
  450. # initialize offsets
  451. offs_n = tl.arange(0, BLOCK_N)
  452. offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
  453. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  454. off_q = (
  455. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  456. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  457. dim_mask = tl.where(
  458. tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
  459. q = tl.load(Q + off_q,
  460. mask=dim_mask[None, :] &
  461. (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
  462. other=0.0)
  463. # # initialize pointer to m and l
  464. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  465. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  466. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
  467. alibi_slope = tl.load(Alibi_slopes + cur_head)
  468. alibi_start_q = tl.arange(
  469. 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
  470. alibi_start_k = 0
  471. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  472. start_n = tl.multiple_of(start_n, BLOCK_N)
  473. # -- compute qk ----
  474. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  475. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  476. mask=(start_n + offs_n) < cur_batch_ctx_len,
  477. other=0)
  478. off_k = (bn[None, :] * stride_k_cache_bs +
  479. cur_kv_head * stride_k_cache_h +
  480. (offs_d[:, None] // x) * stride_k_cache_d +
  481. ((start_n + offs_n[None, :]) % block_size) *
  482. stride_k_cache_bl +
  483. (offs_d[:, None] % x) * stride_k_cache_x)
  484. off_v = (
  485. bn[:, None] * stride_v_cache_bs +
  486. cur_kv_head * stride_v_cache_h +
  487. offs_d[None, :] * stride_v_cache_d +
  488. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  489. k = tl.load(K_cache + off_k,
  490. mask=dim_mask[:, None] &
  491. ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
  492. other=0.0) # [D,N]
  493. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  494. qk += tl.dot(q, k)
  495. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  496. float("-inf"))
  497. qk *= sm_scale
  498. # load alibi
  499. alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
  500. alibi_start_q[:, None]) * alibi_slope
  501. alibi = tl.where(
  502. (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
  503. alibi, float("-inf"))
  504. qk += alibi
  505. alibi_start_k += BLOCK_N
  506. # -- compute m_ij, p, l_ij
  507. m_ij = tl.max(qk, 1)
  508. m_i_new = tl.maximum(m_i, m_ij)
  509. p = tl.math.exp(qk - m_i_new[:, None])
  510. l_ij = tl.sum(p, 1)
  511. # -- update m_i and l_i
  512. alpha = tl.math.exp(m_i - m_i_new)
  513. l_i_new = alpha * l_i + l_ij
  514. # -- update output accumulator --
  515. # scale p
  516. # scale acc
  517. acc_scale = alpha
  518. # acc_scale = l_i / l_i_new * alpha
  519. acc = acc * acc_scale[:, None]
  520. # update acc
  521. v = tl.load(V_cache + off_v,
  522. mask=dim_mask[None, :] &
  523. ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
  524. other=0.0)
  525. p = p.to(v.dtype)
  526. acc += tl.dot(p, v, allow_tf32=False)
  527. # update m_i and l_i
  528. l_i = l_i_new
  529. m_i = m_i_new
  530. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  531. offs_d[:, None] * stride_kd)
  532. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  533. offs_d[None, :] * stride_vd)
  534. k_ptrs = K + off_k
  535. v_ptrs = V + off_v
  536. block_mask = tl.where(
  537. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  538. # init alibi
  539. alibi_slope = tl.load(Alibi_slopes + cur_head)
  540. alibi_start_q = tl.arange(
  541. 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
  542. alibi_start_k = cur_batch_ctx_len
  543. # # init debugger
  544. # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
  545. # offset_db_k = tl.arange(0, BLOCK_N)
  546. # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
  547. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  548. start_n = tl.multiple_of(start_n, BLOCK_N)
  549. # -- compute qk ----
  550. k = tl.load(k_ptrs +
  551. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  552. mask=dim_mask[:, None] &
  553. ((start_n + offs_n[None, :]) <
  554. cur_batch_seq_len - cur_batch_ctx_len),
  555. other=0.0)
  556. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  557. qk += tl.dot(q, k, allow_tf32=False)
  558. qk *= sm_scale
  559. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  560. float("-inf"))
  561. # load alibi
  562. alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
  563. alibi_start_q[:, None]) * alibi_slope
  564. alibi = tl.where(
  565. (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
  566. alibi, float("-inf"))
  567. qk += alibi
  568. alibi_start_k += BLOCK_N
  569. # -- compute m_ij, p, l_ij
  570. m_ij = tl.max(qk, 1)
  571. m_i_new = tl.maximum(m_i, m_ij)
  572. p = tl.math.exp(qk - m_i_new[:, None])
  573. l_ij = tl.sum(p, 1)
  574. # -- update m_i and l_i
  575. alpha = tl.math.exp(m_i - m_i_new)
  576. l_i_new = alpha * l_i + l_ij
  577. # -- update output accumulator --
  578. # scale p
  579. # scale acc
  580. acc_scale = alpha
  581. # acc_scale = l_i / l_i_new * alpha
  582. acc = acc * acc_scale[:, None]
  583. # update acc
  584. v = tl.load(v_ptrs +
  585. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  586. mask=dim_mask[None, :] &
  587. ((start_n + offs_n[:, None]) <
  588. cur_batch_seq_len - cur_batch_ctx_len),
  589. other=0.0)
  590. p = p.to(v.dtype)
  591. acc += tl.dot(p, v, allow_tf32=False)
  592. # update m_i and l_i
  593. l_i = l_i_new
  594. m_i = m_i_new
  595. acc = acc / l_i[:, None]
  596. # initialize pointers to output
  597. off_o = (
  598. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  599. cur_head * stride_oh + offs_d[None, :] * stride_od)
  600. out_ptrs = Out + off_o
  601. tl.store(out_ptrs,
  602. acc,
  603. mask=dim_mask[None, :] &
  604. (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
  605. return
  606. @torch.inference_mode()
  607. def context_attention_fwd(q,
  608. k,
  609. v,
  610. o,
  611. k_cache,
  612. v_cache,
  613. b_loc,
  614. b_start_loc,
  615. b_seq_len,
  616. b_ctx_len,
  617. max_input_len,
  618. alibi_slopes=None,
  619. sliding_window=None):
  620. cap = current_platform.get_device_capability()
  621. BLOCK = 128 if cap[0] >= 8 else 64
  622. # shape constraints
  623. Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
  624. assert Lq == Lk and Lk == Lv
  625. # round up Lk to a power of 2 - this is required for Triton block size
  626. Lk_padded = triton.next_power_of_2(Lk)
  627. sm_scale = 1.0 / (Lq**0.5)
  628. batch, head = b_seq_len.shape[0], q.shape[1]
  629. num_queries_per_kv = q.shape[1] // k.shape[1]
  630. grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
  631. # 0 means "disable"
  632. if sliding_window is None or sliding_window <= 0:
  633. sliding_window = 0
  634. num_warps = 8 if Lk <= 64 else 8
  635. if alibi_slopes is not None:
  636. _fwd_kernel_alibi[grid](
  637. q,
  638. k,
  639. v,
  640. k_cache,
  641. v_cache,
  642. b_loc,
  643. sm_scale,
  644. b_start_loc,
  645. b_seq_len,
  646. b_ctx_len,
  647. alibi_slopes,
  648. v_cache.shape[3],
  649. k_cache.shape[4],
  650. o,
  651. b_loc.stride(0),
  652. b_loc.stride(1),
  653. q.stride(0),
  654. q.stride(1),
  655. q.stride(2),
  656. k.stride(0),
  657. k.stride(1),
  658. k.stride(2),
  659. v.stride(0),
  660. v.stride(1),
  661. v.stride(2),
  662. o.stride(0),
  663. o.stride(1),
  664. o.stride(2),
  665. k_cache.stride(0),
  666. k_cache.stride(1),
  667. k_cache.stride(2),
  668. k_cache.stride(3),
  669. k_cache.stride(
  670. 4
  671. ), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
  672. v_cache.stride(0),
  673. v_cache.stride(1),
  674. v_cache.stride(2),
  675. v_cache.stride(
  676. 3), #[num_blocks, num_kv_heads, head_size, block_size]
  677. num_queries_per_kv=num_queries_per_kv,
  678. BLOCK_M=BLOCK,
  679. BLOCK_DMODEL=Lk,
  680. BLOCK_DMODEL_PADDED=Lk_padded,
  681. BLOCK_N=BLOCK,
  682. num_warps=num_warps,
  683. num_stages=1,
  684. )
  685. return
  686. _fwd_kernel[grid](
  687. q,
  688. k,
  689. v,
  690. k_cache,
  691. v_cache,
  692. b_loc,
  693. sm_scale,
  694. b_start_loc,
  695. b_seq_len,
  696. b_ctx_len,
  697. v_cache.shape[3],
  698. k_cache.shape[4],
  699. o,
  700. b_loc.stride(0),
  701. b_loc.stride(1),
  702. q.stride(0),
  703. q.stride(1),
  704. q.stride(2),
  705. k.stride(0),
  706. k.stride(1),
  707. k.stride(2),
  708. v.stride(0),
  709. v.stride(1),
  710. v.stride(2),
  711. o.stride(0),
  712. o.stride(1),
  713. o.stride(2),
  714. k_cache.stride(0),
  715. k_cache.stride(1),
  716. k_cache.stride(2),
  717. k_cache.stride(3),
  718. k_cache.stride(
  719. 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
  720. v_cache.stride(0),
  721. v_cache.stride(1),
  722. v_cache.stride(2),
  723. v_cache.stride(
  724. 3), #[num_blocks, num_kv_heads, head_size, block_size]
  725. num_queries_per_kv=num_queries_per_kv,
  726. BLOCK_M=BLOCK,
  727. BLOCK_DMODEL=Lk,
  728. BLOCK_DMODEL_PADDED=Lk_padded,
  729. BLOCK_N=BLOCK,
  730. SLIDING_WINDOW=sliding_window,
  731. num_warps=num_warps,
  732. num_stages=1,
  733. )
  734. return