1
0

awq_marlin_repack.cu 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. #include "marlin.cuh"
  2. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  3. namespace marlin {
  4. template <int const num_threads, int const num_bits, bool const has_perm>
  5. __global__ void awq_marlin_repack_kernel(
  6. uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
  7. int size_k, int size_n) {}
  8. } // namespace marlin
  9. torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
  10. int64_t size_k, int64_t size_n,
  11. int64_t num_bits) {
  12. TORCH_CHECK_NOT_IMPLEMENTED(
  13. false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
  14. return torch::empty({1, 1});
  15. }
  16. #else
  17. namespace marlin {
  18. template <int const num_threads, int const num_bits>
  19. __global__ void awq_marlin_repack_kernel(
  20. uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
  21. int size_k, int size_n) {
  22. constexpr int pack_factor = 32 / num_bits;
  23. int k_tiles = size_k / tile_k_size;
  24. int n_tiles = size_n / tile_n_size;
  25. int block_k_tiles = div_ceil(k_tiles, gridDim.x);
  26. int start_k_tile = blockIdx.x * block_k_tiles;
  27. if (start_k_tile >= k_tiles) {
  28. return;
  29. }
  30. int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
  31. // Wait until the next thread tile has been loaded to shared memory.
  32. auto wait_for_stage = [&]() {
  33. // We only have `stages - 2` active fetches since we are double buffering
  34. // and can only issue the next fetch when it is guaranteed that the previous
  35. // shared memory load is fully complete (as it may otherwise be
  36. // overwritten).
  37. cp_async_wait<repack_stages - 2>();
  38. __syncthreads();
  39. };
  40. extern __shared__ int4 sh[];
  41. constexpr int tile_n_ints = tile_n_size / pack_factor;
  42. constexpr int stage_n_threads = tile_n_ints / 4;
  43. constexpr int stage_k_threads = tile_k_size;
  44. constexpr int stage_size = stage_k_threads * stage_n_threads;
  45. auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
  46. if (n_tile_id >= n_tiles) {
  47. cp_async_fence();
  48. return;
  49. }
  50. int first_n = n_tile_id * tile_n_size;
  51. int first_n_packed = first_n / pack_factor;
  52. int4* sh_ptr = sh + stage_size * pipe;
  53. if (threadIdx.x < stage_size) {
  54. int k_id = threadIdx.x / stage_n_threads;
  55. int n_id = threadIdx.x % stage_n_threads;
  56. int first_k = k_tile_id * tile_k_size;
  57. cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
  58. reinterpret_cast<int4 const*>(
  59. &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
  60. first_n_packed + (n_id * 4)])));
  61. }
  62. cp_async_fence();
  63. };
  64. auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
  65. if (n_tile_id >= n_tiles) {
  66. return;
  67. }
  68. int warp_id = threadIdx.x / 32;
  69. int th_id = threadIdx.x % 32;
  70. if (warp_id >= 4) {
  71. return;
  72. }
  73. int tc_col = th_id / 4;
  74. int tc_row = (th_id % 4) * 2;
  75. constexpr int tc_offsets[4] = {0, 1, 8, 9};
  76. int cur_n = warp_id * 16 + tc_col;
  77. int cur_n_packed = cur_n / pack_factor;
  78. int cur_n_pos = cur_n % pack_factor;
  79. constexpr int sh_stride = tile_n_ints;
  80. constexpr uint32_t mask = (1 << num_bits) - 1;
  81. int4* sh_stage_ptr = sh + stage_size * pipe;
  82. uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
  83. // Undo interleaving
  84. int cur_n_pos_unpacked;
  85. if constexpr (num_bits == 4) {
  86. constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
  87. cur_n_pos_unpacked = undo_pack[cur_n_pos];
  88. } else {
  89. constexpr int undo_pack[4] = {0, 2, 1, 3};
  90. cur_n_pos_unpacked = undo_pack[cur_n_pos];
  91. }
  92. uint32_t vals[8];
  93. #pragma unroll
  94. for (int i = 0; i < 4; i++) {
  95. int cur_elem = tc_row + tc_offsets[i];
  96. int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
  97. int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
  98. sh_stride * cur_elem];
  99. vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
  100. vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
  101. }
  102. constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
  103. int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
  104. // Result of:
  105. // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
  106. if constexpr (num_bits == 4) {
  107. constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
  108. uint32_t res = 0;
  109. #pragma unroll
  110. for (int i = 0; i < 8; i++) {
  111. res |= vals[pack_idx[i]] << (i * 4);
  112. }
  113. out_ptr[out_offset + th_id * 4 + warp_id] = res;
  114. } else {
  115. constexpr int pack_idx[4] = {0, 2, 1, 3};
  116. uint32_t res1 = 0;
  117. uint32_t res2 = 0;
  118. #pragma unroll
  119. for (int i = 0; i < 4; i++) {
  120. res1 |= vals[pack_idx[i]] << (i * 8);
  121. res2 |= vals[4 + pack_idx[i]] << (i * 8);
  122. }
  123. out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
  124. out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
  125. }
  126. };
  127. auto start_pipes = [&](int k_tile_id, int n_tile_id) {
  128. #pragma unroll
  129. for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
  130. fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
  131. }
  132. wait_for_stage();
  133. };
  134. #pragma unroll
  135. for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
  136. int n_tile_id = 0;
  137. start_pipes(k_tile_id, n_tile_id);
  138. while (n_tile_id < n_tiles) {
  139. #pragma unroll
  140. for (int pipe = 0; pipe < repack_stages; pipe++) {
  141. fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
  142. n_tile_id + pipe + repack_stages - 1);
  143. repack_tile(pipe, k_tile_id, n_tile_id + pipe);
  144. wait_for_stage();
  145. }
  146. n_tile_id += repack_stages;
  147. }
  148. }
  149. }
  150. } // namespace marlin
  151. #define CALL_IF(NUM_BITS) \
  152. else if (num_bits == NUM_BITS) { \
  153. cudaFuncSetAttribute( \
  154. marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
  155. cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
  156. marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
  157. <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
  158. b_q_weight_ptr, out_ptr, size_k, size_n); \
  159. }
  160. torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
  161. int64_t size_n, int64_t num_bits) {
  162. // Verify compatibility with marlin tile of 16x64
  163. TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
  164. " is not divisible by tile_k_size = ", marlin::tile_k_size);
  165. TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
  166. " is not divisible by tile_n_size = ", marlin::tile_n_size);
  167. TORCH_CHECK(num_bits == 4 || num_bits == 8,
  168. "num_bits must be 4 or 8. Got = ", num_bits);
  169. int const pack_factor = 32 / num_bits;
  170. // Verify B
  171. TORCH_CHECK(b_q_weight.size(0) == size_k,
  172. "b_q_weight.size(0) = ", b_q_weight.size(0),
  173. " is not size_k = ", size_k);
  174. TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
  175. "Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
  176. ", size_n = ", size_n, ", pack_factor = ", pack_factor);
  177. // Verify device and strides
  178. TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
  179. TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
  180. TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
  181. // Alloc buffers
  182. const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
  183. auto options = torch::TensorOptions()
  184. .dtype(b_q_weight.dtype())
  185. .device(b_q_weight.device());
  186. torch::Tensor out = torch::empty(
  187. {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
  188. options);
  189. // Get ptrs
  190. uint32_t const* b_q_weight_ptr =
  191. reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
  192. uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
  193. // Get dev info
  194. int dev = b_q_weight.get_device();
  195. cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
  196. int blocks;
  197. cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
  198. int max_shared_mem = 0;
  199. cudaDeviceGetAttribute(&max_shared_mem,
  200. cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
  201. TORCH_CHECK(max_shared_mem > 0);
  202. if (false) {
  203. }
  204. CALL_IF(4)
  205. CALL_IF(8)
  206. else {
  207. TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
  208. }
  209. return out;
  210. }
  211. #endif
  212. torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
  213. c10::SymInt size_k, c10::SymInt size_n,
  214. int64_t num_bits) {
  215. int const pack_factor = 32 / num_bits;
  216. auto options = torch::TensorOptions()
  217. .dtype(b_q_weight.dtype())
  218. .device(b_q_weight.device());
  219. return torch::empty_symint(
  220. {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
  221. options);
  222. }