fwd_decode.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703
  1. import torch
  2. import triton
  3. import triton.language as tl
  4. from .utils import _strides, get_padded_headsize
  5. @triton.jit
  6. def _fwd_kernel_splitK(
  7. Q,
  8. K,
  9. V,
  10. sm_scale,
  11. Out_splitK, # [B, H, split_k, Mq, K]
  12. Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
  13. K_new,
  14. V_new,
  15. Cache_seqlens,
  16. Cache_batch_idx,
  17. Alibi_slopes,
  18. stride_qz,
  19. stride_qm,
  20. stride_qg,
  21. stride_qh,
  22. stride_qd,
  23. stride_kz,
  24. stride_kn,
  25. stride_kg,
  26. stride_kh,
  27. stride_kd,
  28. stride_vz,
  29. stride_vn,
  30. stride_vg,
  31. stride_vh,
  32. stride_vd,
  33. stride_osk_zhg,
  34. stride_osk_s,
  35. stride_osk_m,
  36. stride_osk_d,
  37. stride_mzhg,
  38. stride_m2,
  39. stride_ms,
  40. stride_mm,
  41. stride_kn_z,
  42. stride_kn_n,
  43. stride_kn_g,
  44. stride_kn_h,
  45. stride_kn_d,
  46. stride_vn_z,
  47. stride_vn_n,
  48. stride_vn_g,
  49. stride_vn_h,
  50. stride_vn_d,
  51. stride_az,
  52. stride_ah,
  53. Z,
  54. N_CTX_Q,
  55. N_CTX_K,
  56. N_CTX_NEW,
  57. BLOCK_N_PER_SPLIT,
  58. H_q: tl.constexpr,
  59. H_kv: tl.constexpr,
  60. G_q: tl.constexpr,
  61. BLOCK_M: tl.constexpr,
  62. BLOCK_DMODEL: tl.constexpr,
  63. ACTUAL_BLOCK_DMODEL: tl.constexpr,
  64. BLOCK_N: tl.constexpr,
  65. BOUNDS_CHECKS_N: tl.constexpr,
  66. USE_CACHE_SEQLENs: tl.constexpr,
  67. USE_CACHE_BATCH_IDX: tl.constexpr,
  68. NEW_KV: tl.constexpr,
  69. IS_GQA: tl.constexpr,
  70. IS_CAUSAL: tl.constexpr,
  71. USE_ALIBI: tl.constexpr,
  72. ):
  73. # Padding
  74. PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
  75. if PADDED_HEAD:
  76. d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL
  77. start_m = tl.program_id(0)
  78. off_zhg = tl.program_id(1)
  79. off_z = off_zhg // (H_q * G_q)
  80. off_h_q = (off_zhg // G_q) % H_q
  81. off_g_q = off_zhg % G_q
  82. splitk_idx = tl.program_id(2)
  83. # pick batch index
  84. if USE_CACHE_BATCH_IDX:
  85. cache_batch_idx = tl.load(Cache_batch_idx + off_z)
  86. else:
  87. cache_batch_idx = off_z
  88. # Load ALiBi slope if enabled
  89. if USE_ALIBI:
  90. a_offset = off_z * stride_az + off_h_q * stride_ah
  91. alibi_slope = tl.load(Alibi_slopes + a_offset)
  92. else:
  93. alibi_slope = None
  94. lo = splitk_idx * BLOCK_N_PER_SPLIT
  95. if USE_CACHE_SEQLENs:
  96. cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z)
  97. if NEW_KV:
  98. kv_len = cache_seqlen_last_idx + N_CTX_NEW
  99. else:
  100. kv_len = cache_seqlen_last_idx
  101. else:
  102. kv_len = N_CTX_K
  103. hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
  104. HEAD_RATIO: tl.constexpr = H_q // H_kv
  105. if IS_GQA:
  106. k_head_idx = off_h_q // HEAD_RATIO
  107. v_head_idx = k_head_idx
  108. else:
  109. k_head_idx = off_h_q
  110. v_head_idx = off_h_q
  111. # calculate base offset
  112. k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg
  113. v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg
  114. # Copy new Keys and Values into Cache
  115. if NEW_KV:
  116. knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g
  117. # Determine the starting position for new data in the cache
  118. if USE_CACHE_SEQLENs:
  119. start_idx = tl.load(Cache_seqlens + off_z)
  120. else:
  121. start_idx = N_CTX_K - N_CTX_NEW
  122. # Copy new Keys
  123. for i in range(0, N_CTX_NEW, BLOCK_N):
  124. # Load from K_new
  125. k_new_block = tl.load(
  126. knew_base +
  127. tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d +
  128. (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n,
  129. mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
  130. (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
  131. other=0
  132. )
  133. # Store to K
  134. tl.store(
  135. k_base +
  136. tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd +
  137. (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn,
  138. k_new_block,
  139. mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
  140. (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
  141. )
  142. # Copy new Values
  143. vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g
  144. for i in range(0, N_CTX_NEW, BLOCK_N):
  145. # Load from V_new
  146. v_new_block = tl.load(
  147. vnew_base +
  148. (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n +
  149. tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d,
  150. mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
  151. (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL),
  152. other=0
  153. )
  154. # Store to V
  155. tl.store(
  156. v_base +
  157. (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn +
  158. tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd,
  159. v_new_block,
  160. mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
  161. (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL),
  162. )
  163. Q_block_ptr = tl.make_block_ptr(
  164. base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg,
  165. shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL),
  166. strides=(stride_qm, stride_qd),
  167. offsets=(start_m * BLOCK_M, 0),
  168. block_shape=(BLOCK_M, BLOCK_DMODEL),
  169. order=(1, 0),
  170. )
  171. K_block_ptr = tl.make_block_ptr(
  172. base=k_base,
  173. shape=(ACTUAL_BLOCK_DMODEL, hi),
  174. strides=(stride_kd, stride_kn),
  175. offsets=(0, lo),
  176. block_shape=(BLOCK_DMODEL, BLOCK_N),
  177. order=(0, 1),
  178. )
  179. V_block_ptr = tl.make_block_ptr(
  180. base=v_base,
  181. shape=(hi, ACTUAL_BLOCK_DMODEL),
  182. strides=(stride_vn, stride_vd),
  183. offsets=(lo, 0),
  184. block_shape=(BLOCK_N, BLOCK_DMODEL),
  185. order=(1, 0),
  186. )
  187. K_scale_shift_block_ptr = None
  188. V_scale_shift_block_ptr = None
  189. # initialize pointer to m and l
  190. m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
  191. l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
  192. acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821
  193. # scale sm_scale by log_2(e) and use
  194. # 2^x instead of exp in the loop because CSE and LICM
  195. # don't work as expected with `exp` in the loop
  196. qk_scale = sm_scale * 1.44269504
  197. # load q: it will stay in SRAM throughout
  198. q = tl.load( # noqa: F821
  199. tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, ))
  200. q = (q * qk_scale).to(q.dtype)
  201. if PADDED_HEAD:
  202. q = tl.where(d_mask[None, :], q, 0.0)
  203. # loop over k, v and update accumulator
  204. for start_n in range(lo, hi, BLOCK_N):
  205. k, v = load_k_v_group(
  206. K_block_ptr,
  207. V_block_ptr,
  208. K_scale_shift_block_ptr,
  209. V_scale_shift_block_ptr,
  210. BOUNDS_CHECKS_N,
  211. 1,
  212. BLOCK_DMODEL,
  213. ACTUAL_BLOCK_DMODEL,
  214. Q.dtype.element_ty,
  215. 0,
  216. )
  217. if PADDED_HEAD:
  218. k = tl.where(d_mask[:, None], k, 0.0)
  219. v = tl.where(d_mask[None, :], v, 0.0)
  220. # -- compute qk ---
  221. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  222. qk += tl.dot(q, k) # noqa: F821
  223. if USE_ALIBI:
  224. row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  225. col_idx = start_n + tl.arange(0, BLOCK_N)
  226. # Compute relative positions
  227. relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :])
  228. relative_pos = tl.abs(relative_pos)
  229. # Compute ALiBi bias
  230. alibi_bias = -1 * alibi_slope * relative_pos
  231. qk += (alibi_bias * 1.44269504)
  232. # Apply causal mask if IS_CAUSAL is True
  233. if IS_CAUSAL:
  234. row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
  235. col_idx = start_n + tl.arange(0, BLOCK_N)
  236. # create a N_CTX_Q x kv_len causal mask
  237. col_offset = N_CTX_Q - kv_len
  238. causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :])
  239. # Apply the mask
  240. qk = tl.where(causal_mask, qk, float("-inf"))
  241. # TODO: This is slow, and only needed at the last iteration.
  242. # Maybe we can unroll the last iteration instead?
  243. if BOUNDS_CHECKS_N:
  244. qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
  245. # -- compute scaling constant ---
  246. m_i_new = tl.maximum(m_i, tl.max(qk, 1))
  247. if IS_CAUSAL:
  248. alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf")))
  249. else:
  250. alpha = tl.math.exp2(m_i - m_i_new)
  251. # cause of nan because subtracting infs
  252. if IS_CAUSAL:
  253. qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf"))
  254. else:
  255. qk = qk - m_i_new[:, None]
  256. p = tl.math.exp2(qk)
  257. # -- update m_i and l_i --
  258. l_i = l_i * alpha + tl.sum(p, 1)
  259. m_i = m_i_new
  260. p = p.to(Q.dtype.element_ty)
  261. # -- scale and update acc --
  262. acc *= alpha[:, None]
  263. acc += tl.dot(p.to(v.dtype), v)
  264. # update pointers
  265. K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
  266. V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
  267. # write back O
  268. O_block_ptr = tl.make_block_ptr(
  269. base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
  270. shape=(N_CTX_Q, BLOCK_DMODEL),
  271. strides=(stride_osk_m, 1),
  272. offsets=(start_m * BLOCK_M, 0),
  273. block_shape=(BLOCK_M, BLOCK_DMODEL),
  274. order=(1, 0),
  275. )
  276. tl.store(
  277. tl.advance(O_block_ptr, (0, 0)),
  278. acc,
  279. boundary_check=(0, ),
  280. )
  281. # Write metadata for split-K reduction
  282. Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M +
  283. tl.arange(0, BLOCK_M))
  284. tl.store(Metadata_ptr, m_i)
  285. tl.store(Metadata_ptr + stride_m2, l_i)
  286. @triton.jit
  287. def load_k_v_group(
  288. K_block_ptr,
  289. V_block_ptr,
  290. K_scale_shift_block_ptr,
  291. V_scale_shift_block_ptr,
  292. BOUNDS_CHECKS_N: tl.constexpr,
  293. PACKED_PER_VAL: tl.constexpr,
  294. BLOCK_DMODEL: tl.constexpr,
  295. ACTUAL_BLOCK_DMODEL: tl.constexpr,
  296. dtype: tl.constexpr,
  297. group_id: tl.constexpr,
  298. ):
  299. #Load K/V for a given block
  300. # Advance to the current quantization group
  301. K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0))
  302. V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id))
  303. # -- load k, v --
  304. k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ())
  305. v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ())
  306. return k, v
  307. @triton.jit
  308. def cast_uint32_to_half2(scale_shift):
  309. # Extract two float16 packed into one int32
  310. scale = scale_shift & 0xFFFF
  311. shift = scale_shift >> 16
  312. scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
  313. shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
  314. return scale, shift
  315. @triton.jit
  316. def dequantize(
  317. x_,
  318. scale,
  319. shift,
  320. PACKED_PER_VAL: tl.constexpr = 8,
  321. ):
  322. # PACKED_PER_VAL is the number of values packed into
  323. # each element x_. For example, for int4 quantization
  324. #and x_ of type int32, PACKED_PER_VAL is 8.
  325. BLOCK_N: tl.constexpr = x_.shape[0]
  326. BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
  327. offsets = tl.arange(0, PACKED_PER_VAL) * 4
  328. quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)
  329. quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL))
  330. # Trick - instead of converting int4 to float16 we view it as float16
  331. # and then multiply by 32768 * 512 == 2**24
  332. quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
  333. quant_offset = (quant_offset * 32768.0).to(tl.float16)
  334. scale_512 = scale * 512
  335. dequant = quant_offset * scale_512 + shift
  336. return dequant
  337. @triton.jit
  338. def _splitK_reduce(
  339. Out_splitK, # [B, H, split_k, Mq, K]
  340. Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
  341. Out, # [B, H, M, K]
  342. LSE, # [B, H, M]
  343. stride_osk_zhg,
  344. stride_osk_s,
  345. stride_osk_m,
  346. stride_osk_k,
  347. stride_mzhg,
  348. stride_m2,
  349. stride_ms,
  350. stride_mm,
  351. stride_oz,
  352. stride_oh,
  353. stride_og,
  354. stride_om,
  355. stride_ok,
  356. stride_lse_zhg,
  357. stride_lse_m,
  358. M_ceil: tl.constexpr,
  359. BLOCK_SIZE: tl.constexpr,
  360. H: tl.constexpr,
  361. G: tl.constexpr,
  362. split_k: tl.constexpr,
  363. splitK_pow2: tl.constexpr,
  364. use_mask: tl.constexpr,
  365. IS_CAUSAL: tl.constexpr,
  366. ):
  367. off_zhg = tl.program_id(0)
  368. off_z = off_zhg // (H * G)
  369. off_h = (off_zhg // G) % H
  370. off_g = off_zhg % G
  371. off_m = tl.program_id(1)
  372. off_k = tl.program_id(2)
  373. # read chunk
  374. spk_idx = tl.arange(0, splitK_pow2)
  375. kidx = tl.arange(0, BLOCK_SIZE)
  376. Metadata_ptr = (Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm)
  377. o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE +
  378. stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k)
  379. # read max values of each splitK
  380. if use_mask:
  381. spk_mask = spk_idx < split_k
  382. l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf"))
  383. l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0)
  384. acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0)
  385. else:
  386. l_m = tl.load(Metadata_ptr)
  387. l_sum = tl.load(Metadata_ptr + stride_m2)
  388. acc = tl.load(o_ptr)
  389. g_m = tl.max(l_m, axis=0)
  390. if IS_CAUSAL:
  391. l_m_offset = l_m - g_m
  392. alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0)
  393. else:
  394. alpha = tl.math.exp2(l_m - g_m)
  395. # read sum
  396. l_sum *= alpha
  397. g_sum = tl.sum(l_sum, axis=0)
  398. acc = acc * alpha[:, None]
  399. if IS_CAUSAL:
  400. # Avoid division by zero
  401. g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0)
  402. acc_out = tl.sum(acc, axis=0) / g_sum_safe
  403. else:
  404. acc_out = tl.sum(acc, axis=0) / g_sum
  405. # Store output
  406. Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m +
  407. off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE))
  408. tl.store(Out_ptr, acc_out)
  409. # Store lse
  410. l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
  411. if IS_CAUSAL:
  412. lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m)
  413. tl.store(l_ptrs, lse)
  414. else:
  415. tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504)
  416. def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
  417. # Scale and shift are such that quantization linearly maps
  418. # int4 values range [0..15] to input values range min(k)..max(k)
  419. # individually for every row
  420. k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups)
  421. max_vals = torch.max(k, dim=-1, keepdim=True).values
  422. min_vals = torch.min(k, dim=-1, keepdim=True).values
  423. scale_k: torch.Tensor = (max_vals - min_vals) / 15
  424. shift_k = torch.min(k, dim=-1, keepdim=True).values
  425. scale_k = scale_k.to(torch.float16)
  426. shift_k = shift_k.to(torch.float16)
  427. in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5
  428. in_bytes = in_bytes.to(torch.uint8)
  429. in_int4 = in_bytes & 0xF
  430. in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4)
  431. scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1)
  432. k_quant = torch.concat(
  433. [
  434. scale_shift.flatten(start_dim=-2),
  435. in_int4_packed.flatten(start_dim=-2),
  436. ],
  437. dim=-1,
  438. ).view(torch.int16)
  439. return k_quant
  440. def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
  441. k_i16 = quant_k.view(torch.int16)
  442. k_ui8 = k_i16.view(torch.uint8)
  443. ss_size = num_groups * 4
  444. scale_shift_ui8 = k_ui8[..., 0:ss_size]
  445. scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4)
  446. scale = scale_shift_ui8[..., 0:2].view(torch.float16)
  447. shift = scale_shift_ui8[..., 2:4].view(torch.float16)
  448. kv_ui8 = k_ui8[..., ss_size:]
  449. k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1)
  450. k1_i4 = k_ui8 & 0xF
  451. k2_i4 = (k_ui8 & 0xF0) >> 4
  452. k_shape = k1_i4.shape
  453. k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape)
  454. k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape)
  455. out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device)
  456. out[..., ::2] = k1_f16
  457. out[..., 1::2] = k2_f16
  458. out = out.reshape(*k_shape[:-2], -1)
  459. return out
  460. def get_split_k(B: int, G: int, H: int, Mk: int) -> int:
  461. """Heuristic for the number of splits"""
  462. bh = max(B * H, 1) # NOTE: Handle B*h=0 case
  463. split_k = max(Mk, 1024) // bh
  464. max_chunk_size = 64
  465. while split_k > 0 and Mk / split_k < max_chunk_size:
  466. split_k = split_k // 2
  467. while B * H * G * split_k >= 1024:
  468. split_k = split_k // 2
  469. split_k = min(split_k, 512)
  470. split_k = max(split_k, 1)
  471. return split_k
  472. def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new):
  473. # kernel config
  474. BLOCK_M = 16
  475. BLOCK_N = 64
  476. SPLIT_K = None
  477. NUM_QUANT_GROUPS = 1
  478. # kernels expects "bsghd"
  479. original_layout = layout
  480. if layout == "bshd":
  481. q=q.unsqueeze(2)
  482. k=k.unsqueeze(2)
  483. v=v.unsqueeze(2)
  484. if new_kv:
  485. k_new = k_new.unsqueeze(2)
  486. v_new = v_new.unsqueeze(2)
  487. layout = "bsghd"
  488. elif layout == "bhsd":
  489. q=q.permute(0, 2, 1, 3).unsqueeze(2)
  490. k=k.permute(0, 2, 1, 3).unsqueeze(2)
  491. v=v.permute(0, 2, 1, 3).unsqueeze(2)
  492. if new_kv:
  493. k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2)
  494. v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2)
  495. layout = "bsghd"
  496. elif layout == "bsghd":
  497. pass
  498. elif layout is None:
  499. raise ValueError("Layout not given")
  500. assert layout == "bsghd"
  501. # get dims
  502. batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape
  503. _, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape
  504. _, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape
  505. assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}"
  506. # get padded size
  507. dim_padded = get_padded_headsize(dim_k)
  508. # Handle MQA/GQA case
  509. if heads_per_group_q > heads_per_group_k:
  510. is_gqa = True
  511. elif heads_per_group_q < heads_per_group_k:
  512. raise ValueError("heads_per_group_q < heads_per_group_k")
  513. else:
  514. is_gqa = False
  515. assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}"
  516. if SPLIT_K is not None:
  517. split_k = SPLIT_K
  518. else:
  519. # Use heuristics
  520. split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_k) # NOTE: should the split think about seqlens?
  521. seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M
  522. out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device)
  523. metadata = torch.empty([batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device)
  524. lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), device=q.device, dtype=torch.float32)
  525. grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_q * heads_per_group_q, split_k)
  526. num_warps = 1
  527. split_size = (seqlen_k + split_k - 1) // split_k
  528. use_cache_seqlens = cache_seqlens is not None
  529. # TODO: enable quantization
  530. _fwd_kernel_splitK[grid](
  531. Q=q,
  532. K=k,
  533. V=v,
  534. sm_scale=sm_scale,
  535. Out_splitK=out_splitk,
  536. Metadata=metadata,
  537. K_new = k_new,
  538. V_new = v_new,
  539. Cache_seqlens=cache_seqlens,
  540. Cache_batch_idx=cache_batch_idx,
  541. Alibi_slopes=alibi_slopes,
  542. **_strides(q, "qz", "qm", "qg", "qh", "qd"),
  543. **_strides(k, "kz", "kn", "kg", "kh", "kd"),
  544. **_strides(v, "vz", "vn", "vg", "vh", "vd"),
  545. **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"),
  546. **_strides(metadata, "mzhg", "m2", "ms", "mm"),
  547. **_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"),
  548. **_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"),
  549. **_strides(alibi_slopes, "az", "ah"),
  550. Z=batch_size,
  551. H_q=heads_per_group_q,
  552. H_kv=heads_per_group_k,
  553. G_q=n_group_q,
  554. N_CTX_Q=seqlen_q,
  555. N_CTX_K=seqlen_k,
  556. N_CTX_NEW=k_new.shape[1] if new_kv else None,
  557. BLOCK_N_PER_SPLIT=split_size,
  558. BLOCK_M=BLOCK_M,
  559. BLOCK_N=BLOCK_N,
  560. BLOCK_DMODEL=dim_padded,
  561. ACTUAL_BLOCK_DMODEL=dim_k,
  562. BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens,
  563. USE_CACHE_SEQLENs=use_cache_seqlens,
  564. USE_CACHE_BATCH_IDX=cache_batch_idx is not None,
  565. NEW_KV=new_kv,
  566. IS_GQA=is_gqa,
  567. IS_CAUSAL=causal,
  568. USE_ALIBI=False if alibi_slopes is None else True,
  569. num_warps=num_warps,
  570. num_stages=1,
  571. )
  572. out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype)
  573. # Merge together
  574. splitK_pow2 = triton.next_power_of_2(split_k)
  575. use_mask = splitK_pow2 > split_k
  576. if batch_size * n_group_q * heads_per_group_q * seqlen_q >= 512:
  577. k_block_num = 1
  578. else:
  579. k_block_num = 2
  580. assert dim_padded % k_block_num == 0
  581. k_block_size = dim_padded // k_block_num
  582. grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num)
  583. _splitK_reduce[grid](
  584. out_splitk,
  585. metadata,
  586. out,
  587. lse,
  588. **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
  589. **_strides(metadata, "mzhg", "m2", "ms", "mm"),
  590. **_strides(out, "oz", "om", "og", "oh", "ok"),
  591. **_strides(lse, "lse_zhg", "lse_m"),
  592. M_ceil=seqlen_q_ceil,
  593. BLOCK_SIZE=k_block_size,
  594. G=n_group_q,
  595. H=heads_per_group_q,
  596. # TODO: Tune num_warps
  597. split_k=split_k,
  598. splitK_pow2=splitK_pow2,
  599. use_mask=use_mask,
  600. IS_CAUSAL=causal,
  601. num_warps=4)
  602. lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q])
  603. if q.ndim == 4:
  604. # BMGHK -> BMHK
  605. assert n_group_q == 1
  606. out = out[:, :, 0]
  607. lse = lse[:, 0]
  608. if seqlen_k == 0:
  609. out.zero_()
  610. out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous()
  611. # output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q
  612. if original_layout == "bshd":
  613. # out=out.transpose(1, 2).contiguous() # this screws up heads and data.
  614. # the data is laid out properly. Just need to reshape dims
  615. out = out.reshape(batch_size, seqlen_q, -1, dim_padded)
  616. return out.narrow(-1, 0, dim_k), lse