awq_triton.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import torch
  2. import triton
  3. import triton.language as tl
  4. AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  5. @triton.jit
  6. def awq_dequantize_kernel(
  7. qweight_ptr, # quantized matrix
  8. scales_ptr, # scales, per group
  9. zeros_ptr, # zeros, per group
  10. group_size, # Should always be one of the supported group sizes
  11. result_ptr, # Output matrix
  12. num_cols, # input num cols in qweight
  13. num_rows, # input num rows in qweight
  14. BLOCK_SIZE_X: tl.constexpr,
  15. BLOCK_SIZE_Y: tl.constexpr):
  16. # Setup the pids.
  17. pid_x = tl.program_id(axis=0)
  18. pid_y = tl.program_id(axis=1)
  19. # Compute offsets and masks for qweight_ptr.
  20. offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
  21. offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
  22. offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]
  23. masks_y = offsets_y < num_rows
  24. masks_x = offsets_x < num_cols
  25. masks = masks_y[:, None] & masks_x[None, :]
  26. # Compute offsets and masks for result output ptr.
  27. result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
  28. result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(
  29. 0, BLOCK_SIZE_X * 8)
  30. result_offsets = (8 * num_cols * result_offsets_y[:, None] +
  31. result_offsets_x[None, :])
  32. result_masks_y = result_offsets_y < num_rows
  33. result_masks_x = result_offsets_x < num_cols * 8
  34. result_masks = result_masks_y[:, None] & result_masks_x[None, :]
  35. # Load the weights.
  36. iweights = tl.load(qweight_ptr + offsets, masks)
  37. iweights = tl.interleave(iweights, iweights)
  38. iweights = tl.interleave(iweights, iweights)
  39. iweights = tl.interleave(iweights, iweights)
  40. # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
  41. # that will map given indices to the correct order.
  42. reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +
  43. tl.arange(0, 4)[:, None]).reshape(8)
  44. # Use this to compute a set of shifts that can be used to unpack and
  45. # reorder the values in iweights and zeros.
  46. shifts = reverse_awq_order_tensor * 4
  47. shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))
  48. shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
  49. # Unpack and reorder: shift out the correct 4-bit value and mask.
  50. iweights = (iweights >> shifts) & 0xF
  51. # Compute zero offsets and masks.
  52. zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
  53. zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
  54. zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]
  55. zero_masks_y = zero_offsets_y < num_rows // group_size
  56. zero_masks_x = zero_offsets_x < num_cols
  57. zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]
  58. # Load the zeros.
  59. zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)
  60. zeros = tl.interleave(zeros, zeros)
  61. zeros = tl.interleave(zeros, zeros)
  62. zeros = tl.interleave(zeros, zeros)
  63. zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
  64. # Unpack and reorder: shift out the correct 4-bit value and mask.
  65. zeros = (zeros >> shifts) & 0xF
  66. # Compute scale offsets and masks.
  67. scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
  68. scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
  69. tl.arange(0, BLOCK_SIZE_X * 8))
  70. scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +
  71. scale_offsets_x[None, :])
  72. scale_masks_y = scale_offsets_y < num_rows // group_size
  73. scale_masks_x = scale_offsets_x < num_cols * 8
  74. scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]
  75. # Load the scales.
  76. scales = tl.load(scales_ptr + scale_offsets, scale_masks)
  77. scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
  78. # Dequantize.
  79. iweights = (iweights - zeros) * scales
  80. iweights = iweights.to(result_ptr.type.element_ty)
  81. # Finally, store.
  82. tl.store(result_ptr + result_offsets, iweights, result_masks)
  83. @triton.jit
  84. def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
  85. group_size, BLOCK_SIZE_M: tl.constexpr,
  86. BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  87. SPLIT_K: tl.constexpr):
  88. pid = tl.program_id(axis=0)
  89. pid_z = tl.program_id(1)
  90. # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
  91. # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
  92. num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
  93. pid_m = pid // num_pid_n
  94. pid_n = pid % num_pid_n
  95. accumulator_dtype = c_ptr.type.element_ty
  96. # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
  97. # accumulator = tl.arange(0, BLOCK_SIZE_N)
  98. # accumulator = tl.broadcast_to(accumulator[None, :],
  99. # (BLOCK_SIZE_M, BLOCK_SIZE_N))
  100. # accumulator = accumulator & 0x0
  101. # accumulator = accumulator.to(accumulator_dtype)
  102. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
  103. dtype=accumulator_dtype)
  104. # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
  105. # that will map given indices to the correct order.
  106. reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +
  107. tl.arange(0, 4)[:, None]).reshape(8)
  108. # Create the necessary shifts to use to unpack.
  109. shifts = reverse_awq_order_tensor * 4
  110. shifts = tl.broadcast_to(shifts[None, :],
  111. (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))
  112. shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))
  113. # Offsets and masks.
  114. offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  115. masks_am = offsets_am < M
  116. offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
  117. masks_bn = offsets_bn < N // 8
  118. offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
  119. masks_zn = offsets_zn < N // 8
  120. offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  121. masks_sn = offsets_sn < N
  122. offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
  123. offsets_a = K * offsets_am[:, None] + offsets_k[None, :]
  124. offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]
  125. a_ptrs = a_ptr + offsets_a
  126. b_ptrs = b_ptr + offsets_b
  127. # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
  128. # block_offset = BLOCK_SIZE_K * SPLIT_K
  129. # for k in range(0, (K + block_offset - 1) // (block_offset)):
  130. for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
  131. masks_k = offsets_k < K
  132. masks_a = masks_am[:, None] & masks_k[None, :]
  133. a = tl.load(a_ptrs, mask=masks_a)
  134. masks_b = masks_k[:, None] & masks_bn[None, :]
  135. b = tl.load(b_ptrs, mask=masks_b)
  136. b = tl.interleave(b, b)
  137. b = tl.interleave(b, b)
  138. b = tl.interleave(b, b)
  139. # Dequantize b.
  140. offsets_szk = (
  141. (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size +
  142. tl.arange(0, 1))
  143. offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
  144. masks_zk = offsets_szk < K // group_size
  145. masks_z = masks_zk[:, None] & masks_zn[None, :]
  146. zeros_ptrs = zeros_ptr + offsets_z
  147. zeros = tl.load(zeros_ptrs, mask=masks_z)
  148. zeros = tl.interleave(zeros, zeros)
  149. zeros = tl.interleave(zeros, zeros)
  150. zeros = tl.interleave(zeros, zeros)
  151. zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))
  152. offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
  153. masks_sk = offsets_szk < K // group_size
  154. masks_s = masks_sk[:, None] & masks_sn[None, :]
  155. scales_ptrs = scales_ptr + offsets_s
  156. scales = tl.load(scales_ptrs, mask=masks_s)
  157. scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
  158. b = (b >> shifts) & 0xF
  159. zeros = (zeros >> shifts) & 0xF
  160. b = (b - zeros) * scales
  161. b = b.to(c_ptr.type.element_ty)
  162. # Accumulate results.
  163. accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
  164. offsets_k += BLOCK_SIZE_K * SPLIT_K
  165. a_ptrs += BLOCK_SIZE_K * SPLIT_K
  166. b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)
  167. c = accumulator.to(c_ptr.type.element_ty)
  168. offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  169. offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  170. c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :]
  171. c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
  172. if SPLIT_K == 1:
  173. tl.store(c_ptrs, c, mask=c_mask)
  174. else:
  175. tl.atomic_add(c_ptrs, c, mask=c_mask)
  176. # qweights - [K , M // 8], int32
  177. # scales - [K // G, M ], float16
  178. # zeros - [K // G, M // 8], int32
  179. def awq_dequantize_triton(qweight: torch.Tensor,
  180. scales: torch.Tensor,
  181. zeros: torch.Tensor,
  182. block_size_x: int = 32,
  183. block_size_y: int = 32) -> torch.Tensor:
  184. K = qweight.shape[0]
  185. M = scales.shape[1]
  186. group_size = qweight.shape[0] // scales.shape[0]
  187. assert K > 0 and M > 0
  188. assert scales.shape[0] == K // group_size and scales.shape[1] == M
  189. assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8
  190. assert group_size <= K
  191. assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
  192. # Result tensor:
  193. # number of rows = same as input tensor
  194. # number of cols = 8 x input tensor num cols
  195. result = torch.empty(qweight.shape[0],
  196. qweight.shape[1] * 8,
  197. device=qweight.device,
  198. dtype=scales.dtype)
  199. Y = qweight.shape[0] # num rows
  200. X = qweight.shape[1] # num cols
  201. grid = lambda META: (
  202. triton.cdiv(X, META['BLOCK_SIZE_X']),
  203. triton.cdiv(Y, META['BLOCK_SIZE_Y']),
  204. )
  205. awq_dequantize_kernel[grid](qweight,
  206. scales,
  207. zeros,
  208. group_size,
  209. result,
  210. X,
  211. Y,
  212. BLOCK_SIZE_X=block_size_x,
  213. BLOCK_SIZE_Y=block_size_y)
  214. return result
  215. # input - [M, K]
  216. # qweight - [K, N // 8]
  217. # qzeros - [K // G, N // 8]
  218. # scales - [K // G, N]
  219. # split_k_iters - parallelism along K-dimension, int, power of 2.
  220. def awq_gemm_triton(input: torch.Tensor,
  221. qweight: torch.Tensor,
  222. scales: torch.Tensor,
  223. qzeros: torch.Tensor,
  224. split_k_iters: int,
  225. block_size_m: int = 32,
  226. block_size_n: int = 32,
  227. block_size_k: int = 32) -> torch.Tensor:
  228. M, K = input.shape
  229. N = qweight.shape[1] * 8
  230. group_size = qweight.shape[0] // qzeros.shape[0]
  231. assert N > 0 and K > 0 and M > 0
  232. assert qweight.shape[0] == K and qweight.shape[1] == N // 8
  233. assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8
  234. assert scales.shape[0] == K // group_size and scales.shape[1] == N
  235. assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0
  236. assert split_k_iters <= 32
  237. assert group_size <= K
  238. assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
  239. grid = lambda META: (
  240. triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
  241. N, META['BLOCK_SIZE_N']),
  242. split_k_iters,
  243. )
  244. result = torch.zeros((M, N), dtype=scales.dtype, device=input.device)
  245. # A = input, B = qweight, C = result
  246. # A = M x K, B = K x N, C = M x N
  247. awq_gemm_kernel[grid](input,
  248. qweight,
  249. result,
  250. qzeros,
  251. scales,
  252. M,
  253. N,
  254. K,
  255. group_size,
  256. BLOCK_SIZE_M=block_size_m,
  257. BLOCK_SIZE_N=block_size_n,
  258. BLOCK_SIZE_K=block_size_k,
  259. SPLIT_K=split_k_iters)
  260. return result