|
@@ -48,12 +48,12 @@ namespace marlin_24 {
|
|
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
|
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
|
// we want relatively few warps to have many registers per warp and small tiles.
|
|
// we want relatively few warps to have many registers per warp and small tiles.
|
|
static constexpr int THREADS = 256;
|
|
static constexpr int THREADS = 256;
|
|
-static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory
|
|
|
|
|
|
+static constexpr int STAGES = 4;
|
|
|
|
|
|
static constexpr int min_thread_n = 128;
|
|
static constexpr int min_thread_n = 128;
|
|
|
|
|
|
static constexpr int tile_size = 16;
|
|
static constexpr int tile_size = 16;
|
|
-static constexpr int max_par = 16;
|
|
|
|
|
|
+static constexpr int max_par = 64;
|
|
|
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
|
|
|
|
@@ -736,10 +736,10 @@ __global__ void Marlin_24(
|
|
for (int pipe = 0; pipe < stages;) {
|
|
for (int pipe = 0; pipe < stages;) {
|
|
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
|
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
|
slice_iters >= stages);
|
|
slice_iters >= stages);
|
|
|
|
+ matmul(pipe);
|
|
wait_for_stage();
|
|
wait_for_stage();
|
|
|
|
|
|
fetch_to_registers(pipe + 1, (pipe + 1) % stages);
|
|
fetch_to_registers(pipe + 1, (pipe + 1) % stages);
|
|
- matmul(pipe);
|
|
|
|
|
|
|
|
pipe++;
|
|
pipe++;
|
|
slice_iters--;
|
|
slice_iters--;
|
|
@@ -899,9 +899,12 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
|
// than better compute utilization
|
|
// than better compute utilization
|
|
thread_k = 128;
|
|
thread_k = 128;
|
|
thread_m = 128;
|
|
thread_m = 128;
|
|
- } else {
|
|
|
|
|
|
+ } else if (prob_n <= 256) {
|
|
thread_k = 64;
|
|
thread_k = 64;
|
|
thread_m = 256;
|
|
thread_m = 256;
|
|
|
|
+ } else {
|
|
|
|
+ thread_k = 32;
|
|
|
|
+ thread_m = 512;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -928,19 +931,21 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
|
int4* C_ptr = (int4*)C;
|
|
int4* C_ptr = (int4*)C;
|
|
const int4* s_ptr = (const int4*)s;
|
|
const int4* s_ptr = (const int4*)s;
|
|
|
|
|
|
|
|
+ constexpr int max_m_blocks = 4;
|
|
|
|
+
|
|
int* locks = (int*)workspace;
|
|
int* locks = (int*)workspace;
|
|
- for (int i = 0; i < tot_n_blocks; i += 4) {
|
|
|
|
|
|
+ for (int i = 0; i < tot_n_blocks; i += max_m_blocks) {
|
|
int thread_n_blocks = tot_n_blocks - i;
|
|
int thread_n_blocks = tot_n_blocks - i;
|
|
prob_n = tot_n - 16 * i;
|
|
prob_n = tot_n - 16 * i;
|
|
int par = 1;
|
|
int par = 1;
|
|
- if (thread_n_blocks > 4) {
|
|
|
|
|
|
+ if (thread_n_blocks > max_m_blocks) {
|
|
// Note that parallel > 1 currently only works for inputs without any
|
|
// Note that parallel > 1 currently only works for inputs without any
|
|
// padding
|
|
// padding
|
|
- par = (16 * thread_n_blocks - pad) / 64;
|
|
|
|
|
|
+ par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16);
|
|
if (par > max_par) par = max_par;
|
|
if (par > max_par) par = max_par;
|
|
- prob_n = 64 * par;
|
|
|
|
- i += 4 * (par - 1);
|
|
|
|
- thread_n_blocks = 4;
|
|
|
|
|
|
+ prob_n = (max_m_blocks * 16) * par;
|
|
|
|
+ i += max_m_blocks * (par - 1);
|
|
|
|
+ thread_n_blocks = max_m_blocks;
|
|
}
|
|
}
|
|
|
|
|
|
// For compilation speed, we only define the kernel configurations that have
|
|
// For compilation speed, we only define the kernel configurations that have
|
|
@@ -951,8 +956,9 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
|
if (false) {
|
|
if (false) {
|
|
} // BMxBNxBK, group
|
|
} // BMxBNxBK, group
|
|
// 4-bit
|
|
// 4-bit
|
|
- CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
|
|
|
|
- CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
|
|
|
|
|
+ CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
|
|
|
|
+ CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
|
|
|
+
|
|
CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
|
|
CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
|
|
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
|
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
|
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
|
|
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
|
|
@@ -962,9 +968,19 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
|
CALL_IF_2_4(4, 16, 4, 2, -1)
|
|
CALL_IF_2_4(4, 16, 4, 2, -1)
|
|
CALL_IF_2_4(4, 16, 4, 2, 4)
|
|
CALL_IF_2_4(4, 16, 4, 2, 4)
|
|
|
|
|
|
|
|
+ CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64
|
|
|
|
+ CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64
|
|
|
|
+ CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64
|
|
|
|
+ CALL_IF_2_4(4, 32, 2, 1, 4)
|
|
|
|
+ CALL_IF_2_4(4, 32, 3, 1, -1)
|
|
|
|
+ CALL_IF_2_4(4, 32, 3, 1, 4)
|
|
|
|
+ CALL_IF_2_4(4, 32, 4, 1, -1)
|
|
|
|
+ CALL_IF_2_4(4, 32, 4, 1, 4)
|
|
|
|
+
|
|
// 8-bit
|
|
// 8-bit
|
|
- CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
|
|
|
|
- CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
|
|
|
|
|
+ CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
|
|
|
|
+ CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
|
|
|
+
|
|
CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
|
|
CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
|
|
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
|
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
|
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
|
|
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
|
|
@@ -973,6 +989,15 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
|
CALL_IF_2_4(8, 16, 3, 2, 4)
|
|
CALL_IF_2_4(8, 16, 3, 2, 4)
|
|
CALL_IF_2_4(8, 16, 4, 2, -1)
|
|
CALL_IF_2_4(8, 16, 4, 2, -1)
|
|
CALL_IF_2_4(8, 16, 4, 2, 4)
|
|
CALL_IF_2_4(8, 16, 4, 2, 4)
|
|
|
|
+
|
|
|
|
+ CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64
|
|
|
|
+ CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64
|
|
|
|
+ CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64
|
|
|
|
+ CALL_IF_2_4(8, 32, 2, 1, 4)
|
|
|
|
+ CALL_IF_2_4(8, 32, 3, 1, -1)
|
|
|
|
+ CALL_IF_2_4(8, 32, 3, 1, 4)
|
|
|
|
+ CALL_IF_2_4(8, 32, 4, 1, -1)
|
|
|
|
+ CALL_IF_2_4(8, 32, 4, 1, 4)
|
|
else {
|
|
else {
|
|
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
|
|
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
|
|
", " + str(prob_k) + ", " + str(prob_n) + "]" +
|
|
", " + str(prob_k) + ", " + str(prob_n) + "]" +
|
|
@@ -1062,7 +1087,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|
int thread_k = -1;
|
|
int thread_k = -1;
|
|
int thread_m = -1;
|
|
int thread_m = -1;
|
|
int sms = -1;
|
|
int sms = -1;
|
|
- int max_par = 16;
|
|
|
|
|
|
+ int max_par = marlin_24::max_par;
|
|
|
|
|
|
int groupsize = -1;
|
|
int groupsize = -1;
|
|
if (b_scales.size(0) > 1) {
|
|
if (b_scales.size(0) > 1) {
|