awq_triton.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import torch
  2. import triton
  3. import triton.language as tl
  4. AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
  5. @triton.jit
  6. def awq_dequantize_kernel(
  7. qweight_ptr, # quantized matrix
  8. scales_ptr, # scales, per group
  9. zeros_ptr, # zeros, per group
  10. group_size, # Should always be one of the supported group sizes
  11. result_ptr, # Output matrix
  12. num_cols, # input num cols in qweight
  13. num_rows, # input num rows in qweight
  14. BLOCK_SIZE_X: tl.constexpr,
  15. BLOCK_SIZE_Y: tl.constexpr):
  16. # Setup the pids.
  17. pid_x = tl.program_id(axis=0)
  18. pid_y = tl.program_id(axis=1)
  19. # Compute offsets and masks for qweight_ptr.
  20. offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
  21. offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
  22. offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]
  23. masks_y = offsets_y < num_rows
  24. masks_x = offsets_x < num_cols
  25. masks = masks_y[:, None] & masks_x[None, :]
  26. # Compute offsets and masks for result output ptr.
  27. result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
  28. result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(
  29. 0, BLOCK_SIZE_X * 8)
  30. result_offsets = (8 * num_cols * result_offsets_y[:, None] +
  31. result_offsets_x[None, :])
  32. result_masks_y = result_offsets_y < num_rows
  33. result_masks_x = result_offsets_x < num_cols * 8
  34. result_masks = result_masks_y[:, None] & result_masks_x[None, :]
  35. # Load the weights.
  36. iweights = tl.load(qweight_ptr + offsets, masks)
  37. # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
  38. # that will map given indices to the correct order.
  39. reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +
  40. tl.arange(0, 4)[:, None]).reshape(8)
  41. # Use this to compute a set of shifts that can be used to unpack and
  42. # reorder the values in iweights and zeros.
  43. shifts = reverse_awq_order_tensor * 4
  44. shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))
  45. shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
  46. # Unpack and reorder: shift out the correct 4-bit value and mask.
  47. iweights = (iweights >> shifts) & 0xF
  48. # Compute zero offsets and masks.
  49. zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
  50. tl.arange(0, BLOCK_SIZE_Y) // group_size)
  51. zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
  52. zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]
  53. zero_masks_y = zero_offsets_y < num_rows // group_size
  54. zero_masks_x = zero_offsets_x < num_cols
  55. zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]
  56. # Load the zeros.
  57. zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)
  58. # Unpack and reorder: shift out the correct 4-bit value and mask.
  59. zeros = (zeros >> shifts) & 0xF
  60. # Compute scale offsets and masks.
  61. scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
  62. tl.arange(0, BLOCK_SIZE_Y) // group_size)
  63. scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
  64. tl.arange(0, BLOCK_SIZE_X * 8))
  65. scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +
  66. scale_offsets_x[None, :])
  67. scale_masks_y = scale_offsets_y < num_rows // group_size
  68. scale_masks_x = scale_offsets_x < num_cols * 8
  69. scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]
  70. # Load the scales.
  71. scales = tl.load(scales_ptr + scale_offsets, scale_masks)
  72. # Dequantize.
  73. iweights = (iweights - zeros) * scales
  74. iweights = iweights.to(result_ptr.type.element_ty)
  75. # Finally, store.
  76. tl.store(result_ptr + result_offsets, iweights, result_masks)
  77. @triton.jit
  78. def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
  79. group_size, BLOCK_SIZE_M: tl.constexpr,
  80. BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  81. SPLIT_K: tl.constexpr):
  82. pid = tl.program_id(axis=0)
  83. pid_z = tl.program_id(1)
  84. # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
  85. # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
  86. num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
  87. pid_m = pid // num_pid_n
  88. pid_n = pid % num_pid_n
  89. accumulator_dtype = c_ptr.type.element_ty
  90. # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
  91. # accumulator = tl.arange(0, BLOCK_SIZE_N)
  92. # accumulator = tl.broadcast_to(accumulator[None, :],
  93. # (BLOCK_SIZE_M, BLOCK_SIZE_N))
  94. # accumulator = accumulator & 0x0
  95. # accumulator = accumulator.to(accumulator_dtype)
  96. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
  97. dtype=accumulator_dtype)
  98. # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
  99. # that will map given indices to the correct order.
  100. reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] +
  101. tl.arange(0, 4)[:, None]).reshape(8)
  102. # Create the necessary shifts to use to unpack.
  103. shifts = reverse_awq_order_tensor * 4
  104. shifts = tl.broadcast_to(shifts[None, :],
  105. (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))
  106. shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))
  107. # Offsets and masks.
  108. offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  109. masks_am = offsets_am < M
  110. offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) +
  111. tl.arange(0, BLOCK_SIZE_N) // 8)
  112. masks_bn = offsets_bn < N // 8
  113. offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) +
  114. tl.arange(0, BLOCK_SIZE_N) // 8)
  115. masks_zn = offsets_zn < N // 8
  116. offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  117. masks_sn = offsets_sn < N
  118. offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
  119. offsets_a = K * offsets_am[:, None] + offsets_k[None, :]
  120. offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]
  121. a_ptrs = a_ptr + offsets_a
  122. b_ptrs = b_ptr + offsets_b
  123. # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
  124. # block_offset = BLOCK_SIZE_K * SPLIT_K
  125. # for k in range(0, (K + block_offset - 1) // (block_offset)):
  126. for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
  127. masks_k = offsets_k < K
  128. masks_a = masks_am[:, None] & masks_k[None, :]
  129. a = tl.load(a_ptrs, mask=masks_a)
  130. masks_b = masks_k[:, None] & masks_bn[None, :]
  131. b = tl.load(b_ptrs, mask=masks_b)
  132. # Dequantize b.
  133. offsets_szk = (
  134. (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size +
  135. tl.arange(0, BLOCK_SIZE_K) // group_size)
  136. offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
  137. masks_zk = offsets_szk < K // group_size
  138. masks_z = masks_zk[:, None] & masks_zn[None, :]
  139. zeros_ptrs = zeros_ptr + offsets_z
  140. zeros = tl.load(zeros_ptrs, mask=masks_z)
  141. offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
  142. masks_sk = offsets_szk < K // group_size
  143. masks_s = masks_sk[:, None] & masks_sn[None, :]
  144. scales_ptrs = scales_ptr + offsets_s
  145. scales = tl.load(scales_ptrs, mask=masks_s)
  146. b = (b >> shifts) & 0xF
  147. zeros = (zeros >> shifts) & 0xF
  148. b = (b - zeros) * scales
  149. b = b.to(c_ptr.type.element_ty)
  150. # Accumulate results.
  151. accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
  152. offsets_k += BLOCK_SIZE_K * SPLIT_K
  153. a_ptrs += BLOCK_SIZE_K * SPLIT_K
  154. b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)
  155. c = accumulator.to(c_ptr.type.element_ty)
  156. offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  157. offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  158. c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :]
  159. c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
  160. if SPLIT_K == 1:
  161. tl.store(c_ptrs, c, mask=c_mask)
  162. else:
  163. tl.atomic_add(c_ptrs, c, mask=c_mask)
  164. # qweights - [K , M // 8], int32
  165. # scales - [K // G, M ], float16
  166. # zeros - [K // G, M // 8], int32
  167. def awq_dequantize_triton(qweight: torch.Tensor,
  168. scales: torch.Tensor,
  169. zeros: torch.Tensor,
  170. block_size_x: int = 32,
  171. block_size_y: int = 32) -> torch.Tensor:
  172. K = qweight.shape[0]
  173. M = scales.shape[1]
  174. group_size = qweight.shape[0] // scales.shape[0]
  175. assert K > 0 and M > 0
  176. assert scales.shape[0] == K // group_size and scales.shape[1] == M
  177. assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8
  178. assert group_size <= K
  179. assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
  180. # Result tensor:
  181. # number of rows = same as input tensor
  182. # number of cols = 8 x input tensor num cols
  183. result = torch.empty(qweight.shape[0],
  184. qweight.shape[1] * 8,
  185. device=qweight.device,
  186. dtype=scales.dtype)
  187. Y = qweight.shape[0] # num rows
  188. X = qweight.shape[1] # num cols
  189. grid = lambda META: (
  190. triton.cdiv(X, META['BLOCK_SIZE_X']),
  191. triton.cdiv(Y, META['BLOCK_SIZE_Y']),
  192. )
  193. awq_dequantize_kernel[grid](qweight,
  194. scales,
  195. zeros,
  196. group_size,
  197. result,
  198. X,
  199. Y,
  200. BLOCK_SIZE_X=block_size_x,
  201. BLOCK_SIZE_Y=block_size_y)
  202. return result
  203. # input - [M, K]
  204. # qweight - [K, N // 8]
  205. # qzeros - [K // G, N // 8]
  206. # scales - [K // G, N]
  207. # split_k_iters - parallelism along K-dimension, int, power of 2.
  208. def awq_gemm_triton(input: torch.Tensor,
  209. qweight: torch.Tensor,
  210. scales: torch.Tensor,
  211. qzeros: torch.Tensor,
  212. split_k_iters: int,
  213. block_size_m: int = 32,
  214. block_size_n: int = 32,
  215. block_size_k: int = 32) -> torch.Tensor:
  216. M, K = input.shape
  217. N = qweight.shape[1] * 8
  218. group_size = qweight.shape[0] // qzeros.shape[0]
  219. assert N > 0 and K > 0 and M > 0
  220. assert qweight.shape[0] == K and qweight.shape[1] == N // 8
  221. assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8
  222. assert scales.shape[0] == K // group_size and scales.shape[1] == N
  223. assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0
  224. assert split_k_iters <= 32
  225. assert group_size <= K
  226. assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
  227. grid = lambda META: (
  228. triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
  229. N, META['BLOCK_SIZE_N']),
  230. split_k_iters,
  231. )
  232. result = torch.zeros((M, N), dtype=scales.dtype, device=input.device)
  233. # A = input, B = qweight, C = result
  234. # A = M x K, B = K x N, C = M x N
  235. awq_gemm_kernel[grid](input,
  236. qweight,
  237. result,
  238. qzeros,
  239. scales,
  240. M,
  241. N,
  242. K,
  243. group_size,
  244. BLOCK_SIZE_M=block_size_m,
  245. BLOCK_SIZE_N=block_size_n,
  246. BLOCK_SIZE_K=block_size_k,
  247. SPLIT_K=split_k_iters)
  248. return result