Browse Source

fix: bnb on Turing GPUs (#299)

* Revert "fix: bump bnb kernels to sm_80 due to async stream copies"

This reverts commit 49793d7c5a8064cbc01e52db623e26a65ec07200.

* Revert "fix: backport bnb kernels (#297)"

This reverts commit 82955ba4403cc3d69400ae1e5588a42ac48c8ef5.

* bump to sm_80

* temporarily allow turing

* valueerror for older arches
AlpinDale 1 year ago
parent
commit
a98babfb74

+ 5 - 0
aphrodite/common/config.py

@@ -194,6 +194,11 @@ class ModelConfig:
                     f"method specified in the `quantization` argument "
                     f"({self.quantization}).")
         if self.load_in_4bit:
+            # the kernels seem to not work with 4bit weight_only
+            if torch.cuda.get_device_capability(0)[0] < 8:
+                raise ValueError(
+                    "load_in_4bit quantization is not supported on GPUs with "
+                    "compute capability less than 8.0.")
             if self.quantization is None:
                 self.quantization = "bnb"
                 self.hf_config.quantization_config = {

+ 1 - 1
aphrodite/modeling/layers/quantization/bitsandbytes.py

@@ -56,7 +56,7 @@ class BitsandBytesConfig(QuantizationConfig):
 
     def get_min_capability(self) -> int:
         # The BitsandBytes kernel only supports Ampere or newer GPUs.
-        return 80
+        return 75
 
     def merge_weight(self) -> bool:
         return True

+ 0 - 30
kernels/quantization/bitsandbytes/format.cu

@@ -23,7 +23,6 @@ namespace autoquant {
 
 __device__ void atomic_assign_u4(uint32_t* address, uint32_t index, uint32_t value)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     uint32_t old = *address;
     uint32_t assumed;
     do {
@@ -31,20 +30,16 @@ __device__ void atomic_assign_u4(uint32_t* address, uint32_t index, uint32_t val
         uint32_t tmp = (assumed & ~(0xfu << (index * 4u))) | (value << (index * 4u));
         old          = atomicCAS(address, assumed, tmp);
     } while (assumed != old);
-#endif
 }
 
 __device__ uint32_t read_u4(const uint32_t* address, uint32_t index)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     return (*address >> (index * 4u)) & 0xfu;
-#endif
 }
 
 template<int... Ds>
 __global__ void permute_u4(uint* dst, const uint* src, Array<int, sizeof...(Ds)> dims)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     constexpr int N = sizeof...(Ds);
 
     size_t count = 1;
@@ -77,63 +72,50 @@ __global__ void permute_u4(uint* dst, const uint* src, Array<int, sizeof...(Ds)>
 
         atomic_assign_u4(dst + index / 8, index % 8, data);
     }
-#endif
 }
 
 void reformat_s4_k8_m(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     // permutation for [k/8, m] layout
     Array<int, 10> shape{k / 32, 2, 2, m / 32, 2, 2, 8, 2, 2, 2};
     //        |warp|  lane  | 2x2 |  a0-7  |
     permute_u4<0, 3, 6, 8, 9, 1, 4, 7, 2, 5><<<512, 512, 0, st>>>(dst, src, shape);
-#endif
 }
 
 void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     // permutation for [k, m/8] layout
     Array<int, 10> shape{k / 32, 2, 2, 4, 2, m / 32, 2, 2, 2, 4};
     //        |warp|  lane  | 2x2 |  a0-7  |
     permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape);
-#endif
 }
 
 __global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t count)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
         dst[i] = dequantize_s4_to_fp16x2_v2(src[i]);
     }
-#endif
 }
 
 __global__ void dequantize_s4_offset_64_bf16(uint4* dst, const uint32_t* src, size_t count)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
         dst[i] = dequantize_s4_to_bf16x2_v2(src[i]);
     }
-#endif
 }
 
 __global__ void merge_Q(half2* Q, const half* scales, const half* zeros, int count)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
         Q[i] = __halves2half2(zeros[i], scales[i]);
     }
-#endif
 }
 
 __global__ void merge_Q(__nv_bfloat162* Q, const __nv_bfloat16* scales, const __nv_bfloat16* zeros, int count)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
         Q[i] = 	halves2bfloat162(zeros[i], scales[i]);
     }
-#endif
 }
 
 void convert_s4_k_m8(uint32_t*       A_dst,
@@ -147,11 +129,9 @@ void convert_s4_k_m8(uint32_t*       A_dst,
                      int             group_size,
                      cudaStream_t    st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     dequantize_s4_offset_64<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8);
     merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m);
     reformat_s4_k_m8(A_dst, A_src, m, k, st);
-#endif
 }
 void convert_s4_k_m8(uint32_t*            A_dst,
                      __nv_bfloat162*      Q_dst,
@@ -164,48 +144,38 @@ void convert_s4_k_m8(uint32_t*            A_dst,
                      int                  group_size,
                      cudaStream_t         st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     dequantize_s4_offset_64_bf16<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8);
     merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m);
     reformat_s4_k_m8(A_dst, A_src, m, k, st);
-#endif
 }
 
 void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     Array<int, 7> shape{k, m / size_per_head, 2, size_per_head / 2 / 8, 2, 2, 2};
     //      dequant   transpose    quant
     // 0123456 -> 0123564 -> 0135642 -> 0135264
     permute_u4<0, 1, 3, 5, 2, 6, 4><<<512, 512, 0, st>>>(dst, src, shape);
-#endif
 }
 
 // [2, k, m/8] -> [k, m/8, 2]
 void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     Array<int, 6> shape{2, k, m / 8, 2, 2, 2};
     //     dequant   transpose   quant
     // 012345 -> 012453 -> 124530 -> 124053
     permute_u4<1, 2, 4, 0, 5, 3><<<512, 512, 0, st>>>(dst, src, shape);
-#endif
 }
 
 __global__ void dequantize_s4_kernel(uint4* dst, const uint* src, size_t count)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
         dst[i] = dequantize_s4_to_fp16x2(src[i]);
     }
-#endif
 }
 
 void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     dequantize_s4_kernel<<<512, 512>>>(dst, src, count);
-#endif
 }
 
 }  // namespace autoquant

+ 0 - 10
kernels/quantization/bitsandbytes/gemm_s4_f16.cu

@@ -97,7 +97,6 @@ struct OutputOps {
 template<typename T_BC, typename T_Q>
 void Impl<T_BC, T_Q>::Generate(std::vector<Kernels>& kernels)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     // smem size (KB):
     // sm75: 64
     // sm80: 163
@@ -128,7 +127,6 @@ void Impl<T_BC, T_Q>::Generate(std::vector<Kernels>& kernels)
     k.emplace_back(new GemmKernel<Shape<64, 16, 256>, Shape<32, 16, 32>, 3, GS, Op, T_BC, T_Q>{});
     k.emplace_back(new GemmKernel<Shape<64, 8, 256>, Shape<32, 8, 32>, 3, GS, Op, T_BC, T_Q>{});
     kernels.push_back(std::move(k));
-#endif
 }
 
 template<typename T_BC, typename T_Q>
@@ -145,7 +143,6 @@ void Impl<T_BC, T_Q>::Measure(T_BC*                 C,
                               cudaStream_t          st,
                               std::vector<Kernels>& _kernels)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     int gid = -1;
     for (size_t i = 0; i < group_sizes_.size(); ++i) {
         if (group_sizes_[i] == group_size) {
@@ -198,7 +195,6 @@ void Impl<T_BC, T_Q>::Measure(T_BC*                 C,
         tmp.push_back(metrics[indices[i]]);
     }
     metrics.swap(tmp);
-#endif
 }
 
 static bool Compare(const Metric& a, const Metric& b)
@@ -249,7 +245,6 @@ void Impl<T_BC, T_Q>::Run(T_BC*                 C,
                           cudaStream_t          st,
                           std::vector<Kernels>& kernels)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     for (size_t i = 0; i < group_sizes_.size(); ++i) {
         if (group_sizes_[i] == group_size) {
             if (algo_id < 0) {
@@ -264,7 +259,6 @@ void Impl<T_BC, T_Q>::Run(T_BC*                 C,
         }
     }
     throw std::runtime_error("unsupported group size");
-#endif
 }
 
 template<typename T_BC, typename T_Q>
@@ -308,9 +302,7 @@ void GemmS4F16<T_BC, T_Q>::Measure(T_BC*                C,
                                    std::vector<Metric>& metrics,
                                    cudaStream_t         st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     impl_->Measure(C, A, B, Q, m, n, k, group_size, type, metrics, st, impl_->kernels_);
-#endif
 }
 
 template<typename T_BC, typename T_Q>
@@ -326,9 +318,7 @@ void GemmS4F16<T_BC, T_Q>::Run(T_BC*        C,
                                int          algo_id,
                                cudaStream_t st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     impl_->Run(C, A, B, Q, m, n, k, group_size, type, algo_id, st, impl_->kernels_);
-#endif
 }
 
 template class GemmS4F16<half, half2>;

+ 0 - 4
kernels/quantization/bitsandbytes/int4_fp16_gemm_kernels.cu

@@ -23,7 +23,6 @@ void autoquant_convert_s4_k_m8(
   int m,
   int k,
   int group_size){
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
       auto st_ = _quant_scales.scalar_type();
       const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
       if(st_ == at::ScalarType::Half){
@@ -46,7 +45,6 @@ void autoquant_convert_s4_k_m8(
               aphrodite::autoquant::convert_s4_k_m8(weight_dest, quant_scales_zeros_dest, workspace, quant_weight_src, quant_scales, quant_zeros,
                             m, k, group_size, stream);
       }
-#endif
 }
 
 
@@ -55,7 +53,6 @@ torch::Tensor autoquant_s4_f16_gemm(
     torch::Tensor _kernel,
     torch::Tensor _scales_zeros)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     int num_in_feats = _in_feats.size(0);
     int num_in_channels = _in_feats.size(1);
     const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
@@ -111,5 +108,4 @@ torch::Tensor autoquant_s4_f16_gemm(
                          stream);
         return _out_feats;   
     }
-#endif
 }