1
0

prefix_prefill.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859
  1. # The kernels in this file are adapted from LightLLM's context_attention_fwd:
  2. # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
  3. import torch
  4. import triton
  5. import triton.language as tl
  6. from aphrodite.platforms import current_platform
  7. if triton.__version__ >= "2.1.0":
  8. @triton.jit
  9. def _fwd_kernel(
  10. Q,
  11. K,
  12. V,
  13. K_cache,
  14. V_cache,
  15. B_Loc,
  16. sm_scale,
  17. k_scale,
  18. v_scale,
  19. B_Start_Loc,
  20. B_Seqlen,
  21. B_Ctxlen,
  22. block_size,
  23. x,
  24. Out,
  25. stride_b_loc_b,
  26. stride_b_loc_s,
  27. stride_qbs,
  28. stride_qh,
  29. stride_qd,
  30. stride_kbs,
  31. stride_kh,
  32. stride_kd,
  33. stride_vbs,
  34. stride_vh,
  35. stride_vd,
  36. stride_obs,
  37. stride_oh,
  38. stride_od,
  39. stride_k_cache_bs,
  40. stride_k_cache_h,
  41. stride_k_cache_d,
  42. stride_k_cache_bl,
  43. stride_k_cache_x,
  44. stride_v_cache_bs,
  45. stride_v_cache_h,
  46. stride_v_cache_d,
  47. stride_v_cache_bl,
  48. num_queries_per_kv: int,
  49. BLOCK_M: tl.constexpr,
  50. BLOCK_DMODEL: tl.constexpr, # head size
  51. BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
  52. BLOCK_N: tl.constexpr,
  53. SLIDING_WINDOW: tl.constexpr,
  54. ):
  55. cur_batch = tl.program_id(0)
  56. cur_head = tl.program_id(1)
  57. start_m = tl.program_id(2)
  58. cur_kv_head = cur_head // num_queries_per_kv
  59. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  60. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  61. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  62. cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
  63. # start position inside of the query
  64. # generally, N goes over kv, while M goes over query_len
  65. block_start_loc = BLOCK_M * start_m
  66. # initialize offsets
  67. # [N]; starts at 0
  68. offs_n = tl.arange(0, BLOCK_N)
  69. # [D]; starts at 0
  70. offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
  71. # [M]; starts at current position in query
  72. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  73. # [M,D]
  74. off_q = (
  75. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  76. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  77. dim_mask = tl.where(
  78. tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
  79. 0).to(tl.int1) # [D]
  80. q = tl.load(Q + off_q,
  81. mask=dim_mask[None, :] &
  82. (offs_m[:, None] < cur_batch_query_len),
  83. other=0.0) # [M,D]
  84. # initialize pointer to m and l
  85. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M]
  86. l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M]
  87. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
  88. dtype=tl.float32) # [M,D]
  89. # compute query against context (no causal mask here)
  90. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  91. start_n = tl.multiple_of(start_n, BLOCK_N)
  92. # -- compute qk ----
  93. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  94. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  95. mask=(start_n + offs_n) < cur_batch_ctx_len,
  96. other=0) # [N]
  97. # [D,N]
  98. off_k = (bn[None, :] * stride_k_cache_bs +
  99. cur_kv_head * stride_k_cache_h +
  100. (offs_d[:, None] // x) * stride_k_cache_d +
  101. ((start_n + offs_n[None, :]) % block_size) *
  102. stride_k_cache_bl +
  103. (offs_d[:, None] % x) * stride_k_cache_x)
  104. # [N,D]
  105. off_v = (
  106. bn[:, None] * stride_v_cache_bs +
  107. cur_kv_head * stride_v_cache_h +
  108. offs_d[None, :] * stride_v_cache_d +
  109. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  110. k_load = tl.load(K_cache + off_k,
  111. mask=dim_mask[:, None] &
  112. ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
  113. other=0.0) # [D,N]
  114. if k_load.dtype.is_fp8():
  115. k = (k_load.to(tl.float32) * k_scale).to(q.dtype)
  116. else:
  117. k = k_load
  118. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
  119. qk += tl.dot(q, k)
  120. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  121. float("-inf"))
  122. qk *= sm_scale
  123. if SLIDING_WINDOW > 0:
  124. # (cur_batch_ctx_len + offs_m[:, None]) are the positions of
  125. # Q entries in sequence
  126. # (start_n + offs_n[None, :]) are the positions of
  127. # KV entries in sequence
  128. # So the condition makes sure each entry in Q only attends
  129. # to KV entries not more than SLIDING_WINDOW away.
  130. #
  131. # We can't use -inf here, because the
  132. # sliding window may lead to the entire row being masked.
  133. # This then makes m_ij contain -inf, which causes NaNs in
  134. # exp().
  135. qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
  136. (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
  137. -10000)
  138. # -- compute m_ij, p, l_ij
  139. m_ij = tl.max(qk, 1) # [M]
  140. p = tl.exp(qk - m_ij[:, None]) # [M,N]
  141. l_ij = tl.sum(p, 1) # [M]
  142. # -- update m_i and l_i
  143. m_i_new = tl.maximum(m_i, m_ij) # [M]
  144. alpha = tl.exp(m_i - m_i_new) # [M]
  145. beta = tl.exp(m_ij - m_i_new) # [M]
  146. l_i_new = alpha * l_i + beta * l_ij # [M]
  147. # -- update output accumulator --
  148. # scale p
  149. p_scale = beta / l_i_new
  150. p = p * p_scale[:, None]
  151. # scale acc
  152. acc_scale = l_i / l_i_new * alpha
  153. acc = acc * acc_scale[:, None]
  154. # update acc
  155. v_load = tl.load(V_cache + off_v,
  156. mask=dim_mask[None, :] &
  157. ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
  158. other=0.0) # [N,D]
  159. if v_load.dtype.is_fp8():
  160. v = (v_load.to(tl.float32) * v_scale).to(q.dtype)
  161. else:
  162. v = v_load
  163. p = p.to(v.dtype)
  164. acc += tl.dot(p, v)
  165. # # update m_i and l_i
  166. l_i = l_i_new
  167. m_i = m_i_new
  168. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  169. offs_d[:, None] * stride_kd)
  170. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  171. offs_d[None, :] * stride_vd)
  172. k_ptrs = K + off_k
  173. v_ptrs = V + off_v
  174. # block_mask is 0 when we're already past the current query length
  175. block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
  176. # compute query against itself (with causal mask)
  177. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  178. start_n = tl.multiple_of(start_n, BLOCK_N)
  179. # -- compute qk ----
  180. k = tl.load(k_ptrs +
  181. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  182. mask=dim_mask[:, None] &
  183. ((start_n + offs_n[None, :]) < cur_batch_query_len),
  184. other=0.0)
  185. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  186. qk += tl.dot(q, k)
  187. qk *= sm_scale
  188. # apply causal mask
  189. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  190. float("-inf"))
  191. if SLIDING_WINDOW > 0:
  192. qk = tl.where(
  193. offs_m[:, None] -
  194. (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000)
  195. # -- compute m_ij, p, l_ij
  196. m_ij = tl.max(qk, 1)
  197. p = tl.exp(qk - m_ij[:, None])
  198. l_ij = tl.sum(p, 1)
  199. # -- update m_i and l_i
  200. m_i_new = tl.maximum(m_i, m_ij)
  201. alpha = tl.exp(m_i - m_i_new)
  202. beta = tl.exp(m_ij - m_i_new)
  203. l_i_new = alpha * l_i + beta * l_ij
  204. # -- update output accumulator --
  205. # scale p
  206. p_scale = beta / l_i_new
  207. p = p * p_scale[:, None]
  208. # scale acc
  209. acc_scale = l_i / l_i_new * alpha
  210. acc = acc * acc_scale[:, None]
  211. # update acc
  212. v = tl.load(v_ptrs +
  213. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  214. mask=dim_mask[None, :] &
  215. ((start_n + offs_n[:, None]) < cur_batch_query_len),
  216. other=0.0)
  217. p = p.to(v.dtype)
  218. acc += tl.dot(p, v)
  219. # update m_i and l_i
  220. l_i = l_i_new
  221. m_i = m_i_new
  222. # initialize pointers to output
  223. off_o = (
  224. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  225. cur_head * stride_oh + offs_d[None, :] * stride_od)
  226. out_ptrs = Out + off_o
  227. tl.store(out_ptrs,
  228. acc,
  229. mask=dim_mask[None, :] &
  230. (offs_m[:, None] < cur_batch_query_len))
  231. return
  232. @triton.jit
  233. def _fwd_kernel_flash_attn_v2(
  234. Q,
  235. K,
  236. V,
  237. K_cache,
  238. V_cache,
  239. B_Loc,
  240. sm_scale,
  241. B_Start_Loc,
  242. B_Seqlen,
  243. B_Ctxlen,
  244. block_size,
  245. x,
  246. Out,
  247. stride_b_loc_b,
  248. stride_b_loc_s,
  249. stride_qbs,
  250. stride_qh,
  251. stride_qd,
  252. stride_kbs,
  253. stride_kh,
  254. stride_kd,
  255. stride_vbs,
  256. stride_vh,
  257. stride_vd,
  258. stride_obs,
  259. stride_oh,
  260. stride_od,
  261. stride_k_cache_bs,
  262. stride_k_cache_h,
  263. stride_k_cache_d,
  264. stride_k_cache_bl,
  265. stride_k_cache_x,
  266. stride_v_cache_bs,
  267. stride_v_cache_h,
  268. stride_v_cache_d,
  269. stride_v_cache_bl,
  270. num_queries_per_kv: int,
  271. BLOCK_M: tl.constexpr,
  272. BLOCK_DMODEL: tl.constexpr,
  273. BLOCK_N: tl.constexpr,
  274. ):
  275. cur_batch = tl.program_id(0)
  276. cur_head = tl.program_id(1)
  277. start_m = tl.program_id(2)
  278. cur_kv_head = cur_head // num_queries_per_kv
  279. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  280. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  281. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  282. block_start_loc = BLOCK_M * start_m
  283. # initialize offsets
  284. offs_n = tl.arange(0, BLOCK_N)
  285. offs_d = tl.arange(0, BLOCK_DMODEL)
  286. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  287. off_q = (
  288. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  289. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  290. q = tl.load(
  291. Q + off_q,
  292. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
  293. other=0.0)
  294. # # initialize pointer to m and l
  295. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  296. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  297. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
  298. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  299. start_n = tl.multiple_of(start_n, BLOCK_N)
  300. # -- compute qk ----
  301. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  302. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  303. mask=(start_n + offs_n) < cur_batch_ctx_len,
  304. other=0)
  305. off_k = (bn[None, :] * stride_k_cache_bs +
  306. cur_kv_head * stride_k_cache_h +
  307. (offs_d[:, None] // x) * stride_k_cache_d +
  308. ((start_n + offs_n[None, :]) % block_size) *
  309. stride_k_cache_bl +
  310. (offs_d[:, None] % x) * stride_k_cache_x)
  311. off_v = (
  312. bn[:, None] * stride_v_cache_bs +
  313. cur_kv_head * stride_v_cache_h +
  314. offs_d[None, :] * stride_v_cache_d +
  315. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  316. k = tl.load(K_cache + off_k,
  317. mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
  318. other=0.0)
  319. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  320. qk += tl.dot(q, k)
  321. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  322. float("-inf"))
  323. qk *= sm_scale
  324. # -- compute m_ij, p, l_ij
  325. m_ij = tl.max(qk, 1)
  326. m_i_new = tl.maximum(m_i, m_ij)
  327. p = tl.math.exp(qk - m_i_new[:, None])
  328. l_ij = tl.sum(p, 1)
  329. # -- update m_i and l_i
  330. alpha = tl.math.exp(m_i - m_i_new)
  331. l_i_new = alpha * l_i + l_ij
  332. # -- update output accumulator --
  333. # scale p
  334. # scale acc
  335. acc_scale = alpha
  336. # acc_scale = l_i / l_i_new * alpha
  337. acc = acc * acc_scale[:, None]
  338. # update acc
  339. v = tl.load(V_cache + off_v,
  340. mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
  341. other=0.0)
  342. p = p.to(v.dtype)
  343. acc += tl.dot(p, v)
  344. # update m_i and l_i
  345. l_i = l_i_new
  346. m_i = m_i_new
  347. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  348. offs_d[:, None] * stride_kd)
  349. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  350. offs_d[None, :] * stride_vd)
  351. k_ptrs = K + off_k
  352. v_ptrs = V + off_v
  353. block_mask = tl.where(
  354. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  355. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  356. start_n = tl.multiple_of(start_n, BLOCK_N)
  357. # -- compute qk ----
  358. k = tl.load(k_ptrs +
  359. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  360. mask=(start_n + offs_n[None, :]) <
  361. cur_batch_seq_len - cur_batch_ctx_len,
  362. other=0.0)
  363. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  364. qk += tl.dot(q, k)
  365. qk *= sm_scale
  366. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  367. float("-inf"))
  368. # -- compute m_ij, p, l_ij
  369. m_ij = tl.max(qk, 1)
  370. m_i_new = tl.maximum(m_i, m_ij)
  371. p = tl.math.exp(qk - m_i_new[:, None])
  372. l_ij = tl.sum(p, 1)
  373. # -- update m_i and l_i
  374. alpha = tl.math.exp(m_i - m_i_new)
  375. l_i_new = alpha * l_i + l_ij
  376. # -- update output accumulator --
  377. # scale p
  378. # scale acc
  379. acc_scale = alpha
  380. # acc_scale = l_i / l_i_new * alpha
  381. acc = acc * acc_scale[:, None]
  382. # update acc
  383. v = tl.load(v_ptrs +
  384. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  385. mask=(start_n + offs_n[:, None]) <
  386. cur_batch_seq_len - cur_batch_ctx_len,
  387. other=0.0)
  388. p = p.to(v.dtype)
  389. acc += tl.dot(p, v)
  390. # update m_i and l_i
  391. l_i = l_i_new
  392. m_i = m_i_new
  393. # acc /= l_i[:, None]
  394. # initialize pointers to output
  395. off_o = (
  396. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  397. cur_head * stride_oh + offs_d[None, :] * stride_od)
  398. out_ptrs = Out + off_o
  399. tl.store(out_ptrs,
  400. acc,
  401. mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
  402. return
  403. @triton.jit
  404. def _fwd_kernel_alibi(
  405. Q,
  406. K,
  407. V,
  408. K_cache,
  409. V_cache,
  410. B_Loc,
  411. sm_scale,
  412. k_scale,
  413. v_scale,
  414. B_Start_Loc,
  415. B_Seqlen,
  416. B_Ctxlen,
  417. Alibi_slopes,
  418. block_size,
  419. x,
  420. Out,
  421. stride_b_loc_b,
  422. stride_b_loc_s,
  423. stride_qbs,
  424. stride_qh,
  425. stride_qd,
  426. stride_kbs,
  427. stride_kh,
  428. stride_kd,
  429. stride_vbs,
  430. stride_vh,
  431. stride_vd,
  432. stride_obs,
  433. stride_oh,
  434. stride_od,
  435. stride_k_cache_bs,
  436. stride_k_cache_h,
  437. stride_k_cache_d,
  438. stride_k_cache_bl,
  439. stride_k_cache_x,
  440. stride_v_cache_bs,
  441. stride_v_cache_h,
  442. stride_v_cache_d,
  443. stride_v_cache_bl,
  444. num_queries_per_kv: int,
  445. BLOCK_M: tl.constexpr,
  446. BLOCK_DMODEL: tl.constexpr, # head size
  447. BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
  448. BLOCK_N: tl.constexpr,
  449. ):
  450. # attn_bias[]
  451. cur_batch = tl.program_id(0)
  452. cur_head = tl.program_id(1)
  453. start_m = tl.program_id(2)
  454. cur_kv_head = cur_head // num_queries_per_kv
  455. # cur_batch_seq_len: the length of prompts
  456. # cur_batch_ctx_len: the length of prefix
  457. # cur_batch_in_all_start_index: the start id of the dim=0
  458. cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
  459. cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
  460. cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
  461. block_start_loc = BLOCK_M * start_m
  462. # initialize offsets
  463. offs_n = tl.arange(0, BLOCK_N)
  464. offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
  465. offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  466. off_q = (
  467. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
  468. cur_head * stride_qh + offs_d[None, :] * stride_qd)
  469. dim_mask = tl.where(
  470. tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
  471. q = tl.load(Q + off_q,
  472. mask=dim_mask[None, :] &
  473. (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
  474. other=0.0)
  475. # # initialize pointer to m and l
  476. m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
  477. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  478. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
  479. alibi_slope = tl.load(Alibi_slopes + cur_head)
  480. alibi_start_q = tl.arange(
  481. 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
  482. alibi_start_k = 0
  483. for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
  484. start_n = tl.multiple_of(start_n, BLOCK_N)
  485. # -- compute qk ----
  486. bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
  487. ((start_n + offs_n) // block_size) * stride_b_loc_s,
  488. mask=(start_n + offs_n) < cur_batch_ctx_len,
  489. other=0)
  490. off_k = (bn[None, :] * stride_k_cache_bs +
  491. cur_kv_head * stride_k_cache_h +
  492. (offs_d[:, None] // x) * stride_k_cache_d +
  493. ((start_n + offs_n[None, :]) % block_size) *
  494. stride_k_cache_bl +
  495. (offs_d[:, None] % x) * stride_k_cache_x)
  496. off_v = (
  497. bn[:, None] * stride_v_cache_bs +
  498. cur_kv_head * stride_v_cache_h +
  499. offs_d[None, :] * stride_v_cache_d +
  500. (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
  501. k_load = tl.load(K_cache + off_k,
  502. mask=dim_mask[:, None] &
  503. ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
  504. other=0.0) # [D,N]
  505. if k_load.dtype.is_fp8():
  506. k = (k_load.to(tl.float32) * k_scale).to(q.dtype)
  507. else:
  508. k = k_load
  509. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  510. qk += tl.dot(q, k)
  511. qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
  512. float("-inf"))
  513. qk *= sm_scale
  514. # load alibi
  515. alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
  516. alibi_start_q[:, None]) * alibi_slope
  517. alibi = tl.where(
  518. (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
  519. alibi, float("-inf"))
  520. qk += alibi
  521. alibi_start_k += BLOCK_N
  522. # -- compute m_ij, p, l_ij
  523. m_ij = tl.max(qk, 1)
  524. m_i_new = tl.maximum(m_i, m_ij)
  525. p = tl.math.exp(qk - m_i_new[:, None])
  526. l_ij = tl.sum(p, 1)
  527. # -- update m_i and l_i
  528. alpha = tl.math.exp(m_i - m_i_new)
  529. l_i_new = alpha * l_i + l_ij
  530. # -- update output accumulator --
  531. # scale p
  532. # scale acc
  533. acc_scale = alpha
  534. # acc_scale = l_i / l_i_new * alpha
  535. acc = acc * acc_scale[:, None]
  536. # update acc
  537. v_load = tl.load(V_cache + off_v,
  538. mask=dim_mask[None, :] &
  539. ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
  540. other=0.0)
  541. if v_load.dtype.is_fp8():
  542. v = (v_load.to(tl.float32) * v_scale).to(q.dtype)
  543. else:
  544. v = v_load
  545. p = p.to(v.dtype)
  546. acc += tl.dot(p, v, allow_tf32=False)
  547. # update m_i and l_i
  548. l_i = l_i_new
  549. m_i = m_i_new
  550. off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
  551. offs_d[:, None] * stride_kd)
  552. off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
  553. offs_d[None, :] * stride_vd)
  554. k_ptrs = K + off_k
  555. v_ptrs = V + off_v
  556. block_mask = tl.where(
  557. block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
  558. # init alibi
  559. alibi_slope = tl.load(Alibi_slopes + cur_head)
  560. alibi_start_q = tl.arange(
  561. 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
  562. alibi_start_k = cur_batch_ctx_len
  563. # # init debugger
  564. # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
  565. # offset_db_k = tl.arange(0, BLOCK_N)
  566. # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
  567. for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
  568. start_n = tl.multiple_of(start_n, BLOCK_N)
  569. # -- compute qk ----
  570. k = tl.load(k_ptrs +
  571. (cur_batch_in_all_start_index + start_n) * stride_kbs,
  572. mask=dim_mask[:, None] &
  573. ((start_n + offs_n[None, :]) <
  574. cur_batch_seq_len - cur_batch_ctx_len),
  575. other=0.0)
  576. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  577. qk += tl.dot(q, k, allow_tf32=False)
  578. qk *= sm_scale
  579. qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
  580. float("-inf"))
  581. # load alibi
  582. alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
  583. alibi_start_q[:, None]) * alibi_slope
  584. alibi = tl.where(
  585. (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
  586. alibi, float("-inf"))
  587. qk += alibi
  588. alibi_start_k += BLOCK_N
  589. # -- compute m_ij, p, l_ij
  590. m_ij = tl.max(qk, 1)
  591. m_i_new = tl.maximum(m_i, m_ij)
  592. p = tl.math.exp(qk - m_i_new[:, None])
  593. l_ij = tl.sum(p, 1)
  594. # -- update m_i and l_i
  595. alpha = tl.math.exp(m_i - m_i_new)
  596. l_i_new = alpha * l_i + l_ij
  597. # -- update output accumulator --
  598. # scale p
  599. # scale acc
  600. acc_scale = alpha
  601. # acc_scale = l_i / l_i_new * alpha
  602. acc = acc * acc_scale[:, None]
  603. # update acc
  604. v = tl.load(v_ptrs +
  605. (cur_batch_in_all_start_index + start_n) * stride_vbs,
  606. mask=dim_mask[None, :] &
  607. ((start_n + offs_n[:, None]) <
  608. cur_batch_seq_len - cur_batch_ctx_len),
  609. other=0.0)
  610. p = p.to(v.dtype)
  611. acc += tl.dot(p, v, allow_tf32=False)
  612. # update m_i and l_i
  613. l_i = l_i_new
  614. m_i = m_i_new
  615. acc = acc / l_i[:, None]
  616. # initialize pointers to output
  617. off_o = (
  618. (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
  619. cur_head * stride_oh + offs_d[None, :] * stride_od)
  620. out_ptrs = Out + off_o
  621. tl.store(out_ptrs,
  622. acc,
  623. mask=dim_mask[None, :] &
  624. (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
  625. return
  626. @torch.inference_mode()
  627. def context_attention_fwd(q,
  628. k,
  629. v,
  630. o,
  631. kv_cache_dtype: str,
  632. k_cache,
  633. v_cache,
  634. b_loc,
  635. b_start_loc,
  636. b_seq_len,
  637. b_ctx_len,
  638. max_input_len,
  639. k_scale: float = 1.0,
  640. v_scale: float = 1.0,
  641. alibi_slopes=None,
  642. sliding_window=None):
  643. cap = current_platform.get_device_capability()
  644. BLOCK = 128 if cap[0] >= 8 else 64
  645. NUM_WARPS = 8
  646. if q.dtype is torch.float32:
  647. BLOCK = BLOCK // 2
  648. # Conversion of FP8 Tensor from uint8 storage to
  649. # appropriate torch.dtype for interpretation by Triton
  650. if "fp8" in kv_cache_dtype:
  651. assert (k_cache.dtype == torch.uint8)
  652. assert (v_cache.dtype == torch.uint8)
  653. if kv_cache_dtype in ("fp8", "fp8_e4m3"):
  654. target_dtype = torch.float8_e4m3fn
  655. elif kv_cache_dtype == "fp8_e5m2":
  656. target_dtype = torch.float8_e5m2
  657. else:
  658. raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
  659. k_cache = k_cache.view(target_dtype)
  660. v_cache = v_cache.view(target_dtype)
  661. if (k_cache.dtype == torch.uint8
  662. or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
  663. raise ValueError("kv_cache_dtype='auto' unsupported for\
  664. FP8 KV Cache prefill kernel")
  665. # shape constraints
  666. Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
  667. assert Lq == Lk and Lk == Lv
  668. # round up Lk to a power of 2 - this is required for Triton block size
  669. Lk_padded = triton.next_power_of_2(Lk)
  670. sm_scale = 1.0 / (Lq**0.5)
  671. batch, head = b_seq_len.shape[0], q.shape[1]
  672. num_queries_per_kv = q.shape[1] // k.shape[1]
  673. grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
  674. # 0 means "disable"
  675. if sliding_window is None or sliding_window <= 0:
  676. sliding_window = 0
  677. if alibi_slopes is not None:
  678. _fwd_kernel_alibi[grid](
  679. q,
  680. k,
  681. v,
  682. k_cache,
  683. v_cache,
  684. b_loc,
  685. sm_scale,
  686. k_scale,
  687. v_scale,
  688. b_start_loc,
  689. b_seq_len,
  690. b_ctx_len,
  691. alibi_slopes,
  692. v_cache.shape[3],
  693. k_cache.shape[4],
  694. o,
  695. b_loc.stride(0),
  696. b_loc.stride(1),
  697. q.stride(0),
  698. q.stride(1),
  699. q.stride(2),
  700. k.stride(0),
  701. k.stride(1),
  702. k.stride(2),
  703. v.stride(0),
  704. v.stride(1),
  705. v.stride(2),
  706. o.stride(0),
  707. o.stride(1),
  708. o.stride(2),
  709. k_cache.stride(0),
  710. k_cache.stride(1),
  711. k_cache.stride(2),
  712. k_cache.stride(3),
  713. k_cache.stride(
  714. 4
  715. ), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
  716. v_cache.stride(0),
  717. v_cache.stride(1),
  718. v_cache.stride(2),
  719. v_cache.stride(
  720. 3), #[num_blocks, num_kv_heads, head_size, block_size]
  721. num_queries_per_kv=num_queries_per_kv,
  722. BLOCK_M=BLOCK,
  723. BLOCK_DMODEL=Lk,
  724. BLOCK_DMODEL_PADDED=Lk_padded,
  725. BLOCK_N=BLOCK,
  726. num_warps=NUM_WARPS,
  727. num_stages=1,
  728. )
  729. return
  730. _fwd_kernel[grid](
  731. q,
  732. k,
  733. v,
  734. k_cache,
  735. v_cache,
  736. b_loc,
  737. sm_scale,
  738. k_scale,
  739. v_scale,
  740. b_start_loc,
  741. b_seq_len,
  742. b_ctx_len,
  743. v_cache.shape[3],
  744. k_cache.shape[4],
  745. o,
  746. b_loc.stride(0),
  747. b_loc.stride(1),
  748. q.stride(0),
  749. q.stride(1),
  750. q.stride(2),
  751. k.stride(0),
  752. k.stride(1),
  753. k.stride(2),
  754. v.stride(0),
  755. v.stride(1),
  756. v.stride(2),
  757. o.stride(0),
  758. o.stride(1),
  759. o.stride(2),
  760. k_cache.stride(0),
  761. k_cache.stride(1),
  762. k_cache.stride(2),
  763. k_cache.stride(3),
  764. k_cache.stride(
  765. 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
  766. v_cache.stride(0),
  767. v_cache.stride(1),
  768. v_cache.stride(2),
  769. v_cache.stride(
  770. 3), #[num_blocks, num_kv_heads, head_size, block_size]
  771. num_queries_per_kv=num_queries_per_kv,
  772. BLOCK_M=BLOCK,
  773. BLOCK_DMODEL=Lk,
  774. BLOCK_DMODEL_PADDED=Lk_padded,
  775. BLOCK_N=BLOCK,
  776. SLIDING_WINDOW=sliding_window,
  777. num_warps=NUM_WARPS,
  778. num_stages=1,
  779. )
  780. return