Procházet zdrojové kódy

gptq_marlin: enable bfloat16

AlpinDale před 7 měsíci
rodič
revize
ad1c6b86a1

+ 4 - 4
aphrodite/quantization/gptq_marlin.py

@@ -103,7 +103,7 @@ class GPTQMarlinConfig(QuantizationConfig):
 
     @classmethod
     def get_supported_act_dtypes(cls) -> List[torch.dtype]:
-        return [torch.half]
+        return [torch.half, torch.bfloat16]
 
     @classmethod
     def get_min_capability(cls) -> int:
@@ -191,9 +191,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
             group_size = input_size
 
         # Validate dtype
-        if params_dtype != torch.float16:
-            raise ValueError(
-                f"The params dtype must be float16, but got {params_dtype}")
+        if params_dtype not in [torch.float16, torch.bfloat16]:
+            raise ValueError(f"The params dtype must be float16 "
+                             f"or bfloat16, but got {params_dtype}")
 
         # Validate output_size_per_partition
         output_size_per_partition = sum(output_partition_sizes)

+ 186 - 79
kernels/quantization/gptq_marlin/gptq_marlin.cu

@@ -20,6 +20,11 @@
  */
 
 #include "gptq_marlin.cuh"
+#include "gptq_marlin_dtypes.cuh"
+
+#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\
+  std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \
+  "only float16 and bfloat16 is supported");
 
 template <typename T> inline std::string str(T x) { return std::to_string(x); }
 
@@ -32,7 +37,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
                                     int4 *__restrict__ out_int4_ptr, int size_m,
                                     int size_k, int block_rows) {}
 
-template <const int num_bits,        // number of bits used for weights
+template <typename scalar_t,         // compute dtype, half or nv_float16
+          const int num_bits,        // number of bits used for weights
           const int threads,         // number of threads in a threadblock
           const int thread_m_blocks, // number of 16x16 blocks in the m
                                      // dimension (batchsize) of the threadblock
@@ -72,31 +78,37 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
 
 #else
 
-// Matrix fragments for tensor core instructions; their precise layout is
-// documented here:
-// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
-using FragA = Vec<half2, 4>;
-using FragB = Vec<half2, 2>;
-using FragC = Vec<float, 4>;
-using FragS = Vec<half2, 1>; // quantization scales
 
 // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
 // output/accumulation.
-__device__ inline void mma(const FragA &a_frag, const FragB &frag_b,
-                           FragC &frag_c) {
+template <typename scalar_t>
+__device__ inline void mma(const typename ScalarType<scalar_t>::FragA &a_frag,
+                           const typename ScalarType<scalar_t>::FragB &frag_b,
+                           typename ScalarType<scalar_t>::FragC &frag_c) {
   const uint32_t *a = reinterpret_cast<const uint32_t *>(&a_frag);
   const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b);
   float *c = reinterpret_cast<float *>(&frag_c);
-  asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
-               "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
-               : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
-               : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
-                 "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
+  if constexpr (std::is_same<scalar_t, half>::value) {
+    asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
+                "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
+                : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
+                : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
+                  "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
+  } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
+    asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
+                "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
+                : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
+                : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
+                  "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
+  } else {
+    STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
+  }
 }
 
 // Instruction for loading a full 16x16 matrix fragment of operand A from shared
 // memory, directly in tensor core layout.
-__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
+template <typename scalar_t>
+__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA &frag_a, const void *smem_ptr) {
   uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
   uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
   asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
@@ -128,8 +140,15 @@ __device__ inline uint32_t prmt(uint32_t a) {
 // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
 // values. We mostly follow the strategy in the link below, with some small
 // changes:
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
-__device__ inline FragB dequant_4bit(int q) {
+// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
+// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
+template <typename scalar_t>
+__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
+  STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
+}
+
+template <>
+__device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
   const int LO = 0x000f000f;
   const int HI = 0x00f000f0;
   const int EX = 0x64006400;
@@ -141,7 +160,7 @@ __device__ inline FragB dequant_4bit(int q) {
   const int SUB = 0x64086408;
   const int MUL = 0x2c002c00;
   const int ADD = 0xd480d480;
-  FragB frag_b;
+  typename ScalarType<half>::FragB frag_b;
   frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
                       *reinterpret_cast<const half2 *>(&SUB));
   frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi),
@@ -150,7 +169,41 @@ __device__ inline FragB dequant_4bit(int q) {
   return frag_b;
 }
 
-__device__ inline FragB dequant_8bit(int q) {
+template <>
+__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_4bit<nv_bfloat16>(int q) {
+  static constexpr uint32_t MASK = 0x000f000f;
+  static constexpr uint32_t EX = 0x43004300;
+
+  // Guarantee that the `(a & b) | c` operations are LOP3s.
+
+  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
+  q >>= 4;
+  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
+
+  typename ScalarType<nv_bfloat16>::FragB frag_b;
+  static constexpr uint32_t MUL = 0x3F803F80;
+  static constexpr uint32_t ADD = 0xC308C308;
+
+  frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162 *>(&lo),
+                      *reinterpret_cast<const nv_bfloat162*>(&MUL),
+                      *reinterpret_cast<const nv_bfloat162*>(&ADD));
+  frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
+                      *reinterpret_cast<const nv_bfloat162*>(&MUL),
+                      *reinterpret_cast<const nv_bfloat162*>(&ADD));
+  return frag_b;
+}
+
+// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or bf16
+// Reference:
+// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
+// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
+template <typename scalar_t>
+__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
+  STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
+}
+
+template <>
+__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
   static constexpr uint32_t mask_for_elt_01 = 0x5250;
   static constexpr uint32_t mask_for_elt_23 = 0x5351;
   static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
@@ -160,7 +213,7 @@ __device__ inline FragB dequant_8bit(int q) {
 
   static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
 
-  FragB frag_b;
+  typename ScalarType<half>::FragB frag_b;
   frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
                       *reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
   frag_b[1] = __hsub2(*reinterpret_cast<half2 *>(&hi),
@@ -168,34 +221,69 @@ __device__ inline FragB dequant_8bit(int q) {
   return frag_b;
 }
 
+template <>
+__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat16>(int q) {
+  typename ScalarType<nv_bfloat16>::FragB frag_b;
+
+  float fp32_intermediates[4];
+  uint32_t * fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
+
+  static constexpr uint32_t fp32_base = 0x4B000000;
+  fp32_intermediates_casted[0]  = __byte_perm(q, fp32_base, 0x7650);
+  fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
+  fp32_intermediates_casted[2]  = __byte_perm(q, fp32_base, 0x7651);
+  fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
+
+  fp32_intermediates[0] -= 8388736.f;
+  fp32_intermediates[1] -= 8388736.f;
+  fp32_intermediates[2] -= 8388736.f;
+  fp32_intermediates[3] -= 8388736.f;
+
+  uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
+  bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
+  bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
+
+  return frag_b;
+}
+
 // Multiply dequantized values by the corresponding quantization scale; used
 // only for grouped quantization.
-__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
-  half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]);
+template <typename scalar_t>
+__device__ inline void scale(typename ScalarType<scalar_t>::FragB &frag_b,
+                             typename ScalarType<scalar_t>::FragS &frag_s, int i) {
+  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
+  scalar_t2 s = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t *>(&frag_s)[i]);
   frag_b[0] = __hmul2(frag_b[0], s);
   frag_b[1] = __hmul2(frag_b[1], s);
 }
 
 // Same as above, but for act_order (each K is multiplied individually)
-__device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2,
-                              FragS &frag_s_3, FragS &frag_s_4, int i) {
-  __half2 s_val_1_2;
-  s_val_1_2.x = reinterpret_cast<__half *>(&frag_s_1)[i];
-  s_val_1_2.y = reinterpret_cast<__half *>(&frag_s_2)[i];
-
-  __half2 s_val_3_4;
-  s_val_3_4.x = reinterpret_cast<__half *>(&frag_s_3)[i];
-  s_val_3_4.y = reinterpret_cast<__half *>(&frag_s_4)[i];
+template <typename scalar_t>
+__device__ inline void scale4(typename ScalarType<scalar_t>::FragB &frag_b,
+                              typename ScalarType<scalar_t>::FragS &frag_s_1,
+                              typename ScalarType<scalar_t>::FragS &frag_s_2,
+                              typename ScalarType<scalar_t>::FragS &frag_s_3,
+                              typename ScalarType<scalar_t>::FragS &frag_s_4,
+                              int i) {
+  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; 
+  scalar_t2 s_val_1_2;
+  s_val_1_2.x = reinterpret_cast<scalar_t *>(&frag_s_1)[i];
+  s_val_1_2.y = reinterpret_cast<scalar_t *>(&frag_s_2)[i];
+
+  scalar_t2 s_val_3_4;
+  s_val_3_4.x = reinterpret_cast<scalar_t *>(&frag_s_3)[i];
+  s_val_3_4.y = reinterpret_cast<scalar_t *>(&frag_s_4)[i];
 
   frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
   frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
 }
 
 // Given 2 floats multiply by 2 scales (halves)
-__device__ inline void scale_float(float *c, FragS &s) {
-  __half *s_ptr = reinterpret_cast<__half *>(&s);
-  c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
-  c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
+template <typename scalar_t>
+__device__ inline void scale_float(float *c, typename ScalarType<scalar_t>::FragS &s) {
+  scalar_t *s_ptr = reinterpret_cast<scalar_t *>(&s);
+  c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
+  c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
 }
 
 // Wait until barrier reaches `count`, then lock for current threadblock.
@@ -286,7 +374,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
   }
 }
 
-template <const int num_bits,        // number of bits used for weights
+template <typename scalar_t,         // compute dtype, half or nv_float16
+          const int num_bits,        // number of bits used for weights
           const int threads,         // number of threads in a threadblock
           const int thread_m_blocks, // number of 16x16 blocks in the m
                                      // dimension (batchsize) of the threadblock
@@ -322,6 +411,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
   // ensures good utilization of all SMs for many kinds of shape and GPU
   // configurations, while requiring as few slow global cross-threadblock
   // reductions as possible.
+  using Dtype = ScalarType<scalar_t>;
+  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
+  using FragA = typename ScalarType<scalar_t>::FragA;
+  using FragB = typename ScalarType<scalar_t>::FragB;
+  using FragC = typename ScalarType<scalar_t>::FragC;
+  using FragS = typename ScalarType<scalar_t>::FragS;
 
   constexpr int pack_factor = 32 / num_bits;
 
@@ -690,7 +785,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
     int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
 #pragma unroll
     for (int i = 0; i < thread_m_blocks; i++)
-      ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
+      ldsm4<scalar_t>(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
     int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
 
 #pragma unroll
@@ -834,43 +929,43 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
         int b_quant = frag_b_quant[k % 2][0][j];
         int b_quant_shift = b_quant >> 8;
 
-        frag_b0 = dequant_4bit(b_quant);
-        frag_b1 = dequant_4bit(b_quant_shift);
+        frag_b0 = dequant_4bit<scalar_t>(b_quant);
+        frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
 
       } else {
         int *frag_b_quant_ptr = reinterpret_cast<int *>(frag_b_quant[k % 2]);
         int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
         int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
 
-        frag_b0 = dequant_8bit(b_quant_0);
-        frag_b1 = dequant_8bit(b_quant_1);
+        frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
+        frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
       }
 
       // Apply scale to frag_b0
       if constexpr (has_act_order) {
-        scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
+        scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
                act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
       } else {
         if constexpr (group_blocks != -1) {
-          scale(frag_b0, frag_s[k % 2][j], 0);
+          scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
         }
       }
 
       // Apply scale to frag_b1
       if constexpr (has_act_order) {
-        scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
+        scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
                act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);
 
       } else {
         if constexpr (group_blocks != -1) {
-          scale(frag_b1, frag_s[k % 2][j], 1);
+          scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);
         }
       }
 
 #pragma unroll
       for (int i = 0; i < thread_m_blocks; i++) {
-        mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
-        mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
+        mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
+        mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
       }
     }
   };
@@ -978,15 +1073,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
             for (int j = 0; j < 2 * 4; j++) {
               reinterpret_cast<float *>(
                   &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
-                  __half2float(reinterpret_cast<__half *>(&c_red)[j]);
+                  Dtype::num2float(reinterpret_cast<scalar_t *>(&c_red)[j]);
             }
           }
           if (!last) {
             int4 c;
 #pragma unroll
             for (int j = 0; j < 2 * 4; j++) {
-              reinterpret_cast<__half *>(&c)[j] =
-                  __float2half(reinterpret_cast<float *>(
+              reinterpret_cast<scalar_t *>(&c)[j] =
+                  Dtype::float2num(reinterpret_cast<float *>(
                       &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
             }
             C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
@@ -1021,7 +1116,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
     // We first reorder in shared memory to guarantee the most efficient final
     // global write patterns
     auto write = [&](int idx, float c0, float c1, FragS &s) {
-      half2 res = __halves2half2(__float2half(c0), __float2half(c1));
+      scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
 
       // For per-column quantization we finally apply the scale here (only for
       // 4-bit)
@@ -1029,7 +1124,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
         res = __hmul2(res, s[0]);
       }
 
-      ((half2 *)sh)[idx] = res;
+      ((scalar_t2 *)sh)[idx] = res;
     };
 
     if (threadIdx.x / 32 < thread_n_blocks / 4) {
@@ -1191,14 +1286,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
           for (int i = 0; i < thread_m_blocks; i++) {
 #pragma unroll
             for (int j = 0; j < 4; j++) {
-              scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
+              scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
                           frag_s[j / 2][2 * (j % 2) + 0]);
-              scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
+              scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
                           frag_s[j / 2][2 * (j % 2) + 0]);
 
-              scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
+              scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
                           frag_s[j / 2][2 * (j % 2) + 1]);
-              scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
+              scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
                           frag_s[j / 2][2 * (j % 2) + 1]);
             }
           }
@@ -1246,22 +1341,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
   }
 }
 
-#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
-                  HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS)                    \
-  else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS &&       \
-           thread_n_blocks == THREAD_N_BLOCKS &&                               \
-           thread_k_blocks == THREAD_K_BLOCKS &&                               \
-           has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS &&   \
-           num_threads == NUM_THREADS) {                                       \
-    cudaFuncSetAttribute(                                                      \
-        Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,        \
-               THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>,     \
-        cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);          \
-    Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,            \
-           THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>          \
-        <<<blocks, NUM_THREADS, max_shared_mem, stream>>>(                     \
-            A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
-            prob_k, locks);                                                    \
+#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS,    \
+                  HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS)                       \
+  else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS &&          \
+           thread_n_blocks == THREAD_N_BLOCKS &&                                  \
+           thread_k_blocks == THREAD_K_BLOCKS &&                                  \
+           has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS &&      \
+           num_threads == NUM_THREADS) {                                          \
+    cudaFuncSetAttribute(                                                         \
+        Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
+               THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>,        \
+        cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);             \
+    Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,     \
+           THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>             \
+        <<<blocks, NUM_THREADS, max_shared_mem, stream>>>(                        \
+            A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n,    \
+            prob_k, locks);                                                       \
   }
 
 typedef struct {
@@ -1461,6 +1556,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
   __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)            \
   __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
 
+template <typename scalar_t>
 void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
                      void *g_idx, void *perm, void *a_tmp, int prob_m,
                      int prob_n, int prob_k, void *workspace, int num_bits,
@@ -1730,12 +1826,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
               " is below min_workspace_size = ", min_workspace_size);
 
   int dev = a.get_device();
-  gptq_marlin::marlin_mm_f16i4(
-      a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(),
-      g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n,
-      size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
-      num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
-      thread_k, thread_n, sms, gptq_marlin::max_par);
+  if (a.scalar_type() == at::ScalarType::Half) {
+    gptq_marlin::marlin_mm_f16i4<half>(
+        a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), b_scales.data_ptr<at::Half>(),
+        g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n,
+        size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
+        num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
+        thread_k, thread_n, sms, gptq_marlin::max_par);
+  } else if (a.scalar_type() == at::ScalarType::BFloat16) {
+    gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
+        a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
+        g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n,
+        size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
+        num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
+        thread_k, thread_n, sms, gptq_marlin::max_par);
+  } else {
+    TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
+  }
 
   return c;
 }

+ 61 - 0
kernels/quantization/gptq_marlin/gptq_marlin_dtypes.cuh

@@ -0,0 +1,61 @@
+#ifndef _data_types_cuh
+#define _data_types_cuh
+#include "gptq_marlin.cuh"
+#include <cuda_fp16.h>
+#include <cuda_bf16.h>
+
+
+namespace gptq_marlin {
+
+template <typename scalar_t>
+class ScalarType {
+};
+
+template <>
+class ScalarType<half> {
+public:
+    using scalar_t = half;
+    using scalar_t2 = half2;
+
+    // Matrix fragments for tensor core instructions; their precise layout is
+    // documented here:
+    // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+    using FragA = Vec<half2, 4>;
+    using FragB = Vec<half2, 2>;
+    using FragC = Vec<float, 4>;
+    using FragS = Vec<half2, 1>;
+
+    static __device__ float inline num2float(const half x) { return __half2float(x); }
+
+    static __device__ half2 inline num2num2(const half x) { return __half2half2(x); }
+
+    static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); }
+
+    static __host__ __device__ half inline float2num(const float x) { return __float2half(x); }
+};
+
+template <>
+class ScalarType<nv_bfloat16> {
+public:
+    using scalar_t = nv_bfloat16;
+    using scalar_t2 = nv_bfloat162;
+
+    using FragA = Vec<nv_bfloat162, 4>;
+    using FragB = Vec<nv_bfloat162, 2>;
+    using FragC = Vec<float, 4>;
+    using FragS = Vec<nv_bfloat162, 1>;
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+    static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); }
+
+    static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); }
+
+    static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); }
+
+    static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); }
+#endif
+};
+
+}
+
+#endif