|
@@ -32,7 +32,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
|
|
|
int4 *__restrict__ out_int4_ptr, int size_m,
|
|
|
int size_k, int block_rows) {}
|
|
|
|
|
|
-template <const int threads, // number of threads in a threadblock
|
|
|
+template <const int num_bits, // number of bits used for weights
|
|
|
+ const int threads, // number of threads in a threadblock
|
|
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
|
|
// dimension (batchsize) of the threadblock
|
|
|
const int thread_n_blocks, // same for n dimension (output)
|
|
@@ -62,8 +63,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
|
torch::Tensor &b_scales, torch::Tensor &g_idx,
|
|
|
torch::Tensor &perm, torch::Tensor &workspace,
|
|
|
- int64_t size_m, int64_t size_n, int64_t size_k,
|
|
|
- bool is_k_full) {
|
|
|
+ int64_t num_bits, int64_t size_m, int64_t size_n,
|
|
|
+ int64_t size_k, bool is_k_full) {
|
|
|
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
|
|
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
|
|
return torch::empty({1, 1});
|
|
@@ -114,11 +115,21 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) {
|
|
|
return res;
|
|
|
}
|
|
|
|
|
|
+// Constructs destination register by taking bytes from 2 sources (based on mask)
|
|
|
+template <int start_byte, int mask>
|
|
|
+__device__ inline uint32_t prmt(uint32_t a) {
|
|
|
+ uint32_t res;
|
|
|
+ asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
|
|
+ : "=r"(res)
|
|
|
+ : "r"(a), "n"(start_byte), "n"(mask));
|
|
|
+ return res;
|
|
|
+}
|
|
|
+
|
|
|
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
|
|
// values. We mostly follow the strategy in the link below, with some small
|
|
|
// changes:
|
|
|
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
|
|
-__device__ inline FragB dequant(int q) {
|
|
|
+__device__ inline FragB dequant_4bit(int q) {
|
|
|
const int LO = 0x000f000f;
|
|
|
const int HI = 0x00f000f0;
|
|
|
const int EX = 0x64006400;
|
|
@@ -139,6 +150,24 @@ __device__ inline FragB dequant(int q) {
|
|
|
return frag_b;
|
|
|
}
|
|
|
|
|
|
+__device__ inline FragB dequant_8bit(int q) {
|
|
|
+ static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
|
|
+ static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
|
|
+ static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
|
|
+
|
|
|
+ uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
|
|
+ uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
|
|
+
|
|
|
+ static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
|
|
+
|
|
|
+ FragB frag_b;
|
|
|
+ frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
|
|
|
+ *reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
|
|
|
+ frag_b[1] = __hsub2(*reinterpret_cast<half2 *>(&hi),
|
|
|
+ *reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
|
|
|
+ return frag_b;
|
|
|
+}
|
|
|
+
|
|
|
// Multiply dequantized values by the corresponding quantization scale; used
|
|
|
// only for grouped quantization.
|
|
|
__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
|
|
@@ -162,6 +191,13 @@ __device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2,
|
|
|
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
|
|
|
}
|
|
|
|
|
|
+// Given 2 floats multiply by 2 scales (halves)
|
|
|
+__device__ inline void scale_float(float *c, FragS &s) {
|
|
|
+ __half *s_ptr = reinterpret_cast<__half *>(&s);
|
|
|
+ c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
|
|
|
+ c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
|
|
|
+}
|
|
|
+
|
|
|
// Wait until barrier reaches `count`, then lock for current threadblock.
|
|
|
__device__ inline void barrier_acquire(int *lock, int count) {
|
|
|
if (threadIdx.x == 0) {
|
|
@@ -250,7 +286,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-template <const int threads, // number of threads in a threadblock
|
|
|
+template <const int num_bits, // number of bits used for weights
|
|
|
+ const int threads, // number of threads in a threadblock
|
|
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
|
|
// dimension (batchsize) of the threadblock
|
|
|
const int thread_n_blocks, // same for n dimension (output)
|
|
@@ -286,6 +323,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
// configurations, while requiring as few slow global cross-threadblock
|
|
|
// reductions as possible.
|
|
|
|
|
|
+ constexpr int pack_factor = 32 / num_bits;
|
|
|
+
|
|
|
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
|
|
// better partitioning with less reductions
|
|
|
int parallel = 1;
|
|
@@ -385,21 +424,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
|
|
|
|
|
|
// B sizes/strides
|
|
|
- int b_gl_stride = 16 * prob_n / 32;
|
|
|
- constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
|
|
|
+ int b_gl_stride = 16 * prob_n / (pack_factor * 4);
|
|
|
+ constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
|
|
|
+ constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
|
|
|
+ constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
|
|
|
+
|
|
|
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
|
|
- int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
|
|
|
- constexpr int b_sh_wr_delta = threads;
|
|
|
- constexpr int b_sh_rd_delta = threads;
|
|
|
+ int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
|
|
|
+ constexpr int b_sh_wr_delta = threads * b_thread_vecs;
|
|
|
+ constexpr int b_sh_rd_delta = threads * b_thread_vecs;
|
|
|
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
|
|
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
|
|
|
|
|
// Scale sizes/strides without act_order
|
|
|
int s_gl_stride = prob_n / 8;
|
|
|
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
|
|
- constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks
|
|
|
- ? thread_k_blocks / group_blocks
|
|
|
- : 1;
|
|
|
+ constexpr int s_tb_groups =
|
|
|
+ !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
|
|
+ ? thread_k_blocks / group_blocks
|
|
|
+ : 1;
|
|
|
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
|
|
int s_gl_rd_delta = s_gl_stride;
|
|
|
|
|
@@ -425,12 +468,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
|
|
|
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
|
|
|
|
|
|
- int b_gl_rd =
|
|
|
- b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
|
|
|
+ int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
|
|
|
+ (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
|
|
b_gl_rd += b_sh_stride * slice_col;
|
|
|
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
|
|
- int b_sh_wr = threadIdx.x;
|
|
|
- int b_sh_rd = threadIdx.x;
|
|
|
+ int b_sh_wr = threadIdx.x * b_thread_vecs;
|
|
|
+ int b_sh_rd = threadIdx.x * b_thread_vecs;
|
|
|
|
|
|
// For act_order
|
|
|
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
|
@@ -442,8 +485,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
// No act_order
|
|
|
int s_gl_rd;
|
|
|
if constexpr (!has_act_order) {
|
|
|
- s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
|
|
- s_sh_stride * slice_col + threadIdx.x;
|
|
|
+ if constexpr (group_blocks == -1) {
|
|
|
+ s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
|
|
+ } else {
|
|
|
+ s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
|
|
+ s_sh_stride * slice_col + threadIdx.x;
|
|
|
+ }
|
|
|
}
|
|
|
int s_sh_wr = threadIdx.x;
|
|
|
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
|
@@ -511,7 +558,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
|
|
|
// Register storage for double buffer of shared memory reads.
|
|
|
FragA frag_a[2][thread_m_blocks];
|
|
|
- I4 frag_b_quant[2];
|
|
|
+ I4 frag_b_quant[2][b_thread_vecs];
|
|
|
FragC frag_c[thread_m_blocks][4][2];
|
|
|
FragS frag_s[2][4]; // No act-order
|
|
|
FragS act_frag_s[2][4][4]; // For act-order
|
|
@@ -575,7 +622,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < b_sh_wr_iters; i++) {
|
|
|
- cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
|
|
|
+#pragma unroll
|
|
|
+ for (int j = 0; j < b_thread_vecs; j++) {
|
|
|
+ cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
|
|
|
+ }
|
|
|
+
|
|
|
B_ptr[i] += b_gl_rd_delta_o;
|
|
|
}
|
|
|
|
|
@@ -602,15 +653,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
// Only fetch scales if this tile starts a new group
|
|
|
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
|
|
if (s_sh_wr_pred) {
|
|
|
- cp_async4_stream(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
|
|
+ cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
|
|
}
|
|
|
s_gl_rd += s_gl_rd_delta;
|
|
|
}
|
|
|
} else {
|
|
|
for (int i = 0; i < s_tb_groups; i++) {
|
|
|
if (s_sh_wr_pred) {
|
|
|
- cp_async4_stream(&sh_s_stage[i * s_sh_stride + s_sh_wr],
|
|
|
- &scales_ptr[s_gl_rd]);
|
|
|
+ cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
|
|
|
+ &scales_ptr[s_gl_rd]);
|
|
|
}
|
|
|
s_gl_rd += s_gl_rd_delta;
|
|
|
}
|
|
@@ -641,14 +692,24 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
for (int i = 0; i < thread_m_blocks; i++)
|
|
|
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
|
|
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
|
|
|
- frag_b_quant[k % 2] = *reinterpret_cast<I4 *>(
|
|
|
- &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < b_thread_vecs; i++) {
|
|
|
+ frag_b_quant[k % 2][i] = *reinterpret_cast<I4 *>(
|
|
|
+ &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
bool is_same_group[stages];
|
|
|
int same_group_id[stages];
|
|
|
|
|
|
auto init_same_group = [&](int pipe) {
|
|
|
+ if constexpr (!has_act_order) {
|
|
|
+ is_same_group[pipe] = false;
|
|
|
+ same_group_id[pipe] = 0;
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
|
|
int *sh_g_idx_int_ptr = reinterpret_cast<int *>(sh_g_idx_stage);
|
|
|
|
|
@@ -767,10 +828,23 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
// dequantization and matmul operations.
|
|
|
#pragma unroll
|
|
|
for (int j = 0; j < 4; j++) {
|
|
|
- int b_quant = frag_b_quant[k % 2][j];
|
|
|
- int b_quant_shift = b_quant >> 8;
|
|
|
+ FragB frag_b0;
|
|
|
+ FragB frag_b1;
|
|
|
+ if constexpr (num_bits == 4) {
|
|
|
+ int b_quant = frag_b_quant[k % 2][0][j];
|
|
|
+ int b_quant_shift = b_quant >> 8;
|
|
|
+
|
|
|
+ frag_b0 = dequant_4bit(b_quant);
|
|
|
+ frag_b1 = dequant_4bit(b_quant_shift);
|
|
|
|
|
|
- FragB frag_b0 = dequant(b_quant);
|
|
|
+ } else {
|
|
|
+ int *frag_b_quant_ptr = reinterpret_cast<int *>(frag_b_quant[k % 2]);
|
|
|
+ int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
|
|
+ int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
|
|
+
|
|
|
+ frag_b0 = dequant_8bit(b_quant_0);
|
|
|
+ frag_b1 = dequant_8bit(b_quant_1);
|
|
|
+ }
|
|
|
|
|
|
// Apply scale to frag_b0
|
|
|
if constexpr (has_act_order) {
|
|
@@ -782,8 +856,6 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- FragB frag_b1 = dequant(b_quant_shift);
|
|
|
-
|
|
|
// Apply scale to frag_b1
|
|
|
if constexpr (has_act_order) {
|
|
|
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
|
@@ -808,13 +880,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
// multiple warps that accumulate their partial sums of the same output
|
|
|
// location; which we have to reduce over in the end. We do in shared memory.
|
|
|
auto thread_block_reduce = [&]() {
|
|
|
- constexpr int red_off = threads / b_sh_stride / 2;
|
|
|
+ constexpr int red_off = threads / b_sh_stride_threads / 2;
|
|
|
if (red_off >= 1) {
|
|
|
- int red_idx = threadIdx.x / b_sh_stride;
|
|
|
- constexpr int red_sh_stride = b_sh_stride * 4 * 2;
|
|
|
- constexpr int red_sh_delta = b_sh_stride;
|
|
|
- int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
|
|
|
- (threadIdx.x % b_sh_stride);
|
|
|
+ int red_idx = threadIdx.x / b_sh_stride_threads;
|
|
|
+ constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
|
|
+ constexpr int red_sh_delta = b_sh_stride_threads;
|
|
|
+ int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
|
|
+ (threadIdx.x % b_sh_stride_threads);
|
|
|
|
|
|
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
|
|
// unnecessary read or write iterations, e.g., for two warps we write only
|
|
@@ -861,7 +933,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
};
|
|
|
|
|
|
// Since multiple threadblocks may process parts of the same column slice, we
|
|
|
- // finally have to globally reduce over the results. As the striped portioning
|
|
|
+ // finally have to globally reduce over the results. As the striped partitioning
|
|
|
// minimizes the number of such reductions and our outputs are usually rather
|
|
|
// small, we perform this reduction serially in L2 cache.
|
|
|
auto global_reduce = [&](bool first = false, bool last = false) {
|
|
@@ -951,13 +1023,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
auto write = [&](int idx, float c0, float c1, FragS &s) {
|
|
|
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
|
|
|
|
|
|
- // For per-column quantization we finally apply the scale here
|
|
|
- if constexpr (!has_act_order && group_blocks == -1) {
|
|
|
+ // For per-column quantization we finally apply the scale here (only for
|
|
|
+ // 4-bit)
|
|
|
+ if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {
|
|
|
res = __hmul2(res, s[0]);
|
|
|
}
|
|
|
|
|
|
((half2 *)sh)[idx] = res;
|
|
|
};
|
|
|
+
|
|
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < thread_m_blocks; i++) {
|
|
@@ -1023,6 +1097,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
// ensure all shared memory accesses are static. Note that both pipelines
|
|
|
// have even length meaning that the next iteration will always start at
|
|
|
// index 0.
|
|
|
+
|
|
|
#pragma unroll
|
|
|
for (int pipe = 0; pipe < stages;) {
|
|
|
#pragma unroll
|
|
@@ -1070,23 +1145,63 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
// For per-column scales, we only fetch them here in the final step before
|
|
|
// write-out
|
|
|
if constexpr (!has_act_order && group_blocks == -1) {
|
|
|
- if (last) {
|
|
|
+ if constexpr (num_bits == 8) {
|
|
|
if (s_sh_wr_pred) {
|
|
|
- cp_async4_stream(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
|
|
+ cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
|
|
}
|
|
|
cp_async_fence();
|
|
|
+ } else {
|
|
|
+ if (last) {
|
|
|
+ if (s_sh_wr_pred) {
|
|
|
+ cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
|
|
+ }
|
|
|
+ cp_async_fence();
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
thread_block_reduce();
|
|
|
if constexpr (!has_act_order && group_blocks == -1) {
|
|
|
- if (last) {
|
|
|
+ if constexpr (num_bits == 8) {
|
|
|
cp_async_wait<0>();
|
|
|
__syncthreads();
|
|
|
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
|
|
reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
|
|
reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
|
|
}
|
|
|
+
|
|
|
+ } else {
|
|
|
+ if (last) {
|
|
|
+ cp_async_wait<0>();
|
|
|
+ __syncthreads();
|
|
|
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
|
|
+ reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
|
|
+ reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // For 8-bit channelwise, we apply the scale before the global reduction
|
|
|
+ // that converts the fp32 results to fp16 (so that we avoid possible
|
|
|
+ // overflow in fp16)
|
|
|
+ if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {
|
|
|
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < thread_m_blocks; i++) {
|
|
|
+#pragma unroll
|
|
|
+ for (int j = 0; j < 4; j++) {
|
|
|
+ scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
|
|
|
+ frag_s[j / 2][2 * (j % 2) + 0]);
|
|
|
+ scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
|
|
|
+ frag_s[j / 2][2 * (j % 2) + 0]);
|
|
|
+
|
|
|
+ scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
|
|
|
+ frag_s[j / 2][2 * (j % 2) + 1]);
|
|
|
+ scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
|
|
|
+ frag_s[j / 2][2 * (j % 2) + 1]);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -1125,28 +1240,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
|
|
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
|
|
}
|
|
|
|
|
|
- // if (blockIdx.x == 0 && threadIdx.x == 0) {
|
|
|
- // printf("Move\n");
|
|
|
- // }
|
|
|
start_pipes();
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
|
|
+#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
|
|
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
|
|
|
- else if (thread_m_blocks == THREAD_M_BLOCKS && \
|
|
|
+ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
|
|
thread_n_blocks == THREAD_N_BLOCKS && \
|
|
|
thread_k_blocks == THREAD_K_BLOCKS && \
|
|
|
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
|
|
num_threads == NUM_THREADS) { \
|
|
|
cudaFuncSetAttribute( \
|
|
|
- Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
|
|
- pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
|
|
+ Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
|
|
+ THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
|
|
- Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
|
|
- pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
|
|
+ Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
|
|
+ THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
|
|
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
|
|
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
|
|
|
prob_k, locks); \
|
|
@@ -1158,28 +1270,92 @@ typedef struct {
|
|
|
int num_threads;
|
|
|
} thread_config_t;
|
|
|
|
|
|
-thread_config_t small_batch_thread_configs[] = {
|
|
|
+typedef struct {
|
|
|
+ int max_m_blocks;
|
|
|
+ thread_config_t tb_cfg;
|
|
|
+} exec_config_t;
|
|
|
+
|
|
|
+thread_config_t thread_configs[] = {
|
|
|
// Ordered by priority
|
|
|
|
|
|
// thread_k, thread_n, num_threads
|
|
|
- {128, 128, 256}, // Default
|
|
|
- {128, 64, 128}, // Reduce N 2X, same K
|
|
|
- {64, 256, 256}, // Reduce K 2X, increase N 2X
|
|
|
- {64, 128, 128}, // Reduce K 2X, same N
|
|
|
+ {64, 256, 256}, // Default (max cache usage)
|
|
|
+ {64, 128, 128}, // Reduce N, reduce warps
|
|
|
+ {128, 64, 128}, // Reduce N more, but increase K
|
|
|
+
|
|
|
};
|
|
|
|
|
|
-thread_config_t large_batch_thread_configs[] = {
|
|
|
- // Ordered by priority
|
|
|
+int get_scales_cache_size(thread_config_t const &th_config, int prob_m,
|
|
|
+ int prob_n, int prob_k, int num_bits, int group_size,
|
|
|
+ bool has_act_order, bool is_k_full) {
|
|
|
+ bool cache_scales_chunk = has_act_order && !is_k_full;
|
|
|
|
|
|
- // thread_k, thread_n, num_threads
|
|
|
- {64, 256, 256}, // Default
|
|
|
- {128, 64, 128}, // Reduce N 2X, same K
|
|
|
- {64, 128, 128}, // Reduce N 2X, same K
|
|
|
- // {128, 64, 128}, // Reduce N 4X, increase K 2X
|
|
|
-};
|
|
|
+ int tb_n = th_config.thread_n;
|
|
|
+ int tb_k = th_config.thread_k;
|
|
|
+
|
|
|
+ // Get max scale groups per thread-block
|
|
|
+ int tb_groups;
|
|
|
+ if (group_size == -1) {
|
|
|
+ tb_groups = 1;
|
|
|
+ } else if (group_size == 0) {
|
|
|
+ tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
|
|
|
+ } else {
|
|
|
+ tb_groups = div_ceil(tb_k, group_size);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (cache_scales_chunk) {
|
|
|
+ int load_groups =
|
|
|
+ tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
|
|
+ load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
|
|
+ return load_groups * tb_n * 2;
|
|
|
+
|
|
|
+ } else {
|
|
|
+ int tb_scales = tb_groups * tb_n * 2;
|
|
|
+
|
|
|
+ return tb_scales * pipe_stages;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks,
|
|
|
+ int prob_m, int prob_n, int prob_k, int num_bits,
|
|
|
+ int scales_cache_size, int max_shared_mem) {
|
|
|
+ int pack_factor = 32 / num_bits;
|
|
|
+
|
|
|
+ // Get B size
|
|
|
+ int tb_k = th_config.thread_k;
|
|
|
+ int tb_n = th_config.thread_n;
|
|
|
+
|
|
|
+ int b_size = (tb_k * tb_n / pack_factor) * 4;
|
|
|
+
|
|
|
+ // Get A size
|
|
|
+ int m_blocks = div_ceil(prob_m, 16);
|
|
|
+ int tb_max_m = 16;
|
|
|
|
|
|
-bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
|
|
|
- int prob_k) {
|
|
|
+ while (true) {
|
|
|
+ if (m_blocks >= max_m_blocks) {
|
|
|
+ tb_max_m *= max_m_blocks;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+
|
|
|
+ max_m_blocks--;
|
|
|
+ if (max_m_blocks == 0) {
|
|
|
+ TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ int a_size = (tb_max_m * tb_k) * 2;
|
|
|
+
|
|
|
+ float pipe_size = (a_size + b_size) * pipe_stages;
|
|
|
+
|
|
|
+ TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
|
|
|
+
|
|
|
+ return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
|
|
|
+}
|
|
|
+
|
|
|
+bool is_valid_config(thread_config_t const &th_config, int max_m_blocks,
|
|
|
+ int prob_m, int prob_n, int prob_k, int num_bits,
|
|
|
+ int group_size, bool has_act_order, bool is_k_full,
|
|
|
+ int max_shared_mem) {
|
|
|
// Sanity
|
|
|
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
|
|
th_config.num_threads == -1) {
|
|
@@ -1201,62 +1377,79 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
+ // Determine cache for scales
|
|
|
+ int scales_cache_size =
|
|
|
+ get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
|
|
+ group_size, has_act_order, is_k_full);
|
|
|
+
|
|
|
+ // Check that pipeline fits into cache
|
|
|
+ if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
|
|
+ num_bits, scales_cache_size, max_shared_mem)) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
-thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
|
|
|
-
|
|
|
- // TODO: Enable if needed after some more testing
|
|
|
- if (prob_m <= 0) {
|
|
|
- for (auto th_config : small_batch_thread_configs) {
|
|
|
- if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
|
|
- return th_config;
|
|
|
+exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
|
|
+ int num_bits, int group_size,
|
|
|
+ bool has_act_order, bool is_k_full,
|
|
|
+ int max_shared_mem) {
|
|
|
+ int max_m_blocks = 4;
|
|
|
+ while (max_m_blocks > 0) {
|
|
|
+ for (auto th_config : thread_configs) {
|
|
|
+ if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
|
|
+ num_bits, group_size, has_act_order, is_k_full,
|
|
|
+ max_shared_mem)) {
|
|
|
+ return exec_config_t{max_m_blocks, th_config};
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- } else {
|
|
|
- for (auto th_config : large_batch_thread_configs) {
|
|
|
- if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
|
|
|
- return th_config;
|
|
|
- }
|
|
|
- }
|
|
|
+ printf("WARNING: Marlin kernel is reducing max_m_blocks due to small SM "
|
|
|
+ "GPU cache. This may "
|
|
|
+ "hurt performance. Consider upgrading your GPU.\n");
|
|
|
+
|
|
|
+ max_m_blocks--; // Process less M blocks per invocation to reduce cache
|
|
|
+ // usage
|
|
|
}
|
|
|
|
|
|
- return thread_config_t{-1, -1, -1};
|
|
|
+ return exec_config_t{0, {-1, -1, -1}};
|
|
|
}
|
|
|
|
|
|
-#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
|
|
- __CALL_IF(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
|
|
- __CALL_IF(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
|
|
- __CALL_IF(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
|
|
- __CALL_IF(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
|
|
+#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
|
|
\
|
|
|
- __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
|
|
- __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
|
|
- __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
|
|
- __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
|
|
\
|
|
|
- __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
|
|
- __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
|
|
- __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
|
|
- __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
|
|
\
|
|
|
- __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
|
|
- __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
|
|
- __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
|
|
- __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
|
|
\
|
|
|
- __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
|
|
- __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
|
|
- __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
|
|
- __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
|
|
-
|
|
|
-void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
|
|
|
- void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k,
|
|
|
- void *workspace, bool has_act_order, bool is_k_full,
|
|
|
- int num_groups, int group_size, int dev = 0,
|
|
|
- cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1,
|
|
|
- int sms = -1, int max_par = 16) {
|
|
|
+ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
|
|
+ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
|
|
+
|
|
|
+void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
|
|
|
+ void *g_idx, void *perm, void *a_tmp, int prob_m,
|
|
|
+ int prob_n, int prob_k, void *workspace, int num_bits,
|
|
|
+ bool has_act_order, bool is_k_full, int num_groups,
|
|
|
+ int group_size, int dev, cudaStream_t stream, int thread_k,
|
|
|
+ int thread_n, int sms, int max_par) {
|
|
|
+ TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
|
|
+ "num_bits must be 4 or 8. Got = ", num_bits);
|
|
|
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
|
|
", ", prob_n, ", ", prob_k, "]");
|
|
|
|
|
@@ -1274,25 +1467,34 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
|
|
|
TORCH_CHECK(max_shared_mem > 0);
|
|
|
|
|
|
// Set thread config
|
|
|
- thread_config_t th_config;
|
|
|
+ exec_config_t exec_cfg;
|
|
|
if (thread_k != -1 && thread_n != -1) {
|
|
|
// User-defined config
|
|
|
- th_config = thread_config_t{thread_k, thread_n, default_threads};
|
|
|
+ exec_cfg =
|
|
|
+ exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
|
|
|
} else {
|
|
|
// Auto config
|
|
|
- th_config = determine_thread_config(prob_m, prob_n, prob_k);
|
|
|
+ exec_cfg =
|
|
|
+ determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
|
|
|
+ has_act_order, is_k_full, max_shared_mem);
|
|
|
}
|
|
|
|
|
|
- TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k),
|
|
|
- "Invalid thread config: thread_k = " + str(th_config.thread_k) +
|
|
|
- ", thread_n = " + str(th_config.thread_n) +
|
|
|
- ", num_threads = " + str(th_config.num_threads) +
|
|
|
- " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " +
|
|
|
- str(prob_n) + "]");
|
|
|
-
|
|
|
- int num_threads = th_config.num_threads;
|
|
|
- thread_k = th_config.thread_k;
|
|
|
- thread_n = th_config.thread_n;
|
|
|
+ TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
|
|
|
+ is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
|
|
|
+ prob_m, prob_n, prob_k, num_bits, group_size,
|
|
|
+ has_act_order, is_k_full, max_shared_mem),
|
|
|
+ "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
|
|
|
+ ", thread_k = ", exec_cfg.tb_cfg.thread_k,
|
|
|
+ ", thread_n = ", exec_cfg.tb_cfg.thread_n,
|
|
|
+ ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
|
|
|
+ prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
|
|
+ ", group_size = ", group_size,
|
|
|
+ ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
|
|
+ ", max_shared_mem = ", max_shared_mem);
|
|
|
+
|
|
|
+ int num_threads = exec_cfg.tb_cfg.num_threads;
|
|
|
+ thread_k = exec_cfg.tb_cfg.thread_k;
|
|
|
+ thread_n = exec_cfg.tb_cfg.thread_n;
|
|
|
|
|
|
int thread_k_blocks = thread_k / 16;
|
|
|
int thread_n_blocks = thread_n / 16;
|
|
@@ -1352,28 +1554,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
|
|
|
}
|
|
|
|
|
|
// Main loop
|
|
|
- for (int i = 0; i < tot_m_blocks; i += 4) {
|
|
|
+ for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
|
|
|
int thread_m_blocks = tot_m_blocks - i;
|
|
|
prob_m = tot_m - 16 * i;
|
|
|
int par = 1;
|
|
|
- if (thread_m_blocks > 4) {
|
|
|
+ if (thread_m_blocks > exec_cfg.max_m_blocks) {
|
|
|
// Note that parallel > 1 currently only works for inputs without any
|
|
|
// padding
|
|
|
- par = (16 * thread_m_blocks - pad) / 64;
|
|
|
+ par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
|
|
|
if (par > max_par)
|
|
|
par = max_par;
|
|
|
- prob_m = 64 * par;
|
|
|
- i += 4 * (par - 1);
|
|
|
- thread_m_blocks = 4;
|
|
|
+ prob_m = (16 * exec_cfg.max_m_blocks) * par;
|
|
|
+ i += exec_cfg.max_m_blocks * (par - 1);
|
|
|
+ thread_m_blocks = exec_cfg.max_m_blocks;
|
|
|
}
|
|
|
|
|
|
// Define kernel configurations
|
|
|
if (false) {
|
|
|
}
|
|
|
- CALL_IF(16, 4, 256)
|
|
|
- CALL_IF(8, 8, 256)
|
|
|
- CALL_IF(8, 4, 128)
|
|
|
- CALL_IF(4, 8, 128)
|
|
|
+ CALL_IF(4, 32, 2, 256)
|
|
|
+ CALL_IF(4, 16, 4, 256)
|
|
|
+ CALL_IF(4, 8, 4, 128)
|
|
|
+ CALL_IF(4, 4, 8, 128)
|
|
|
+ CALL_IF(8, 32, 2, 256)
|
|
|
+ CALL_IF(8, 16, 4, 256)
|
|
|
+ CALL_IF(8, 8, 4, 128)
|
|
|
+ CALL_IF(8, 4, 8, 128)
|
|
|
else {
|
|
|
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
|
|
|
str(prob_n) + ", " + str(prob_k) + "]" +
|
|
@@ -1395,33 +1601,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
|
|
|
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
|
torch::Tensor &b_scales, torch::Tensor &g_idx,
|
|
|
torch::Tensor &perm, torch::Tensor &workspace,
|
|
|
- int64_t size_m, int64_t size_n, int64_t size_k,
|
|
|
- bool is_k_full) {
|
|
|
+ int64_t num_bits, int64_t size_m, int64_t size_n,
|
|
|
+ int64_t size_k, bool is_k_full) {
|
|
|
+ // Verify num_bits
|
|
|
+ TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
|
|
+ "num_bits must be 4 or 8. Got = ", num_bits);
|
|
|
+ int pack_factor = 32 / num_bits;
|
|
|
+
|
|
|
// Verify A
|
|
|
- TORCH_CHECK(a.size(0) == size_m,
|
|
|
- "Shape mismatch: a.size(0) = " + str(a.size(0)) +
|
|
|
- ", size_m = " + str(size_m));
|
|
|
- TORCH_CHECK(a.size(1) == size_k,
|
|
|
- "Shape mismatch: a.size(1) = " + str(a.size(1)) +
|
|
|
- ", size_k = " + str(size_k));
|
|
|
+ TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
|
|
+ ", size_m = ", size_m);
|
|
|
+ TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
|
|
|
+ ", size_k = ", size_k);
|
|
|
|
|
|
// Verify B
|
|
|
- TORCH_CHECK(size_k % gptq_marlin::tile_size == 0,
|
|
|
- "size_k = " + str(size_k) + " is not divisible by tile_size = " +
|
|
|
- str(gptq_marlin::tile_size));
|
|
|
+ TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
|
|
|
+ " is not divisible by tile_size = ", gptq_marlin::tile_size);
|
|
|
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
|
|
|
- "Shape mismatch: b_q_weight.size(0) = " +
|
|
|
- str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
|
|
|
- ", tile_size = " + str(gptq_marlin::tile_size));
|
|
|
- TORCH_CHECK(
|
|
|
- b_q_weight.size(1) % gptq_marlin::tile_size == 0,
|
|
|
- "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
|
|
|
- " is not divisible by tile_size = " + str(gptq_marlin::tile_size));
|
|
|
- int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) *
|
|
|
- gptq_marlin::pack_factor_4bit;
|
|
|
- TORCH_CHECK(size_n == actual_size_n,
|
|
|
- "size_n = " + str(size_n) +
|
|
|
- ", actual_size_n = " + str(actual_size_n));
|
|
|
+ "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
|
|
+ ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
|
|
|
+ TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
|
|
|
+ "b_q_weight.size(1) = ", b_q_weight.size(1),
|
|
|
+ " is not divisible by tile_size = ", gptq_marlin::tile_size);
|
|
|
+ int actual_size_n =
|
|
|
+ (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
|
|
|
+ TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
|
|
+ ", actual_size_n = ", actual_size_n);
|
|
|
|
|
|
// Verify device and strides
|
|
|
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
|
@@ -1457,9 +1662,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
|
// Verify g_idx and perm
|
|
|
TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||
|
|
|
(g_idx.size(0) == size_k && perm.size(0) == size_k),
|
|
|
- "Unexpected g_idx.size(0) = " + str(g_idx.size(0)) +
|
|
|
- " and perm.size(0) = " + str(perm.size(0)) +
|
|
|
- ", where size_k = " + str(size_k));
|
|
|
+ "Unexpected g_idx.size(0) = ", g_idx.size(0),
|
|
|
+ " and perm.size(0) = ", perm.size(0),
|
|
|
+ ", where size_k = ", size_k);
|
|
|
|
|
|
// Detect groupsize and act_order
|
|
|
int num_groups = -1;
|
|
@@ -1475,9 +1680,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
|
if (has_act_order) {
|
|
|
if (is_k_full) {
|
|
|
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
|
|
- TORCH_CHECK(size_k % num_groups == 0,
|
|
|
- "size_k = " + str(size_k) +
|
|
|
- ", is not divisible by num_groups = " + str(num_groups));
|
|
|
+ TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
|
|
|
+ ", is not divisible by num_groups = ", num_groups);
|
|
|
group_size = size_k / num_groups;
|
|
|
} else {
|
|
|
group_size = 0;
|
|
@@ -1485,10 +1689,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
|
|
|
|
} else {
|
|
|
if (num_groups > 1) {
|
|
|
- TORCH_CHECK(size_k % num_groups == 0,
|
|
|
- "size_k = " + str(size_k) +
|
|
|
- ", is not divisible by b_scales.size(0) = " +
|
|
|
- str(b_scales.size(0)));
|
|
|
+ TORCH_CHECK(
|
|
|
+ size_k % num_groups == 0, "size_k = ", size_k,
|
|
|
+ ", is not divisible by b_scales.size(0) = ", b_scales.size(0));
|
|
|
group_size = size_k / num_groups;
|
|
|
} else {
|
|
|
group_size = -1;
|
|
@@ -1496,23 +1699,22 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|
|
}
|
|
|
|
|
|
// Verify workspace size
|
|
|
- TORCH_CHECK(size_n % gptq_marlin::min_thread_n == 0,
|
|
|
- "size_n = " + str(size_n) +
|
|
|
- ", is not divisible by min_thread_n = " +
|
|
|
- str(gptq_marlin::min_thread_n));
|
|
|
+ TORCH_CHECK(
|
|
|
+ size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
|
|
|
+ ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
|
|
|
int min_workspace_size =
|
|
|
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
|
|
|
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
|
|
- "workspace.numel = " + str(workspace.numel()) +
|
|
|
- " is below min_workspace_size = " + str(min_workspace_size));
|
|
|
+ "workspace.numel = ", workspace.numel(),
|
|
|
+ " is below min_workspace_size = ", min_workspace_size);
|
|
|
|
|
|
int dev = a.get_device();
|
|
|
- gptq_marlin::marlin_cuda(
|
|
|
+ gptq_marlin::marlin_mm_f16i4(
|
|
|
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(),
|
|
|
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n,
|
|
|
- size_k, workspace.data_ptr(), has_act_order, is_k_full, num_groups,
|
|
|
- group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
|
|
|
- sms, gptq_marlin::max_par);
|
|
|
+ size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
|
|
|
+ num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
|
|
+ thread_k, thread_n, sms, gptq_marlin::max_par);
|
|
|
|
|
|
return c;
|
|
|
}
|