Browse Source

Revert "fix the awq gemm kernels"

This reverts commit 20c27863c1658de339ad14a67e57b9520f491bba.
AlpinDale 1 year ago
parent
commit
579071b570
1 changed files with 3 additions and 18 deletions
  1. 3 18
      kernels/quantization/awq/gemm_kernels.cu

+ 3 - 18
kernels/quantization/awq/gemm_kernels.cu

@@ -6,8 +6,7 @@ Adapted from https://github.com/mit-han-lab/llm-awq
   journal={arXiv},
   year={2023}
 }
- */
-
+*/
 
 #include <torch/extension.h>
 #include <c10/cuda/CUDAGuard.h>
@@ -16,9 +15,6 @@ Adapted from https://github.com/mit-han-lab/llm-awq
 
 #include <cuda_fp16.h>
 
-namespace aphrodite {
-namespace awq {
-
 // Pack two half values.
 static inline __device__ __host__ unsigned
 __pack_half2(const half x, const half y) {
@@ -29,9 +25,6 @@ __pack_half2(const half x, const half y) {
 
 __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) 
 {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
-  assert(false);
-#else
   static constexpr uint32_t ZERO = 0x0;
   float C_warp[32];
   __shared__ half A_shared[16 * (32 + 8)];
@@ -220,15 +213,11 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
       }
     }
   }
-#endif
 }
 
 
 __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) 
 {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
-  assert(false);
-#else
   static constexpr uint32_t ZERO = 0x0;
   float C_warp[32];
   __shared__ half A_shared[16 * (32 + 8)];
@@ -422,12 +411,8 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
       }
     }
   }
-#endif
 }
 
-} // namespace awq
-} // namespace aphrodite
-
 // in_feats: M, IC [float16]
 // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
 // scaling_factors: IC // G, OC [float16]
@@ -473,7 +458,7 @@ torch::Tensor awq_gemm(
         // threadIdx.x: 32
         // threadIdx.y: i_factors[2] * j_factors[2]
         dim3 threads_per_block(32, 2);
-        aphrodite::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
+        gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
             group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
     }
     else if (num_out_channels % 64 == 0)
@@ -484,7 +469,7 @@ torch::Tensor awq_gemm(
         // threadIdx.x: 32
         // threadIdx.y: i_factors[2] * j_factors[2]
         dim3 threads_per_block(32, 2);
-        aphrodite::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
+        gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
             group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
     }
     return _out_feats.sum(0);