|
@@ -20,6 +20,11 @@
|
|
|
*/
|
|
|
|
|
|
#include "gptq_marlin.cuh"
|
|
|
+#include "gptq_marlin_dtypes.cuh"
|
|
|
+
|
|
|
+#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\
|
|
|
+ std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \
|
|
|
+ "only float16 and bfloat16 is supported");
|
|
|
|
|
|
template <typename T> inline std::string str(T x) { return std::to_string(x); }
|
|
|
|
|
@@ -32,7 +37,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
|
|
|
int4 *__restrict__ out_int4_ptr, int size_m,
|
|
|
int size_k, int block_rows) {}
|
|
|
|
|
|
-template <const int num_bits, // number of bits used for weights
|
|
|
+template <typename scalar_t, // compute dtype, half or nv_float16
|
|
|
+ const int num_bits, // number of bits used for weights
|
|
|
const int threads, // number of threads in a threadblock
|
|
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
|
|
// dimension (batchsize) of the threadblock
|
|
@@ -72,31 +78,37 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
|
|
|
|
#else
|
|
|
|
|
|
-// Matrix fragments for tensor core instructions; their precise layout is
|
|
|
-// documented here:
|
|
|
-// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
|
|
-using FragA = Vec<half2, 4>;
|
|
|
-using FragB = Vec<half2, 2>;
|
|
|
-using FragC = Vec<float, 4>;
|
|
|
-using FragS = Vec<half2, 1>; // quantization scales
|
|
|
|
|
|
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
|
|
// output/accumulation.
|
|
|
-__device__ inline void mma(const FragA &a_frag, const FragB &frag_b,
|
|
|
- FragC &frag_c) {
|
|
|
+template <typename scalar_t>
|
|
|
+__device__ inline void mma(const typename ScalarType<scalar_t>::FragA &a_frag,
|
|
|
+ const typename ScalarType<scalar_t>::FragB &frag_b,
|
|
|
+ typename ScalarType<scalar_t>::FragC &frag_c) {
|
|
|
const uint32_t *a = reinterpret_cast<const uint32_t *>(&a_frag);
|
|
|
const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b);
|
|
|
float *c = reinterpret_cast<float *>(&frag_c);
|
|
|
- asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
|
|
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
|
|
- : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
|
|
- : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
|
|
- "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
|
|
+ if constexpr (std::is_same<scalar_t, half>::value) {
|
|
|
+ asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
|
|
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
|
|
+ : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
|
|
+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
|
|
+ "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
|
|
+ } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
|
|
+ asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
|
|
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
|
|
+ : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
|
|
+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
|
|
|
+ "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
|
|
+ } else {
|
|
|
+ STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
|
|
// memory, directly in tensor core layout.
|
|
|
-__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
|
|
|
+template <typename scalar_t>
|
|
|
+__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA &frag_a, const void *smem_ptr) {
|
|
|
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
|
|
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
|
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
|
@@ -128,8 +140,15 @@ __device__ inline uint32_t prmt(uint32_t a) {
|
|
|
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
|
|
// values. We mostly follow the strategy in the link below, with some small
|
|
|
// changes:
|
|
|
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
|
|
-__device__ inline FragB dequant_4bit(int q) {
|
|
|
+// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
|
|
+// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
|
|
+template <typename scalar_t>
|
|
|
+__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
|
|
|
+ STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
|
|
+}
|
|
|
+
|
|
|
+template <>
|
|
|
+__device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
|
|
|
const int LO = 0x000f000f;
|
|
|
const int HI = 0x00f000f0;
|
|
|
const int EX = 0x64006400;
|
|
@@ -141,7 +160,7 @@ __device__ inline FragB dequant_4bit(int q) {
|
|
|
const int SUB = 0x64086408;
|
|
|
const int MUL = 0x2c002c00;
|
|
|
const int ADD = 0xd480d480;
|
|
|
- FragB frag_b;
|
|
|
+ typename ScalarType<half>::FragB frag_b;
|
|
|
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
|
|
|
*reinterpret_cast<const half2 *>(&SUB));
|
|
|
frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi),
|
|
@@ -150,7 +169,41 @@ __device__ inline FragB dequant_4bit(int q) {
|
|
|
return frag_b;
|
|
|
}
|
|
|
|
|
|
-__device__ inline FragB dequant_8bit(int q) {
|
|
|
+template <>
|
|
|
+__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_4bit<nv_bfloat16>(int q) {
|
|
|
+ static constexpr uint32_t MASK = 0x000f000f;
|
|
|
+ static constexpr uint32_t EX = 0x43004300;
|
|
|
+
|
|
|
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
|
|
|
+
|
|
|
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
|
|
+ q >>= 4;
|
|
|
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
|
|
+
|
|
|
+ typename ScalarType<nv_bfloat16>::FragB frag_b;
|
|
|
+ static constexpr uint32_t MUL = 0x3F803F80;
|
|
|
+ static constexpr uint32_t ADD = 0xC308C308;
|
|
|
+
|
|
|
+ frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162 *>(&lo),
|
|
|
+ *reinterpret_cast<const nv_bfloat162*>(&MUL),
|
|
|
+ *reinterpret_cast<const nv_bfloat162*>(&ADD));
|
|
|
+ frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
|
|
+ *reinterpret_cast<const nv_bfloat162*>(&MUL),
|
|
|
+ *reinterpret_cast<const nv_bfloat162*>(&ADD));
|
|
|
+ return frag_b;
|
|
|
+}
|
|
|
+
|
|
|
+// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or bf16
|
|
|
+// Reference:
|
|
|
+// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
|
|
+// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
|
|
+template <typename scalar_t>
|
|
|
+__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
|
|
|
+ STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
|
|
+}
|
|
|
+
|
|
|
+template <>
|
|
|
+__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
|
|
|
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
|
|
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
|
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
|
@@ -160,7 +213,7 @@ __device__ inline FragB dequant_8bit(int q) {
|
|
|
|
|
|
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
|
|
|
|
|
- FragB frag_b;
|
|
|
+ typename ScalarType<half>::FragB frag_b;
|
|
|
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
|
|
|
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
|
|
|
frag_b[1] = __hsub2(*reinterpret_cast<half2 *>(&hi),
|
|
@@ -168,34 +221,69 @@ __device__ inline FragB dequant_8bit(int q) {
|
|
|
return frag_b;
|
|
|
}
|
|
|
|
|
|
+template <>
|
|
|
+__device__ inline typename ScalarType<nv_bfloat16>::FragB dequant_8bit<nv_bfloat16>(int q) {
|
|
|
+ typename ScalarType<nv_bfloat16>::FragB frag_b;
|
|
|
+
|
|
|
+ float fp32_intermediates[4];
|
|
|
+ uint32_t * fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
|
|
+
|
|
|
+ static constexpr uint32_t fp32_base = 0x4B000000;
|
|
|
+ fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
|
|
+ fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
|
|
+ fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
|
|
+ fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
|
|
+
|
|
|
+ fp32_intermediates[0] -= 8388736.f;
|
|
|
+ fp32_intermediates[1] -= 8388736.f;
|
|
|
+ fp32_intermediates[2] -= 8388736.f;
|
|
|
+ fp32_intermediates[3] -= 8388736.f;
|
|
|
+
|
|
|
+ uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
|
|
+ bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
|
|
|
+ bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
|
|
|
+
|
|
|
+ return frag_b;
|
|
|
+}
|
|
|
+
|
|
|
// Multiply dequantized values by the corresponding quantization scale; used
|
|
|
// only for grouped quantization.
|
|
|
-__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
|
|
|
- half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]);
|
|
|
+template <typename scalar_t>
|
|
|
+__device__ inline void scale(typename ScalarType<scalar_t>::FragB &frag_b,
|
|
|
+ typename ScalarType<scalar_t>::FragS &frag_s, int i) {
|
|
|
+ using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
|
|
+ scalar_t2 s = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t *>(&frag_s)[i]);
|
|
|
frag_b[0] = __hmul2(frag_b[0], s);
|
|
|
frag_b[1] = __hmul2(frag_b[1], s);
|
|
|
}
|
|
|
|
|
|
// Same as above, but for act_order (each K is multiplied individually)
|
|
|
-__device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2,
|
|
|
- FragS &frag_s_3, FragS &frag_s_4, int i) {
|
|
|
- __half2 s_val_1_2;
|
|
|
- s_val_1_2.x = reinterpret_cast<__half *>(&frag_s_1)[i];
|
|
|
- s_val_1_2.y = reinterpret_cast<__half *>(&frag_s_2)[i];
|
|
|
-
|
|
|
- __half2 s_val_3_4;
|
|
|
- s_val_3_4.x = reinterpret_cast<__half *>(&frag_s_3)[i];
|
|
|
- s_val_3_4.y = reinterpret_cast<__half *>(&frag_s_4)[i];
|
|
|
+template <typename scalar_t>
|
|
|
+__device__ inline void scale4(typename ScalarType<scalar_t>::FragB &frag_b,
|
|
|
+ typename ScalarType<scalar_t>::FragS &frag_s_1,
|
|
|
+ typename ScalarType<scalar_t>::FragS &frag_s_2,
|
|
|
+ typename ScalarType<scalar_t>::FragS &frag_s_3,
|
|
|
+ typename ScalarType<scalar_t>::FragS &frag_s_4,
|
|
|
+ int i) {
|
|
|
+ using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
|
|
+ scalar_t2 s_val_1_2;
|
|
|
+ s_val_1_2.x = reinterpret_cast<scalar_t *>(&frag_s_1)[i];
|
|
|
+ s_val_1_2.y = reinterpret_cast<scalar_t *>(&frag_s_2)[i];
|
|
|
+
|
|
|
+ scalar_t2 s_val_3_4;
|
|
|
+ s_val_3_4.x = reinterpret_cast<scalar_t *>(&frag_s_3)[i];
|
|
|
+ s_val_3_4.y = reinterpret_cast<scalar_t *>(&frag_s_4)[i];
|
|
|
|
|
|
frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
|
|
|
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
|
|
|
}
|
|
|
|
|
|
// Given 2 floats multiply by 2 scales (halves)
|
|
|
-__device__ inline void scale_float(float *c, FragS &s) {
|
|
|
- __half *s_ptr = reinterpret_cast<__half *>(&s);
|
|
|
- c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
|
|
|
- c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
|
|
|
+template <typename scalar_t>
|
|
|
+__device__ inline void scale_float(float *c, typename ScalarType<scalar_t>::FragS &s) {
|
|
|
+ scalar_t *s_ptr = reinterpret_cast<scalar_t *>(&s);
|
|
|
+ c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
|
|
|
+ c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
|
|
|
}
|
|
|
|
|
|
// Wait until barrier reaches `count`, then lock for current threadblock.
|
|
@@ -286,7 +374,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-template <const int num_bits, // number of bits used for weights
|
|
|
+template <typename scalar_t, // compute dtype, half or nv_float16
|
|
|
+ const int num_bits, // number of bits used for weights
|
|
|
const int threads, // number of threads in a threadblock
|
|
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
|
|
// dimension (batchsize) of the threadblock
|
|
@@ -322,6 +411,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
// ensures good utilization of all SMs for many kinds of shape and GPU
|
|
|
// configurations, while requiring as few slow global cross-threadblock
|
|
|
// reductions as possible.
|
|
|
+ using Dtype = ScalarType<scalar_t>;
|
|
|
+ using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
|
|
+ using FragA = typename ScalarType<scalar_t>::FragA;
|
|
|
+ using FragB = typename ScalarType<scalar_t>::FragB;
|
|
|
+ using FragC = typename ScalarType<scalar_t>::FragC;
|
|
|
+ using FragS = typename ScalarType<scalar_t>::FragS;
|
|
|
|
|
|
constexpr int pack_factor = 32 / num_bits;
|
|
|
|
|
@@ -690,7 +785,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < thread_m_blocks; i++)
|
|
|
- ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
|
|
+ ldsm4<scalar_t>(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
|
|
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
|
|
|
|
|
|
#pragma unroll
|
|
@@ -834,43 +929,43 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
int b_quant = frag_b_quant[k % 2][0][j];
|
|
|
int b_quant_shift = b_quant >> 8;
|
|
|
|
|
|
- frag_b0 = dequant_4bit(b_quant);
|
|
|
- frag_b1 = dequant_4bit(b_quant_shift);
|
|
|
+ frag_b0 = dequant_4bit<scalar_t>(b_quant);
|
|
|
+ frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
|
|
|
|
|
|
} else {
|
|
|
int *frag_b_quant_ptr = reinterpret_cast<int *>(frag_b_quant[k % 2]);
|
|
|
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
|
|
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
|
|
|
|
|
- frag_b0 = dequant_8bit(b_quant_0);
|
|
|
- frag_b1 = dequant_8bit(b_quant_1);
|
|
|
+ frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
|
|
+ frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
|
|
}
|
|
|
|
|
|
// Apply scale to frag_b0
|
|
|
if constexpr (has_act_order) {
|
|
|
- scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
|
|
+ scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
|
|
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
|
|
|
} else {
|
|
|
if constexpr (group_blocks != -1) {
|
|
|
- scale(frag_b0, frag_s[k % 2][j], 0);
|
|
|
+ scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Apply scale to frag_b1
|
|
|
if constexpr (has_act_order) {
|
|
|
- scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
|
|
+ scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
|
|
act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);
|
|
|
|
|
|
} else {
|
|
|
if constexpr (group_blocks != -1) {
|
|
|
- scale(frag_b1, frag_s[k % 2][j], 1);
|
|
|
+ scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < thread_m_blocks; i++) {
|
|
|
- mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
|
|
- mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
|
|
|
+ mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
|
|
+ mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
|
|
|
}
|
|
|
}
|
|
|
};
|
|
@@ -978,15 +1073,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
for (int j = 0; j < 2 * 4; j++) {
|
|
|
reinterpret_cast<float *>(
|
|
|
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
|
|
|
- __half2float(reinterpret_cast<__half *>(&c_red)[j]);
|
|
|
+ Dtype::num2float(reinterpret_cast<scalar_t *>(&c_red)[j]);
|
|
|
}
|
|
|
}
|
|
|
if (!last) {
|
|
|
int4 c;
|
|
|
#pragma unroll
|
|
|
for (int j = 0; j < 2 * 4; j++) {
|
|
|
- reinterpret_cast<__half *>(&c)[j] =
|
|
|
- __float2half(reinterpret_cast<float *>(
|
|
|
+ reinterpret_cast<scalar_t *>(&c)[j] =
|
|
|
+ Dtype::float2num(reinterpret_cast<float *>(
|
|
|
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
|
|
|
}
|
|
|
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
|
|
@@ -1021,7 +1116,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
// We first reorder in shared memory to guarantee the most efficient final
|
|
|
// global write patterns
|
|
|
auto write = [&](int idx, float c0, float c1, FragS &s) {
|
|
|
- half2 res = __halves2half2(__float2half(c0), __float2half(c1));
|
|
|
+ scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
|
|
|
|
|
// For per-column quantization we finally apply the scale here (only for
|
|
|
// 4-bit)
|
|
@@ -1029,7 +1124,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
res = __hmul2(res, s[0]);
|
|
|
}
|
|
|
|
|
|
- ((half2 *)sh)[idx] = res;
|
|
|
+ ((scalar_t2 *)sh)[idx] = res;
|
|
|
};
|
|
|
|
|
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
|
@@ -1191,14 +1286,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
for (int i = 0; i < thread_m_blocks; i++) {
|
|
|
#pragma unroll
|
|
|
for (int j = 0; j < 4; j++) {
|
|
|
- scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
|
|
|
+ scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
|
|
|
frag_s[j / 2][2 * (j % 2) + 0]);
|
|
|
- scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
|
|
|
+ scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
|
|
|
frag_s[j / 2][2 * (j % 2) + 0]);
|
|
|
|
|
|
- scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
|
|
|
+ scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
|
|
|
frag_s[j / 2][2 * (j % 2) + 1]);
|
|
|
- scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
|
|
|
+ scale_float<scalar_t>(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
|
|
|
frag_s[j / 2][2 * (j % 2) + 1]);
|
|
|
}
|
|
|
}
|
|
@@ -1246,22 +1341,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
|
|
- HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
|
|
|
- else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
|
|
- thread_n_blocks == THREAD_N_BLOCKS && \
|
|
|
- thread_k_blocks == THREAD_K_BLOCKS && \
|
|
|
- has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
|
|
- num_threads == NUM_THREADS) { \
|
|
|
- cudaFuncSetAttribute( \
|
|
|
- Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
|
|
- THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
|
|
- cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
|
|
- Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
|
|
- THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
|
|
- <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
|
|
- A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
|
|
|
- prob_k, locks); \
|
|
|
+#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
|
|
+ HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
|
|
|
+ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
|
|
+ thread_n_blocks == THREAD_N_BLOCKS && \
|
|
|
+ thread_k_blocks == THREAD_K_BLOCKS && \
|
|
|
+ has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
|
|
+ num_threads == NUM_THREADS) { \
|
|
|
+ cudaFuncSetAttribute( \
|
|
|
+ Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
|
|
+ THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
|
|
+ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
|
|
+ Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
|
|
+ THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
|
|
+ <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
|
|
+ A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
|
|
|
+ prob_k, locks); \
|
|
|
}
|
|
|
|
|
|
typedef struct {
|
|
@@ -1461,6 +1556,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
|
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
|
|
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
|
|
|
|
|
+template <typename scalar_t>
|
|
|
void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
|
|
|
void *g_idx, void *perm, void *a_tmp, int prob_m,
|
|
|
int prob_n, int prob_k, void *workspace, int num_bits,
|
|
@@ -1730,12 +1826,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
|
" is below min_workspace_size = ", min_workspace_size);
|
|
|
|
|
|
int dev = a.get_device();
|
|
|
- gptq_marlin::marlin_mm_f16i4(
|
|
|
- a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(),
|
|
|
- g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n,
|
|
|
- size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
|
|
|
- num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
|
|
- thread_k, thread_n, sms, gptq_marlin::max_par);
|
|
|
+ if (a.scalar_type() == at::ScalarType::Half) {
|
|
|
+ gptq_marlin::marlin_mm_f16i4<half>(
|
|
|
+ a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), b_scales.data_ptr<at::Half>(),
|
|
|
+ g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n,
|
|
|
+ size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
|
|
|
+ num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
|
|
+ thread_k, thread_n, sms, gptq_marlin::max_par);
|
|
|
+ } else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
|
|
+ gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
|
|
+ a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
|
|
+ g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n,
|
|
|
+ size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
|
|
|
+ num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
|
|
+ thread_k, thread_n, sms, gptq_marlin::max_par);
|
|
|
+ } else {
|
|
|
+ TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
|
|
+ }
|
|
|
|
|
|
return c;
|
|
|
}
|