gptq_marlin_repack.cu 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. #include "marlin.cuh"
  2. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  3. namespace marlin {
  4. template <int const num_threads, int const num_bits, bool const has_perm>
  5. __global__ void gptq_marlin_repack_kernel(
  6. uint32_t const* __restrict__ b_q_weight_ptr,
  7. uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
  8. int size_k, int size_n) {}
  9. } // namespace marlin
  10. torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
  11. int64_t size_k, int64_t size_n,
  12. int64_t num_bits) {
  13. TORCH_CHECK_NOT_IMPLEMENTED(
  14. false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
  15. return torch::empty({1, 1});
  16. }
  17. #else
  18. namespace marlin {
  19. template <int const num_threads, int const num_bits, bool const has_perm>
  20. __global__ void gptq_marlin_repack_kernel(
  21. uint32_t const* __restrict__ b_q_weight_ptr,
  22. uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
  23. int size_k, int size_n) {
  24. constexpr int pack_factor = 32 / num_bits;
  25. int k_tiles = size_k / tile_k_size;
  26. int n_tiles = size_n / tile_n_size;
  27. int block_k_tiles = div_ceil(k_tiles, gridDim.x);
  28. int start_k_tile = blockIdx.x * block_k_tiles;
  29. if (start_k_tile >= k_tiles) {
  30. return;
  31. }
  32. int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
  33. // Wait until the next thread tile has been loaded to shared memory.
  34. auto wait_for_stage = [&]() {
  35. // We only have `stages - 2` active fetches since we are double buffering
  36. // and can only issue the next fetch when it is guaranteed that the previous
  37. // shared memory load is fully complete (as it may otherwise be
  38. // overwritten).
  39. cp_async_wait<repack_stages - 2>();
  40. __syncthreads();
  41. };
  42. extern __shared__ int4 sh[];
  43. constexpr int perm_size = tile_k_size / 4;
  44. int4* sh_perm_ptr = sh;
  45. int4* sh_pipe_ptr = sh_perm_ptr;
  46. if constexpr (has_perm) {
  47. sh_pipe_ptr += perm_size;
  48. }
  49. constexpr int tile_ints = tile_k_size / pack_factor;
  50. constexpr int stage_n_threads = tile_n_size / 4;
  51. constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
  52. constexpr int stage_size = stage_k_threads * stage_n_threads;
  53. auto load_perm_to_shared = [&](int k_tile_id) {
  54. int first_k_int4 = (k_tile_id * tile_k_size) / 4;
  55. int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
  56. if (threadIdx.x < perm_size) {
  57. sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
  58. }
  59. __syncthreads();
  60. };
  61. auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
  62. if (n_tile_id >= n_tiles) {
  63. cp_async_fence();
  64. return;
  65. }
  66. int first_n = n_tile_id * tile_n_size;
  67. int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
  68. if constexpr (has_perm) {
  69. if (threadIdx.x < stage_size) {
  70. int k_id = threadIdx.x / stage_n_threads;
  71. int n_id = threadIdx.x % stage_n_threads;
  72. uint32_t const* sh_perm_int_ptr =
  73. reinterpret_cast<uint32_t const*>(sh_perm_ptr);
  74. int src_k = sh_perm_int_ptr[k_id];
  75. int src_k_packed = src_k / pack_factor;
  76. cp_async4(
  77. &sh_ptr[k_id * stage_n_threads + n_id],
  78. reinterpret_cast<int4 const*>(&(
  79. b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
  80. }
  81. } else {
  82. if (threadIdx.x < stage_size) {
  83. int k_id = threadIdx.x / stage_n_threads;
  84. int n_id = threadIdx.x % stage_n_threads;
  85. int first_k = k_tile_id * tile_k_size;
  86. int first_k_packed = first_k / pack_factor;
  87. cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
  88. reinterpret_cast<int4 const*>(
  89. &(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
  90. first_n + (n_id * 4)])));
  91. }
  92. }
  93. cp_async_fence();
  94. };
  95. auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
  96. if (n_tile_id >= n_tiles) {
  97. return;
  98. }
  99. int warp_id = threadIdx.x / 32;
  100. int th_id = threadIdx.x % 32;
  101. if (warp_id >= 4) {
  102. return;
  103. }
  104. int tc_col = th_id / 4;
  105. int tc_row = (th_id % 4) * 2;
  106. constexpr int tc_offsets[4] = {0, 1, 8, 9};
  107. int cur_n = warp_id * 16 + tc_col;
  108. constexpr int sh_stride = 64;
  109. constexpr uint32_t mask = (1 << num_bits) - 1;
  110. int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
  111. uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
  112. uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
  113. uint32_t vals[8];
  114. if constexpr (has_perm) {
  115. for (int i = 0; i < 4; i++) {
  116. int k_idx = tc_row + tc_offsets[i];
  117. uint32_t src_k = sh_perm_int_ptr[k_idx];
  118. uint32_t src_k_pos = src_k % pack_factor;
  119. uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
  120. uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
  121. uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
  122. uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
  123. vals[i] = b1_cur_val;
  124. vals[4 + i] = b2_cur_val;
  125. }
  126. } else {
  127. uint32_t b1_vals[tile_ints];
  128. uint32_t b2_vals[tile_ints];
  129. #pragma unroll
  130. for (int i = 0; i < tile_ints; i++) {
  131. b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
  132. b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
  133. }
  134. #pragma unroll
  135. for (int i = 0; i < 4; i++) {
  136. int cur_elem = tc_row + tc_offsets[i];
  137. int cur_int = cur_elem / pack_factor;
  138. int cur_pos = cur_elem % pack_factor;
  139. vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
  140. vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
  141. }
  142. }
  143. constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
  144. int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
  145. // Result of:
  146. // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
  147. if constexpr (num_bits == 4) {
  148. constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
  149. uint32_t res = 0;
  150. #pragma unroll
  151. for (int i = 0; i < 8; i++) {
  152. res |= vals[pack_idx[i]] << (i * 4);
  153. }
  154. out_ptr[out_offset + th_id * 4 + warp_id] = res;
  155. } else {
  156. constexpr int pack_idx[4] = {0, 2, 1, 3};
  157. uint32_t res1 = 0;
  158. uint32_t res2 = 0;
  159. #pragma unroll
  160. for (int i = 0; i < 4; i++) {
  161. res1 |= vals[pack_idx[i]] << (i * 8);
  162. res2 |= vals[4 + pack_idx[i]] << (i * 8);
  163. }
  164. out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
  165. out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
  166. }
  167. };
  168. auto start_pipes = [&](int k_tile_id, int n_tile_id) {
  169. #pragma unroll
  170. for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
  171. fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
  172. }
  173. wait_for_stage();
  174. };
  175. #pragma unroll
  176. for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
  177. int n_tile_id = 0;
  178. if constexpr (has_perm) {
  179. load_perm_to_shared(k_tile_id);
  180. }
  181. start_pipes(k_tile_id, n_tile_id);
  182. while (n_tile_id < n_tiles) {
  183. #pragma unroll
  184. for (int pipe = 0; pipe < repack_stages; pipe++) {
  185. fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
  186. n_tile_id + pipe + repack_stages - 1);
  187. repack_tile(pipe, k_tile_id, n_tile_id + pipe);
  188. wait_for_stage();
  189. }
  190. n_tile_id += repack_stages;
  191. }
  192. }
  193. }
  194. } // namespace marlin
  195. #define CALL_IF(NUM_BITS, HAS_PERM) \
  196. else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
  197. cudaFuncSetAttribute( \
  198. marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
  199. HAS_PERM>, \
  200. cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
  201. marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
  202. HAS_PERM> \
  203. <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
  204. b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
  205. }
  206. torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
  207. int64_t size_k, int64_t size_n,
  208. int64_t num_bits) {
  209. // Verify compatibility with marlin tile of 16x64
  210. TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
  211. " is not divisible by tile_k_size = ", marlin::tile_k_size);
  212. TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
  213. " is not divisible by tile_n_size = ", marlin::tile_n_size);
  214. TORCH_CHECK(num_bits == 4 || num_bits == 8,
  215. "num_bits must be 4 or 8. Got = ", num_bits);
  216. int const pack_factor = 32 / num_bits;
  217. // Verify B
  218. TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
  219. "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
  220. ", size_k = ", size_k, ", pack_factor = ", pack_factor);
  221. TORCH_CHECK(b_q_weight.size(1) == size_n,
  222. "b_q_weight.size(1) = ", b_q_weight.size(1),
  223. " is not size_n = ", size_n);
  224. // Verify device and strides
  225. TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
  226. TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
  227. TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
  228. TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
  229. TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
  230. TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
  231. // Alloc buffers
  232. const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
  233. auto options = torch::TensorOptions()
  234. .dtype(b_q_weight.dtype())
  235. .device(b_q_weight.device());
  236. torch::Tensor out = torch::empty(
  237. {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
  238. options);
  239. // Detect if there is act_order
  240. bool has_perm = perm.size(0) != 0;
  241. // Get ptrs
  242. uint32_t const* b_q_weight_ptr =
  243. reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
  244. uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
  245. uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
  246. // Get dev info
  247. int dev = b_q_weight.get_device();
  248. cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
  249. int blocks;
  250. cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
  251. int max_shared_mem = 0;
  252. cudaDeviceGetAttribute(&max_shared_mem,
  253. cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
  254. TORCH_CHECK(max_shared_mem > 0);
  255. if (false) {
  256. }
  257. CALL_IF(4, false)
  258. CALL_IF(4, true)
  259. CALL_IF(8, false)
  260. CALL_IF(8, true)
  261. else {
  262. TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
  263. ", has_perm = ", has_perm);
  264. }
  265. return out;
  266. }
  267. #endif
  268. torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
  269. torch::Tensor& perm, c10::SymInt size_k,
  270. c10::SymInt size_n, int64_t num_bits) {
  271. int const pack_factor = 32 / num_bits;
  272. auto options = torch::TensorOptions()
  273. .dtype(b_q_weight.dtype())
  274. .device(b_q_weight.device());
  275. return torch::empty_symint(
  276. {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
  277. options);
  278. }