Browse Source

gptq_marlin: 8bit GPTQ support

AlpinDale 7 tháng trước cách đây
mục cha
commit
c154578c97

+ 31 - 48
aphrodite/quantization/gptq_marlin.py

@@ -3,7 +3,6 @@ from contextlib import suppress
 from enum import Enum
 from typing import Any, Dict, List, Optional
 
-import numpy
 import torch
 from torch.nn.parameter import Parameter
 
@@ -21,41 +20,13 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
 GPTQ_MARLIN_MIN_THREAD_K = 128
 GPTQ_MARLIN_MAX_PARALLEL = 16
 
-GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4]
+GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
 GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
 GPTQ_MARLIN_SUPPORTED_SYM = [True]
 
 
-# Precompute permutations for Marlin weight and scale shuffling
-#
-# Marlin works on [16,64] tiles. The goal of the permutations
-# is to reorder the weight data so that it is compatible
-# with the tensor-core format that is described here:
-# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
-#
-# As a result of this reordering, the vector loads inside the
-# kernel will get the data as it is needed for tensor-core
-# (without the need to use ldmatrix instructions)
-def _get_perms():
-    perm = []
-    for i in range(32):
-        perm1 = []
-        col = i // 4
-        for block in [0, 1]:
-            for row in [
-                    2 * (i % 4),
-                    2 * (i % 4) + 1,
-                    2 * (i % 4 + 4),
-                    2 * (i % 4 + 4) + 1,
-            ]:
-                perm1.append(16 * row + col + 8 * block)
-        for j in range(4):
-            perm.extend([p + 256 * j for p in perm1])
-
-    perm = numpy.array(perm)
-    interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
-    perm = perm.reshape((-1, 8))[:, interleave].ravel()  # type: ignore
-    perm = torch.from_numpy(perm)
+# Permutations for Marlin scale shuffling
+def get_scale_perms(num_bits):
     scale_perm = []
     for i in range(8):
         scale_perm.extend([i + 8 * j for j in range(8)])
@@ -63,23 +34,21 @@ def _get_perms():
     for i in range(4):
         scale_perm_single.extend(
             [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
-    return perm, scale_perm, scale_perm_single
-
-
-_perm, _scale_perm, _scale_perm_single = _get_perms()
+    return scale_perm, scale_perm_single
 
 
 def get_pack_factor(num_bits):
-    assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, (
-        f"Unsupported num_bits = {num_bits}")
+    assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
+            ), f"Unsupported num_bits = {num_bits}"
     return 32 // num_bits
 
 
-def marlin_permute_scales(s, size_k, size_n, group_size):
+def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
+    scale_perm, scale_perm_single = get_scale_perms(num_bits)
     if group_size < size_k and group_size != -1:
-        s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
+        s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
     else:
-        s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
+        s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
     s = s.reshape((-1, size_n)).contiguous()
 
     return s
@@ -424,6 +393,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
                 layer.g_idx_sort_indices,
                 part_size_k,
                 part_size_n,
+                self.quant_config.weight_bits,
             )
             replace_tensor("qweight", marlin_qweight)
 
@@ -433,15 +403,28 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
             if self.quant_config.desc_act:
                 scales_size_k = full_size_k
 
-            marlin_scales = marlin_permute_scales(layer.scales, scales_size_k,
-                                                  scales_size_n,
-                                                  self.quant_config.group_size)
+            marlin_scales = marlin_permute_scales(
+                layer.scales,
+                scales_size_k,
+                scales_size_n,
+                self.quant_config.group_size,
+                self.quant_config.weight_bits,
+            )
             replace_tensor("scales", marlin_scales)
 
-        output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales,
-                                      layer.g_idx, layer.g_idx_sort_indices,
-                                      layer.workspace, size_m, part_size_n,
-                                      part_size_k, layer.is_k_full)
+        output = ops.gptq_marlin_gemm(
+            reshaped_x,
+            layer.qweight,
+            layer.scales,
+            layer.g_idx,
+            layer.g_idx_sort_indices,
+            layer.workspace,
+            self.quant_config.weight_bits,
+            size_m,
+            part_size_n,
+            part_size_k,
+            layer.is_k_full,
+        )
 
         if bias is not None:
             output.add_(bias)  # In-place add

+ 377 - 175
kernels/quantization/gptq_marlin/gptq_marlin.cu

@@ -32,7 +32,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
                                     int4 *__restrict__ out_int4_ptr, int size_m,
                                     int size_k, int block_rows) {}
 
-template <const int threads,         // number of threads in a threadblock
+template <const int num_bits,        // number of bits used for weights
+          const int threads,         // number of threads in a threadblock
           const int thread_m_blocks, // number of 16x16 blocks in the m
                                      // dimension (batchsize) of the threadblock
           const int thread_n_blocks, // same for n dimension (output)
@@ -62,8 +63,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
 torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
                                torch::Tensor &b_scales, torch::Tensor &g_idx,
                                torch::Tensor &perm, torch::Tensor &workspace,
-                               int64_t size_m, int64_t size_n, int64_t size_k,
-                               bool is_k_full) {
+                               int64_t num_bits, int64_t size_m, int64_t size_n,
+                               int64_t size_k, bool is_k_full) {
   TORCH_CHECK_NOT_IMPLEMENTED(false,
                               "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
   return torch::empty({1, 1});
@@ -114,11 +115,21 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) {
   return res;
 }
 
+// Constructs destination register by taking bytes from 2 sources (based on mask)
+template <int start_byte, int mask>
+__device__ inline uint32_t prmt(uint32_t a) {
+  uint32_t res;
+  asm volatile("prmt.b32 %0, %1, %2, %3;\n"
+               : "=r"(res)
+               : "r"(a), "n"(start_byte), "n"(mask));
+  return res;
+}
+
 // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
 // values. We mostly follow the strategy in the link below, with some small
 // changes:
 // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
-__device__ inline FragB dequant(int q) {
+__device__ inline FragB dequant_4bit(int q) {
   const int LO = 0x000f000f;
   const int HI = 0x00f000f0;
   const int EX = 0x64006400;
@@ -139,6 +150,24 @@ __device__ inline FragB dequant(int q) {
   return frag_b;
 }
 
+__device__ inline FragB dequant_8bit(int q) {
+  static constexpr uint32_t mask_for_elt_01 = 0x5250;
+  static constexpr uint32_t mask_for_elt_23 = 0x5351;
+  static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
+
+  uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
+  uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
+
+  static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
+
+  FragB frag_b;
+  frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
+                      *reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
+  frag_b[1] = __hsub2(*reinterpret_cast<half2 *>(&hi),
+                      *reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
+  return frag_b;
+}
+
 // Multiply dequantized values by the corresponding quantization scale; used
 // only for grouped quantization.
 __device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
@@ -162,6 +191,13 @@ __device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2,
   frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
 }
 
+// Given 2 floats multiply by 2 scales (halves)
+__device__ inline void scale_float(float *c, FragS &s) {
+  __half *s_ptr = reinterpret_cast<__half *>(&s);
+  c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
+  c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
+}
+
 // Wait until barrier reaches `count`, then lock for current threadblock.
 __device__ inline void barrier_acquire(int *lock, int count) {
   if (threadIdx.x == 0) {
@@ -250,7 +286,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
   }
 }
 
-template <const int threads,         // number of threads in a threadblock
+template <const int num_bits,        // number of bits used for weights
+          const int threads,         // number of threads in a threadblock
           const int thread_m_blocks, // number of 16x16 blocks in the m
                                      // dimension (batchsize) of the threadblock
           const int thread_n_blocks, // same for n dimension (output)
@@ -286,6 +323,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
   // configurations, while requiring as few slow global cross-threadblock
   // reductions as possible.
 
+  constexpr int pack_factor = 32 / num_bits;
+
   // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
   // better partitioning with less reductions
   int parallel = 1;
@@ -385,21 +424,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
   constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
 
   // B sizes/strides
-  int b_gl_stride = 16 * prob_n / 32;
-  constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
+  int b_gl_stride = 16 * prob_n / (pack_factor * 4);
+  constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
+  constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
+  constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
+
   int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
-  int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
-  constexpr int b_sh_wr_delta = threads;
-  constexpr int b_sh_rd_delta = threads;
+  int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
+  constexpr int b_sh_wr_delta = threads * b_thread_vecs;
+  constexpr int b_sh_rd_delta = threads * b_thread_vecs;
   constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
   constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
 
   // Scale sizes/strides without act_order
   int s_gl_stride = prob_n / 8;
   constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
-  constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks
-                                  ? thread_k_blocks / group_blocks
-                                  : 1;
+  constexpr int s_tb_groups =
+      !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
+          ? thread_k_blocks / group_blocks
+          : 1;
   constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
   int s_gl_rd_delta = s_gl_stride;
 
@@ -425,12 +468,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
       a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
   a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
 
-  int b_gl_rd =
-      b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
+  int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
+                (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
   b_gl_rd += b_sh_stride * slice_col;
   b_gl_rd += b_gl_rd_delta_o * slice_row;
-  int b_sh_wr = threadIdx.x;
-  int b_sh_rd = threadIdx.x;
+  int b_sh_wr = threadIdx.x * b_thread_vecs;
+  int b_sh_rd = threadIdx.x * b_thread_vecs;
 
   // For act_order
   constexpr int k_iter_size = tb_k / b_sh_wr_iters;
@@ -442,8 +485,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
   // No act_order
   int s_gl_rd;
   if constexpr (!has_act_order) {
-    s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
-              s_sh_stride * slice_col + threadIdx.x;
+    if constexpr (group_blocks == -1) {
+      s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
+    } else {
+      s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
+                s_sh_stride * slice_col + threadIdx.x;
+    }
   }
   int s_sh_wr = threadIdx.x;
   bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
@@ -511,7 +558,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
 
   // Register storage for double buffer of shared memory reads.
   FragA frag_a[2][thread_m_blocks];
-  I4 frag_b_quant[2];
+  I4 frag_b_quant[2][b_thread_vecs];
   FragC frag_c[thread_m_blocks][4][2];
   FragS frag_s[2][4];        // No act-order
   FragS act_frag_s[2][4][4]; // For act-order
@@ -575,7 +622,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
       int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
 #pragma unroll
       for (int i = 0; i < b_sh_wr_iters; i++) {
-        cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
+#pragma unroll
+        for (int j = 0; j < b_thread_vecs; j++) {
+          cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
+        }
+
         B_ptr[i] += b_gl_rd_delta_o;
       }
 
@@ -602,15 +653,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
             // Only fetch scales if this tile starts a new group
             if (pipe % (group_blocks / thread_k_blocks) == 0) {
               if (s_sh_wr_pred) {
-                cp_async4_stream(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
+                cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
               }
               s_gl_rd += s_gl_rd_delta;
             }
           } else {
             for (int i = 0; i < s_tb_groups; i++) {
               if (s_sh_wr_pred) {
-                cp_async4_stream(&sh_s_stage[i * s_sh_stride + s_sh_wr],
-                                 &scales_ptr[s_gl_rd]);
+                cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
+                          &scales_ptr[s_gl_rd]);
               }
               s_gl_rd += s_gl_rd_delta;
             }
@@ -641,14 +692,24 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
     for (int i = 0; i < thread_m_blocks; i++)
       ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
     int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
-    frag_b_quant[k % 2] = *reinterpret_cast<I4 *>(
-        &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
+
+#pragma unroll
+    for (int i = 0; i < b_thread_vecs; i++) {
+      frag_b_quant[k % 2][i] = *reinterpret_cast<I4 *>(
+          &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
+    }
   };
 
   bool is_same_group[stages];
   int same_group_id[stages];
 
   auto init_same_group = [&](int pipe) {
+    if constexpr (!has_act_order) {
+      is_same_group[pipe] = false;
+      same_group_id[pipe] = 0;
+      return;
+    }
+
     int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
     int *sh_g_idx_int_ptr = reinterpret_cast<int *>(sh_g_idx_stage);
 
@@ -767,10 +828,23 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
 // dequantization and matmul operations.
 #pragma unroll
     for (int j = 0; j < 4; j++) {
-      int b_quant = frag_b_quant[k % 2][j];
-      int b_quant_shift = b_quant >> 8;
+      FragB frag_b0;
+      FragB frag_b1;
+      if constexpr (num_bits == 4) {
+        int b_quant = frag_b_quant[k % 2][0][j];
+        int b_quant_shift = b_quant >> 8;
+
+        frag_b0 = dequant_4bit(b_quant);
+        frag_b1 = dequant_4bit(b_quant_shift);
 
-      FragB frag_b0 = dequant(b_quant);
+      } else {
+        int *frag_b_quant_ptr = reinterpret_cast<int *>(frag_b_quant[k % 2]);
+        int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
+        int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
+
+        frag_b0 = dequant_8bit(b_quant_0);
+        frag_b1 = dequant_8bit(b_quant_1);
+      }
 
       // Apply scale to frag_b0
       if constexpr (has_act_order) {
@@ -782,8 +856,6 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
         }
       }
 
-      FragB frag_b1 = dequant(b_quant_shift);
-
       // Apply scale to frag_b1
       if constexpr (has_act_order) {
         scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
@@ -808,13 +880,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
   // multiple warps that accumulate their partial sums of the same output
   // location; which we have to reduce over in the end. We do in shared memory.
   auto thread_block_reduce = [&]() {
-    constexpr int red_off = threads / b_sh_stride / 2;
+    constexpr int red_off = threads / b_sh_stride_threads / 2;
     if (red_off >= 1) {
-      int red_idx = threadIdx.x / b_sh_stride;
-      constexpr int red_sh_stride = b_sh_stride * 4 * 2;
-      constexpr int red_sh_delta = b_sh_stride;
-      int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
-                      (threadIdx.x % b_sh_stride);
+      int red_idx = threadIdx.x / b_sh_stride_threads;
+      constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
+      constexpr int red_sh_delta = b_sh_stride_threads;
+      int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
+                      (threadIdx.x % b_sh_stride_threads);
 
       // Parallel logarithmic shared memory reduction. We make sure to avoid any
       // unnecessary read or write iterations, e.g., for two warps we write only
@@ -861,7 +933,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
   };
 
   // Since multiple threadblocks may process parts of the same column slice, we
-  // finally have to globally reduce over the results. As the striped portioning
+  // finally have to globally reduce over the results. As the striped partitioning
   // minimizes the number of such reductions and our outputs are usually rather
   // small, we perform this reduction serially in L2 cache.
   auto global_reduce = [&](bool first = false, bool last = false) {
@@ -951,13 +1023,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
     auto write = [&](int idx, float c0, float c1, FragS &s) {
       half2 res = __halves2half2(__float2half(c0), __float2half(c1));
 
-      // For per-column quantization we finally apply the scale here
-      if constexpr (!has_act_order && group_blocks == -1) {
+      // For per-column quantization we finally apply the scale here (only for
+      // 4-bit)
+      if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {
         res = __hmul2(res, s[0]);
       }
 
       ((half2 *)sh)[idx] = res;
     };
+
     if (threadIdx.x / 32 < thread_n_blocks / 4) {
 #pragma unroll
       for (int i = 0; i < thread_m_blocks; i++) {
@@ -1023,6 +1097,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
     // ensure all shared memory accesses are static. Note that both pipelines
     // have even length meaning that the next iteration will always start at
     // index 0.
+
 #pragma unroll
     for (int pipe = 0; pipe < stages;) {
 #pragma unroll
@@ -1070,23 +1145,63 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
       // For per-column scales, we only fetch them here in the final step before
       // write-out
       if constexpr (!has_act_order && group_blocks == -1) {
-        if (last) {
+        if constexpr (num_bits == 8) {
           if (s_sh_wr_pred) {
-            cp_async4_stream(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
+            cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
           }
           cp_async_fence();
+        } else {
+          if (last) {
+            if (s_sh_wr_pred) {
+              cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
+            }
+            cp_async_fence();
+          }
         }
       }
 
       thread_block_reduce();
       if constexpr (!has_act_order && group_blocks == -1) {
-        if (last) {
+        if constexpr (num_bits == 8) {
           cp_async_wait<0>();
           __syncthreads();
           if (threadIdx.x / 32 < thread_n_blocks / 4) {
             reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0];
             reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4];
           }
+
+        } else {
+          if (last) {
+            cp_async_wait<0>();
+            __syncthreads();
+            if (threadIdx.x / 32 < thread_n_blocks / 4) {
+              reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0];
+              reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4];
+            }
+          }
+        }
+      }
+
+      // For 8-bit channelwise, we apply the scale before the global reduction
+      // that converts the fp32 results to fp16 (so that we avoid possible
+      // overflow in fp16)
+      if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {
+        if (threadIdx.x / 32 < thread_n_blocks / 4) {
+#pragma unroll
+          for (int i = 0; i < thread_m_blocks; i++) {
+#pragma unroll
+            for (int j = 0; j < 4; j++) {
+              scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
+                          frag_s[j / 2][2 * (j % 2) + 0]);
+              scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
+                          frag_s[j / 2][2 * (j % 2) + 0]);
+
+              scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
+                          frag_s[j / 2][2 * (j % 2) + 1]);
+              scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
+                          frag_s[j / 2][2 * (j % 2) + 1]);
+            }
+          }
         }
       }
 
@@ -1125,28 +1240,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
           s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
         }
 
-        // if (blockIdx.x == 0 && threadIdx.x == 0) {
-        //   printf("Move\n");
-        // }
         start_pipes();
       }
     }
   }
 }
 
-#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS,           \
+#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
                   HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS)                    \
-  else if (thread_m_blocks == THREAD_M_BLOCKS &&                               \
+  else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS &&       \
            thread_n_blocks == THREAD_N_BLOCKS &&                               \
            thread_k_blocks == THREAD_K_BLOCKS &&                               \
            has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS &&   \
            num_threads == NUM_THREADS) {                                       \
     cudaFuncSetAttribute(                                                      \
-        Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
-               pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>,                      \
+        Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,        \
+               THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>,     \
         cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);          \
-    Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS,     \
-           pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>                           \
+    Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,            \
+           THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>          \
         <<<blocks, NUM_THREADS, max_shared_mem, stream>>>(                     \
             A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
             prob_k, locks);                                                    \
@@ -1158,28 +1270,92 @@ typedef struct {
   int num_threads;
 } thread_config_t;
 
-thread_config_t small_batch_thread_configs[] = {
+typedef struct {
+  int max_m_blocks;
+  thread_config_t tb_cfg;
+} exec_config_t;
+
+thread_config_t thread_configs[] = {
     // Ordered by priority
 
     // thread_k, thread_n, num_threads
-    {128, 128, 256}, // Default
-    {128, 64, 128},  // Reduce N 2X, same K
-    {64, 256, 256},  // Reduce K 2X, increase N 2X
-    {64, 128, 128},  // Reduce K 2X, same N
+    {64, 256, 256}, // Default (max cache usage)
+    {64, 128, 128}, // Reduce N, reduce warps
+    {128, 64, 128}, // Reduce N more, but increase K
+
 };
 
-thread_config_t large_batch_thread_configs[] = {
-    // Ordered by priority
+int get_scales_cache_size(thread_config_t const &th_config, int prob_m,
+                          int prob_n, int prob_k, int num_bits, int group_size,
+                          bool has_act_order, bool is_k_full) {
+  bool cache_scales_chunk = has_act_order && !is_k_full;
 
-    // thread_k, thread_n, num_threads
-    {64, 256, 256}, // Default
-    {128, 64, 128}, // Reduce N 2X, same K
-    {64, 128, 128}, // Reduce N 2X, same K
-                    // {128, 64, 128},  // Reduce N 4X, increase K 2X
-};
+  int tb_n = th_config.thread_n;
+  int tb_k = th_config.thread_k;
+
+  // Get max scale groups per thread-block
+  int tb_groups;
+  if (group_size == -1) {
+    tb_groups = 1;
+  } else if (group_size == 0) {
+    tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
+  } else {
+    tb_groups = div_ceil(tb_k, group_size);
+  }
+
+  if (cache_scales_chunk) {
+    int load_groups =
+        tb_groups * pipe_stages * 2;    // Chunk size is 2x pipeline over dim K
+    load_groups = max(load_groups, 32); // We load at least 32 scale groups
+    return load_groups * tb_n * 2;
+
+  } else {
+    int tb_scales = tb_groups * tb_n * 2;
+
+    return tb_scales * pipe_stages;
+  }
+}
+
+bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks,
+                         int prob_m, int prob_n, int prob_k, int num_bits,
+                         int scales_cache_size, int max_shared_mem) {
+  int pack_factor = 32 / num_bits;
+
+  // Get B size
+  int tb_k = th_config.thread_k;
+  int tb_n = th_config.thread_n;
+
+  int b_size = (tb_k * tb_n / pack_factor) * 4;
+
+  // Get A size
+  int m_blocks = div_ceil(prob_m, 16);
+  int tb_max_m = 16;
 
-bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
-                     int prob_k) {
+  while (true) {
+    if (m_blocks >= max_m_blocks) {
+      tb_max_m *= max_m_blocks;
+      break;
+    }
+
+    max_m_blocks--;
+    if (max_m_blocks == 0) {
+      TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
+    }
+  }
+
+  int a_size = (tb_max_m * tb_k) * 2;
+
+  float pipe_size = (a_size + b_size) * pipe_stages;
+
+  TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
+
+  return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
+}
+
+bool is_valid_config(thread_config_t const &th_config, int max_m_blocks,
+                     int prob_m, int prob_n, int prob_k, int num_bits,
+                     int group_size, bool has_act_order, bool is_k_full,
+                     int max_shared_mem) {
   // Sanity
   if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
       th_config.num_threads == -1) {
@@ -1201,62 +1377,79 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
     return false;
   }
 
+  //  Determine cache for scales
+  int scales_cache_size =
+      get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
+                            group_size, has_act_order, is_k_full);
+
+  // Check that pipeline fits into cache
+  if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
+                           num_bits, scales_cache_size, max_shared_mem)) {
+    return false;
+  }
+
   return true;
 }
 
-thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
-
-  // TODO: Enable if needed after some more testing
-  if (prob_m <= 0) {
-    for (auto th_config : small_batch_thread_configs) {
-      if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
-        return th_config;
+exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
+                                      int num_bits, int group_size,
+                                      bool has_act_order, bool is_k_full,
+                                      int max_shared_mem) {
+  int max_m_blocks = 4;
+  while (max_m_blocks > 0) {
+    for (auto th_config : thread_configs) {
+      if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
+                          num_bits, group_size, has_act_order, is_k_full,
+                          max_shared_mem)) {
+        return exec_config_t{max_m_blocks, th_config};
       }
     }
 
-  } else {
-    for (auto th_config : large_batch_thread_configs) {
-      if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
-        return th_config;
-      }
-    }
+    printf("WARNING: Marlin kernel is reducing max_m_blocks due to small SM "
+           "GPU cache. This may "
+           "hurt performance. Consider upgrading your GPU.\n");
+
+    max_m_blocks--; // Process less M blocks per invocation to reduce cache
+                    // usage
   }
 
-  return thread_config_t{-1, -1, -1};
+  return exec_config_t{0, {-1, -1, -1}};
 }
 
-#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS)                               \
-  __CALL_IF(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)                       \
-  __CALL_IF(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)                       \
-  __CALL_IF(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)                       \
-  __CALL_IF(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)                       \
+#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS)                     \
+  __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)             \
+  __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)             \
+  __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)             \
+  __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)             \
                                                                                \
-  __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS)                     \
-  __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)                      \
-  __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)                      \
-  __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)                      \
+  __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS)           \
+  __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)            \
+  __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)            \
+  __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)            \
                                                                                \
-  __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS)                     \
-  __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)                      \
-  __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)                      \
-  __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)                      \
+  __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS)           \
+  __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)            \
+  __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)            \
+  __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)            \
                                                                                \
-  __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS)                     \
-  __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)                      \
-  __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)                      \
-  __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)                      \
+  __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS)           \
+  __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)            \
+  __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)            \
+  __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)            \
                                                                                \
-  __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS)                     \
-  __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)                      \
-  __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)                      \
-  __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
-
-void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
-                 void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k,
-                 void *workspace, bool has_act_order, bool is_k_full,
-                 int num_groups, int group_size, int dev = 0,
-                 cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1,
-                 int sms = -1, int max_par = 16) {
+  __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS)           \
+  __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)            \
+  __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)            \
+  __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
+
+void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
+                     void *g_idx, void *perm, void *a_tmp, int prob_m,
+                     int prob_n, int prob_k, void *workspace, int num_bits,
+                     bool has_act_order, bool is_k_full, int num_groups,
+                     int group_size, int dev, cudaStream_t stream, int thread_k,
+                     int thread_n, int sms, int max_par) {
+  TORCH_CHECK(num_bits == 4 || num_bits == 8,
+              "num_bits must be 4 or 8. Got = ", num_bits);
   TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
               ", ", prob_n, ", ", prob_k, "]");
 
@@ -1274,25 +1467,34 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
   TORCH_CHECK(max_shared_mem > 0);
 
   // Set thread config
-  thread_config_t th_config;
+  exec_config_t exec_cfg;
   if (thread_k != -1 && thread_n != -1) {
     // User-defined config
-    th_config = thread_config_t{thread_k, thread_n, default_threads};
+    exec_cfg =
+        exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
   } else {
     // Auto config
-    th_config = determine_thread_config(prob_m, prob_n, prob_k);
+    exec_cfg =
+        determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
+                                has_act_order, is_k_full, max_shared_mem);
   }
 
-  TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k),
-              "Invalid thread config: thread_k = " + str(th_config.thread_k) +
-                  ", thread_n = " + str(th_config.thread_n) +
-                  ", num_threads = " + str(th_config.num_threads) +
-                  " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " +
-                  str(prob_n) + "]");
-
-  int num_threads = th_config.num_threads;
-  thread_k = th_config.thread_k;
-  thread_n = th_config.thread_n;
+  TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
+                  is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
+                                  prob_m, prob_n, prob_k, num_bits, group_size,
+                                  has_act_order, is_k_full, max_shared_mem),
+              "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
+              ", thread_k = ", exec_cfg.tb_cfg.thread_k,
+              ", thread_n = ", exec_cfg.tb_cfg.thread_n,
+              ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
+              prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
+              ", group_size = ", group_size,
+              ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
+              ", max_shared_mem = ", max_shared_mem);
+
+  int num_threads = exec_cfg.tb_cfg.num_threads;
+  thread_k = exec_cfg.tb_cfg.thread_k;
+  thread_n = exec_cfg.tb_cfg.thread_n;
 
   int thread_k_blocks = thread_k / 16;
   int thread_n_blocks = thread_n / 16;
@@ -1352,28 +1554,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
   }
 
   // Main loop
-  for (int i = 0; i < tot_m_blocks; i += 4) {
+  for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
     int thread_m_blocks = tot_m_blocks - i;
     prob_m = tot_m - 16 * i;
     int par = 1;
-    if (thread_m_blocks > 4) {
+    if (thread_m_blocks > exec_cfg.max_m_blocks) {
       // Note that parallel > 1 currently only works for inputs without any
       // padding
-      par = (16 * thread_m_blocks - pad) / 64;
+      par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
       if (par > max_par)
         par = max_par;
-      prob_m = 64 * par;
-      i += 4 * (par - 1);
-      thread_m_blocks = 4;
+      prob_m = (16 * exec_cfg.max_m_blocks) * par;
+      i += exec_cfg.max_m_blocks * (par - 1);
+      thread_m_blocks = exec_cfg.max_m_blocks;
     }
 
     // Define kernel configurations
     if (false) {
     }
-    CALL_IF(16, 4, 256)
-    CALL_IF(8, 8, 256)
-    CALL_IF(8, 4, 128)
-    CALL_IF(4, 8, 128)
+    CALL_IF(4, 32, 2, 256)
+    CALL_IF(4, 16, 4, 256)
+    CALL_IF(4, 8, 4, 128)
+    CALL_IF(4, 4, 8, 128)
+    CALL_IF(8, 32, 2, 256)
+    CALL_IF(8, 16, 4, 256)
+    CALL_IF(8, 8, 4, 128)
+    CALL_IF(8, 4, 8, 128)
     else {
       TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
                              str(prob_n) + ", " + str(prob_k) + "]" +
@@ -1395,33 +1601,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
 torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
                                torch::Tensor &b_scales, torch::Tensor &g_idx,
                                torch::Tensor &perm, torch::Tensor &workspace,
-                               int64_t size_m, int64_t size_n, int64_t size_k,
-                               bool is_k_full) {
+                               int64_t num_bits, int64_t size_m, int64_t size_n,
+                               int64_t size_k, bool is_k_full) {
+  // Verify num_bits
+  TORCH_CHECK(num_bits == 4 || num_bits == 8,
+              "num_bits must be 4 or 8. Got = ", num_bits);
+  int pack_factor = 32 / num_bits;
+
   // Verify A
-  TORCH_CHECK(a.size(0) == size_m,
-              "Shape mismatch: a.size(0) = " + str(a.size(0)) +
-                  ", size_m = " + str(size_m));
-  TORCH_CHECK(a.size(1) == size_k,
-              "Shape mismatch: a.size(1) = " + str(a.size(1)) +
-                  ", size_k = " + str(size_k));
+  TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
+              ", size_m = ", size_m);
+  TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
+              ", size_k = ", size_k);
 
   // Verify B
-  TORCH_CHECK(size_k % gptq_marlin::tile_size == 0,
-              "size_k = " + str(size_k) + " is not divisible by tile_size = " +
-                  str(gptq_marlin::tile_size));
+  TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
+              " is not divisible by tile_size = ", gptq_marlin::tile_size);
   TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
-              "Shape mismatch: b_q_weight.size(0) = " +
-                  str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
-                  ", tile_size = " + str(gptq_marlin::tile_size));
-  TORCH_CHECK(
-      b_q_weight.size(1) % gptq_marlin::tile_size == 0,
-      "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
-          " is not divisible by tile_size = " + str(gptq_marlin::tile_size));
-  int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) *
-                      gptq_marlin::pack_factor_4bit;
-  TORCH_CHECK(size_n == actual_size_n,
-              "size_n = " + str(size_n) +
-                  ", actual_size_n = " + str(actual_size_n));
+              "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
+              ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
+  TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
+              "b_q_weight.size(1) = ", b_q_weight.size(1),
+              " is not divisible by tile_size = ", gptq_marlin::tile_size);
+  int actual_size_n =
+      (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
+  TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
+              ", actual_size_n = ", actual_size_n);
 
   // Verify device and strides
   TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
@@ -1457,9 +1662,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
   // Verify g_idx and perm
   TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||
                   (g_idx.size(0) == size_k && perm.size(0) == size_k),
-              "Unexpected g_idx.size(0) = " + str(g_idx.size(0)) +
-                  " and perm.size(0) = " + str(perm.size(0)) +
-                  ", where size_k = " + str(size_k));
+              "Unexpected g_idx.size(0) = ", g_idx.size(0),
+              " and perm.size(0) = ", perm.size(0),
+              ", where size_k = ", size_k);
 
   // Detect groupsize and act_order
   int num_groups = -1;
@@ -1475,9 +1680,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
   if (has_act_order) {
     if (is_k_full) {
       TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
-      TORCH_CHECK(size_k % num_groups == 0,
-                  "size_k = " + str(size_k) +
-                      ", is not divisible by num_groups = " + str(num_groups));
+      TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
+                  ", is not divisible by num_groups = ", num_groups);
       group_size = size_k / num_groups;
     } else {
       group_size = 0;
@@ -1485,10 +1689,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
 
   } else {
     if (num_groups > 1) {
-      TORCH_CHECK(size_k % num_groups == 0,
-                  "size_k = " + str(size_k) +
-                      ", is not divisible by b_scales.size(0) = " +
-                      str(b_scales.size(0)));
+      TORCH_CHECK(
+          size_k % num_groups == 0, "size_k = ", size_k,
+          ", is not divisible by b_scales.size(0) = ", b_scales.size(0));
       group_size = size_k / num_groups;
     } else {
       group_size = -1;
@@ -1496,23 +1699,22 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
   }
 
   // Verify workspace size
-  TORCH_CHECK(size_n % gptq_marlin::min_thread_n == 0,
-              "size_n = " + str(size_n) +
-                  ", is not divisible by min_thread_n = " +
-                  str(gptq_marlin::min_thread_n));
+  TORCH_CHECK(
+      size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
+      ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
   int min_workspace_size =
       (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
   TORCH_CHECK(workspace.numel() >= min_workspace_size,
-              "workspace.numel = " + str(workspace.numel()) +
-                  " is below min_workspace_size = " + str(min_workspace_size));
+              "workspace.numel = ", workspace.numel(),
+              " is below min_workspace_size = ", min_workspace_size);
 
   int dev = a.get_device();
-  gptq_marlin::marlin_cuda(
+  gptq_marlin::marlin_mm_f16i4(
       a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(),
       g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n,
-      size_k, workspace.data_ptr(), has_act_order, is_k_full, num_groups,
-      group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
-      sms, gptq_marlin::max_par);
+      size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
+      num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
+      thread_k, thread_n, sms, gptq_marlin::max_par);
 
   return c;
 }

+ 2 - 6
kernels/quantization/gptq_marlin/gptq_marlin.cuh

@@ -24,8 +24,6 @@ static constexpr int min_thread_k = 64;
 static constexpr int tile_size = 16;
 static constexpr int max_par   = 16;
 
-static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit
-
 template <typename T, int n>
 struct Vec {
   T             elems[n];
@@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
                "r"(smem), "l"(glob_ptr), "n"(BYTES));
 }
 
-__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
+__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
   const int BYTES = 16;
   uint32_t  smem  = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
   asm volatile("{\n"
-               "   .reg .b64 p;\n"
-               "   createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
-               "   cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
+               "   cp.async.cg.shared.global [%0], [%1], %2;\n"
                "}\n" ::"r"(smem),
                "l"(glob_ptr), "n"(BYTES));
 }

+ 90 - 62
kernels/quantization/gptq_marlin/gptq_marlin_repack.cu

@@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4;
 
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
 
-template <int const num_threads, bool const has_perm>
+template <int const num_threads, int const num_bits, bool const has_perm>
 __global__ void
 marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
                      uint32_t const *__restrict__ perm_ptr,
@@ -20,7 +20,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
 } // namespace gptq_marlin
 
 torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
-                                 int64_t size_k, int64_t size_n) {
+                                 int64_t size_k, int64_t size_n,
+                                 int64_t num_bits) {
   TORCH_CHECK_NOT_IMPLEMENTED(
       false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
   return torch::empty({1, 1});
@@ -28,11 +29,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
 
 #else
 
-template <int const num_threads, bool const has_perm>
+template <int const num_threads, int const num_bits, bool const has_perm>
 __global__ void
 marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
                      uint32_t const *__restrict__ perm_ptr,
                      uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
+  constexpr int pack_factor = 32 / num_bits;
+
   int k_tiles = size_k / tile_k_size;
   int n_tiles = size_n / tile_n_size;
   int block_k_tiles = div_ceil(k_tiles, gridDim.x);
@@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
     sh_pipe_ptr += perm_size;
   }
 
+  constexpr int tile_ints = tile_k_size / pack_factor;
+
   constexpr int stage_n_threads = tile_n_size / 4;
-  constexpr int stage_k_threads =
-      has_perm ? tile_k_size : tile_k_size / pack_factor_4bit;
+  constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
   constexpr int stage_size = stage_k_threads * stage_n_threads;
 
   auto load_perm_to_shared = [&](int k_tile_id) {
@@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
             reinterpret_cast<uint32_t const *>(sh_perm_ptr);
 
         int src_k = sh_perm_int_ptr[k_id];
-        int src_k_packed = src_k / pack_factor_4bit;
+        int src_k_packed = src_k / pack_factor;
 
-        cp_async4_stream(
+        cp_async4(
             &sh_ptr[k_id * stage_n_threads + n_id],
             reinterpret_cast<int4 const *>(&(
                 b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
@@ -113,12 +117,12 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
         int n_id = threadIdx.x % stage_n_threads;
 
         int first_k = k_tile_id * tile_k_size;
-        int first_k_packed = first_k / pack_factor_4bit;
+        int first_k_packed = first_k / pack_factor;
 
-        cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id],
-                         reinterpret_cast<int4 const *>(
-                             &(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
-                                              first_n + (n_id * 4)])));
+        cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
+                  reinterpret_cast<int4 const *>(
+                      &(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
+                                       first_n + (n_id * 4)])));
       }
     }
 
@@ -145,26 +149,27 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
     int cur_n = warp_id * 16 + tc_col;
 
     constexpr int sh_stride = 64;
+    constexpr uint32_t mask = (1 << num_bits) - 1;
 
     int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
     uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr);
 
     uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr);
 
-    uint32_t vals[pack_factor_4bit];
+    uint32_t vals[8];
 
     if constexpr (has_perm) {
       for (int i = 0; i < 4; i++) {
         int k_idx = tc_row + tc_offsets[i];
 
         uint32_t src_k = sh_perm_int_ptr[k_idx];
-        uint32_t src_k_pos = src_k % pack_factor_4bit;
+        uint32_t src_k_pos = src_k % pack_factor;
 
         uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
-        uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf;
+        uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
 
         uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
-        uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf;
+        uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
 
         vals[i] = b1_cur_val;
         vals[4 + i] = b2_cur_val;
@@ -172,41 +177,56 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
 
     } else {
 
-      uint32_t b1_val_1 = sh_stage_int_ptr[cur_n];
-      uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n];
-
-      uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8];
-      uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8];
+      uint32_t b1_vals[tile_ints];
+      uint32_t b2_vals[tile_ints];
 
 #pragma unroll
-      for (int i = 0; i < 2; i++) {
-        int cur_elem = tc_row + tc_offsets[i];
-        vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf;
-        vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf;
+      for (int i = 0; i < tile_ints; i++) {
+        b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
+        b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
       }
 
 #pragma unroll
-      for (int i = 2; i < 4; i++) {
-        int cur_elem = tc_row + tc_offsets[i] - 8;
-        vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf;
-        vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf;
+      for (int i = 0; i < 4; i++) {
+        int cur_elem = tc_row + tc_offsets[i];
+        int cur_int = cur_elem / pack_factor;
+        int cur_pos = cur_elem % pack_factor;
+
+        vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
+        vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
       }
     }
 
+    constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
+    int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
+
     // Result of:
     // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
-    constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7};
+    if constexpr (num_bits == 4) {
+      constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
 
-    uint32_t res = 0;
+      uint32_t res = 0;
 #pragma unroll
-    for (int i = 0; i < pack_factor_4bit; i++) {
-      res |= vals[pack_idx[i]] << (i * 4);
-    }
+      for (int i = 0; i < 8; i++) {
+        res |= vals[pack_idx[i]] << (i * 4);
+      }
 
-    constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit;
-    int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
+      out_ptr[out_offset + th_id * 4 + warp_id] = res;
 
-    out_ptr[out_offset + th_id * 4 + warp_id] = res;
+    } else {
+      constexpr int pack_idx[4] = {0, 2, 1, 3};
+
+      uint32_t res1 = 0;
+      uint32_t res2 = 0;
+#pragma unroll
+      for (int i = 0; i < 4; i++) {
+        res1 |= vals[pack_idx[i]] << (i * 8);
+        res2 |= vals[4 + pack_idx[i]] << (i * 8);
+      }
+
+      out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
+      out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
+    }
   };
 
   auto start_pipes = [&](int k_tile_id, int n_tile_id) {
@@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
 
 } // namespace gptq_marlin
 
+#define CALL_IF(NUM_BITS, HAS_PERM)                                            \
+  else if (num_bits == NUM_BITS && has_perm == HAS_PERM) {                     \
+    cudaFuncSetAttribute(                                                      \
+        gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads,         \
+                                          NUM_BITS, HAS_PERM>,                 \
+        cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);          \
+    gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS,   \
+                                      HAS_PERM>                                \
+        <<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>(     \
+            b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);                \
+  }
+
 torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
-                                 int64_t size_k, int64_t size_n) {
+                                 int64_t size_k, int64_t size_n,
+                                 int64_t num_bits) {
   // Verify compatibility with marlin tile of 16x64
   TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
               " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
   TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
               " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
 
+  TORCH_CHECK(num_bits == 4 || num_bits == 8,
+              "num_bits must be 4 or 8. Got = ", num_bits);
+  int const pack_factor = 32 / num_bits;
+
   // Verify B
-  TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0),
+  TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
               "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
-              ", size_k = ", size_k,
-              ", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit);
+              ", size_k = ", size_k, ", pack_factor = ", pack_factor);
   TORCH_CHECK(b_q_weight.size(1) == size_n,
               "b_q_weight.size(1) = ", b_q_weight.size(1),
               " is not size_n = ", size_n);
@@ -273,10 +309,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
   auto options = torch::TensorOptions()
                      .dtype(b_q_weight.dtype())
                      .device(b_q_weight.device());
-  torch::Tensor out = torch::empty(
-      {size_k / gptq_marlin::tile_size,
-       size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit},
-      options);
+  torch::Tensor out =
+      torch::empty({size_k / gptq_marlin::tile_size,
+                    size_n * gptq_marlin::tile_size / pack_factor},
+                   options);
 
   // Detect if there is act_order
   bool has_perm = perm.size(0) != 0;
@@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
                          cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
   TORCH_CHECK(max_shared_mem > 0);
 
-  if (has_perm) {
-    cudaFuncSetAttribute(
-        gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>,
-        cudaFuncAttributeMaxDynamicSharedMemorySize,
-        max_shared_mem);
-    gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>
-        <<<blocks, gptq_marlin::repack_threads, max_shared_mem,
-           stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
-
-  } else {
-    cudaFuncSetAttribute(
-        gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>,
-        cudaFuncAttributeMaxDynamicSharedMemorySize,
-        max_shared_mem);
-    gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>
-        <<<blocks, gptq_marlin::repack_threads, max_shared_mem,
-           stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
+  if (false) {
+  }
+  CALL_IF(4, false)
+  CALL_IF(4, true)
+  CALL_IF(8, false)
+  CALL_IF(8, true)
+  else {
+    TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
+                ", has_perm = ", has_perm);
   }
 
   return out;

+ 3 - 1
kernels/quantization/quant_ops.h

@@ -149,6 +149,7 @@ torch::Tensor gptq_marlin_gemm(
   torch::Tensor &g_idx,
   torch::Tensor &perm,
   torch::Tensor &workspace,
+  int64_t num_bits,
   int64_t size_m,
   int64_t size_n,
   int64_t size_k,
@@ -158,7 +159,8 @@ torch::Tensor gptq_marlin_repack(
   torch::Tensor &b_q_weight,
   torch::Tensor &perm,
   int64_t size_k,
-  int64_t size_n);
+  int64_t size_n,
+  int64_t num_bits);
 
 // QuIP#
 at::Tensor e8p_mm_origorder(