|
@@ -6,8 +6,7 @@ Adapted from https://github.com/mit-han-lab/llm-awq
|
|
journal={arXiv},
|
|
journal={arXiv},
|
|
year={2023}
|
|
year={2023}
|
|
}
|
|
}
|
|
- */
|
|
|
|
-
|
|
|
|
|
|
+*/
|
|
|
|
|
|
#include <torch/extension.h>
|
|
#include <torch/extension.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
@@ -16,9 +15,6 @@ Adapted from https://github.com/mit-han-lab/llm-awq
|
|
|
|
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_fp16.h>
|
|
|
|
|
|
-namespace aphrodite {
|
|
|
|
-namespace awq {
|
|
|
|
-
|
|
|
|
// Pack two half values.
|
|
// Pack two half values.
|
|
static inline __device__ __host__ unsigned
|
|
static inline __device__ __host__ unsigned
|
|
__pack_half2(const half x, const half y) {
|
|
__pack_half2(const half x, const half y) {
|
|
@@ -29,9 +25,6 @@ __pack_half2(const half x, const half y) {
|
|
|
|
|
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
|
{
|
|
{
|
|
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
|
|
- assert(false);
|
|
|
|
-#else
|
|
|
|
static constexpr uint32_t ZERO = 0x0;
|
|
static constexpr uint32_t ZERO = 0x0;
|
|
float C_warp[32];
|
|
float C_warp[32];
|
|
__shared__ half A_shared[16 * (32 + 8)];
|
|
__shared__ half A_shared[16 * (32 + 8)];
|
|
@@ -220,15 +213,11 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-#endif
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
|
{
|
|
{
|
|
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
|
|
|
- assert(false);
|
|
|
|
-#else
|
|
|
|
static constexpr uint32_t ZERO = 0x0;
|
|
static constexpr uint32_t ZERO = 0x0;
|
|
float C_warp[32];
|
|
float C_warp[32];
|
|
__shared__ half A_shared[16 * (32 + 8)];
|
|
__shared__ half A_shared[16 * (32 + 8)];
|
|
@@ -422,12 +411,8 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-#endif
|
|
|
|
}
|
|
}
|
|
|
|
|
|
-} // namespace awq
|
|
|
|
-} // namespace aphrodite
|
|
|
|
-
|
|
|
|
// in_feats: M, IC [float16]
|
|
// in_feats: M, IC [float16]
|
|
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
|
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
|
// scaling_factors: IC // G, OC [float16]
|
|
// scaling_factors: IC // G, OC [float16]
|
|
@@ -473,7 +458,7 @@ torch::Tensor awq_gemm(
|
|
// threadIdx.x: 32
|
|
// threadIdx.x: 32
|
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
|
dim3 threads_per_block(32, 2);
|
|
dim3 threads_per_block(32, 2);
|
|
- aphrodite::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
|
|
|
|
|
|
+ gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
|
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
|
}
|
|
}
|
|
else if (num_out_channels % 64 == 0)
|
|
else if (num_out_channels % 64 == 0)
|
|
@@ -484,7 +469,7 @@ torch::Tensor awq_gemm(
|
|
// threadIdx.x: 32
|
|
// threadIdx.x: 32
|
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
|
dim3 threads_per_block(32, 2);
|
|
dim3 threads_per_block(32, 2);
|
|
- aphrodite::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
|
|
|
|
|
|
+ gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
|
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
|
}
|
|
}
|
|
return _out_feats.sum(0);
|
|
return _out_feats.sum(0);
|