瀏覽代碼

improve gptq_marlin_24 prefill performance

AlpinDale 7 月之前
父節點
當前提交
d8667fcb98
共有 2 個文件被更改,包括 65 次插入27 次删除
  1. 25 12
      aphrodite/quantization/gptq_marlin_24.py
  2. 40 15
      kernels/quantization/marlin/sparse/marlin_24_cuda_kernel.cu

+ 25 - 12
aphrodite/quantization/gptq_marlin_24.py

@@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter
 from loguru import logger
 from loguru import logger
 
 
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
-from aphrodite.quantization.base_config import (QuantizationConfig)
+from aphrodite.quantization.base_config import (
+    QuantizationConfig)
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.modeling.utils import set_weight_attrs
 
 
 HAS_QUANTS = False
 HAS_QUANTS = False
@@ -14,6 +15,15 @@ with suppress(ImportError):
     from aphrodite._quant_C import quant_ops as ops
     from aphrodite._quant_C import quant_ops as ops
     HAS_QUANTS = True
     HAS_QUANTS = True
 
 
+GPTQ_MARLIN_24_TILE = 16
+GPTQ_MARLIN_24_MIN_THREAD_N = 128
+GPTQ_MARLIN_24_MIN_THREAD_K = 128
+GPTQ_MARLIN_24_MAX_PARALLEL = 64
+
+GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
+GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
+GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
+
 
 
 class GPTQMarlin24Config(QuantizationConfig):
 class GPTQMarlin24Config(QuantizationConfig):
     """Config class for Marlin24.
     """Config class for Marlin24.
@@ -27,15 +37,17 @@ class GPTQMarlin24Config(QuantizationConfig):
         self.weight_bits = weight_bits
         self.weight_bits = weight_bits
         self.group_size = group_size
         self.group_size = group_size
 
 
-        if self.weight_bits != 4 and self.weight_bits != 8:
-            raise ValueError("weight_bits must be 4 or 8. Got = {}".format(
-                self.weight_bits))
-
-        if self.group_size != 128 and self.group_size != -1:
+        # Verify
+        if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
             raise ValueError(
             raise ValueError(
-                "Currently, only group size 128 and -1 (channelwise) "
-                "is supported for Marlin24, but got group_size of "
-                f"{self.group_size}")
+                f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
+                f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
+                "are supported.")
+        if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
+            raise ValueError(
+                f"Marlin_24 does not support group_size = {self.group_size}. "
+                f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
+                "are supported.")
 
 
         # 4 Bits packed into 32 bit datatype.
         # 4 Bits packed into 32 bit datatype.
         self.pack_factor = 32 // self.weight_bits
         self.pack_factor = 32 // self.weight_bits
@@ -44,14 +56,14 @@ class GPTQMarlin24Config(QuantizationConfig):
         self.tile_size = 16
         self.tile_size = 16
 
 
         # Min out_features dim
         # Min out_features dim
-        self.min_n_threads = 128
+        self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
 
 
         # Min in_features dim
         # Min in_features dim
-        self.min_k_threads = 128
+        self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
 
 
         # Max parallel problems to solve at once (improves large
         # Max parallel problems to solve at once (improves large
         # batch performance)
         # batch performance)
-        self.max_parallel = 16
+        self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
 
 
         # Permutation length used by the marlin kernels.
         # Permutation length used by the marlin kernels.
         self.perm_len = 1024
         self.perm_len = 1024
@@ -113,6 +125,7 @@ class GPTQMarlin24Config(QuantizationConfig):
 
 
 class GPTQMarlin24LinearMethod(LinearMethodBase):
 class GPTQMarlin24LinearMethod(LinearMethodBase):
     """Linear method for Marlin24.
     """Linear method for Marlin24.
+
     Args:
     Args:
         quant_config: The Marlin24 quantization config.
         quant_config: The Marlin24 quantization config.
     """
     """

+ 40 - 15
kernels/quantization/marlin/sparse/marlin_24_cuda_kernel.cu

@@ -48,12 +48,12 @@ namespace marlin_24 {
 // than 1 warp per schedule allows some more latency hiding. At the same time,
 // than 1 warp per schedule allows some more latency hiding. At the same time,
 // we want relatively few warps to have many registers per warp and small tiles.
 // we want relatively few warps to have many registers per warp and small tiles.
 static constexpr int THREADS = 256;
 static constexpr int THREADS = 256;
-static constexpr int STAGES = 4;  // 4 pipeline stages fit into shared memory
+static constexpr int STAGES = 4;
 
 
 static constexpr int min_thread_n = 128;
 static constexpr int min_thread_n = 128;
 
 
 static constexpr int tile_size = 16;
 static constexpr int tile_size = 16;
-static constexpr int max_par = 16;
+static constexpr int max_par = 64;
 
 
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
 
 
@@ -736,10 +736,10 @@ __global__ void Marlin_24(
     for (int pipe = 0; pipe < stages;) {
     for (int pipe = 0; pipe < stages;) {
       fetch_to_shared((pipe + stages - 1) % stages, pipe,
       fetch_to_shared((pipe + stages - 1) % stages, pipe,
                       slice_iters >= stages);
                       slice_iters >= stages);
+      matmul(pipe);
       wait_for_stage();
       wait_for_stage();
 
 
       fetch_to_registers(pipe + 1, (pipe + 1) % stages);
       fetch_to_registers(pipe + 1, (pipe + 1) % stages);
-      matmul(pipe);
 
 
       pipe++;
       pipe++;
       slice_iters--;
       slice_iters--;
@@ -899,9 +899,12 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
       // than better compute utilization
       // than better compute utilization
       thread_k = 128;
       thread_k = 128;
       thread_m = 128;
       thread_m = 128;
-    } else {
+    } else if (prob_n <= 256) {
       thread_k = 64;
       thread_k = 64;
       thread_m = 256;
       thread_m = 256;
+    } else {
+      thread_k = 32;
+      thread_m = 512;
     }
     }
   }
   }
 
 
@@ -928,19 +931,21 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
   int4* C_ptr = (int4*)C;
   int4* C_ptr = (int4*)C;
   const int4* s_ptr = (const int4*)s;
   const int4* s_ptr = (const int4*)s;
 
 
+  constexpr int max_m_blocks = 4;
+
   int* locks = (int*)workspace;
   int* locks = (int*)workspace;
-  for (int i = 0; i < tot_n_blocks; i += 4) {
+  for (int i = 0; i < tot_n_blocks; i += max_m_blocks) {
     int thread_n_blocks = tot_n_blocks - i;
     int thread_n_blocks = tot_n_blocks - i;
     prob_n = tot_n - 16 * i;
     prob_n = tot_n - 16 * i;
     int par = 1;
     int par = 1;
-    if (thread_n_blocks > 4) {
+    if (thread_n_blocks > max_m_blocks) {
       // Note that parallel > 1 currently only works for inputs without any
       // Note that parallel > 1 currently only works for inputs without any
       // padding
       // padding
-      par = (16 * thread_n_blocks - pad) / 64;
+      par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16);
       if (par > max_par) par = max_par;
       if (par > max_par) par = max_par;
-      prob_n = 64 * par;
-      i += 4 * (par - 1);
-      thread_n_blocks = 4;
+      prob_n = (max_m_blocks * 16) * par;
+      i += max_m_blocks * (par - 1);
+      thread_n_blocks = max_m_blocks;
     }
     }
 
 
     // For compilation speed, we only define the kernel configurations that have
     // For compilation speed, we only define the kernel configurations that have
@@ -951,8 +956,9 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
     if (false) {
     if (false) {
     }  //         BMxBNxBK,   group
     }  //         BMxBNxBK,   group
     // 4-bit
     // 4-bit
-    CALL_IF_2_4(4, 8, 1, 4, -1)   // e.g., 16x128x128
-    CALL_IF_2_4(4, 8, 1, 4, 4)    // e.g., 16x128x128, 64
+    CALL_IF_2_4(4, 8, 1, 4, -1)  // e.g., 16x128x128
+    CALL_IF_2_4(4, 8, 1, 4, 4)   // e.g., 16x128x128, 64
+
     CALL_IF_2_4(4, 16, 1, 2, -1)  // e.g., 16x256x64
     CALL_IF_2_4(4, 16, 1, 2, -1)  // e.g., 16x256x64
     CALL_IF_2_4(4, 16, 1, 2, 4)   // e.g., 16x256x64,  64
     CALL_IF_2_4(4, 16, 1, 2, 4)   // e.g., 16x256x64,  64
     CALL_IF_2_4(4, 16, 2, 2, -1)  // e.g.. 32x256x64
     CALL_IF_2_4(4, 16, 2, 2, -1)  // e.g.. 32x256x64
@@ -962,9 +968,19 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
     CALL_IF_2_4(4, 16, 4, 2, -1)
     CALL_IF_2_4(4, 16, 4, 2, -1)
     CALL_IF_2_4(4, 16, 4, 2, 4)
     CALL_IF_2_4(4, 16, 4, 2, 4)
 
 
+    CALL_IF_2_4(4, 32, 1, 1, -1)  // e.g., 16x256x64
+    CALL_IF_2_4(4, 32, 1, 1, 4)   // e.g., 16x256x64,  64
+    CALL_IF_2_4(4, 32, 2, 1, -1)  // e.g.. 32x256x64
+    CALL_IF_2_4(4, 32, 2, 1, 4)
+    CALL_IF_2_4(4, 32, 3, 1, -1)
+    CALL_IF_2_4(4, 32, 3, 1, 4)
+    CALL_IF_2_4(4, 32, 4, 1, -1)
+    CALL_IF_2_4(4, 32, 4, 1, 4)
+
     // 8-bit
     // 8-bit
-    CALL_IF_2_4(8, 8, 1, 4, -1)   // e.g., 16x128x128
-    CALL_IF_2_4(8, 8, 1, 4, 4)    // e.g., 16x128x128, 64
+    CALL_IF_2_4(8, 8, 1, 4, -1)  // e.g., 16x128x128
+    CALL_IF_2_4(8, 8, 1, 4, 4)   // e.g., 16x128x128, 64
+
     CALL_IF_2_4(8, 16, 1, 2, -1)  // e.g., 16x256x64
     CALL_IF_2_4(8, 16, 1, 2, -1)  // e.g., 16x256x64
     CALL_IF_2_4(8, 16, 1, 2, 4)   // e.g., 16x256x64,  64
     CALL_IF_2_4(8, 16, 1, 2, 4)   // e.g., 16x256x64,  64
     CALL_IF_2_4(8, 16, 2, 2, -1)  // e.g.. 32x256x64
     CALL_IF_2_4(8, 16, 2, 2, -1)  // e.g.. 32x256x64
@@ -973,6 +989,15 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
     CALL_IF_2_4(8, 16, 3, 2, 4)
     CALL_IF_2_4(8, 16, 3, 2, 4)
     CALL_IF_2_4(8, 16, 4, 2, -1)
     CALL_IF_2_4(8, 16, 4, 2, -1)
     CALL_IF_2_4(8, 16, 4, 2, 4)
     CALL_IF_2_4(8, 16, 4, 2, 4)
+
+    CALL_IF_2_4(8, 32, 1, 1, -1)  // e.g., 16x256x64
+    CALL_IF_2_4(8, 32, 1, 1, 4)   // e.g., 16x256x64,  64
+    CALL_IF_2_4(8, 32, 2, 1, -1)  // e.g.. 32x256x64
+    CALL_IF_2_4(8, 32, 2, 1, 4)
+    CALL_IF_2_4(8, 32, 3, 1, -1)
+    CALL_IF_2_4(8, 32, 3, 1, 4)
+    CALL_IF_2_4(8, 32, 4, 1, -1)
+    CALL_IF_2_4(8, 32, 4, 1, 4)
     else {
     else {
       throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
       throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
                                ", " + str(prob_k) + ", " + str(prob_n) + "]" +
                                ", " + str(prob_k) + ", " + str(prob_n) + "]" +
@@ -1062,7 +1087,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
   int thread_k = -1;
   int thread_k = -1;
   int thread_m = -1;
   int thread_m = -1;
   int sms = -1;
   int sms = -1;
-  int max_par = 16;
+  int max_par = marlin_24::max_par;
 
 
   int groupsize = -1;
   int groupsize = -1;
   if (b_scales.size(0) > 1) {
   if (b_scales.size(0) > 1) {