Bläddra i källkod

fix: bump bnb kernels to sm_80 due to async stream copies

AlpinDale 1 år sedan
förälder
incheckning
49793d7c5a

+ 2 - 2
aphrodite/modeling/layers/quantization/bitsandbytes.py

@@ -55,8 +55,8 @@ class BitsandBytesConfig(QuantizationConfig):
         return [torch.half, torch.bfloat16]
 
     def get_min_capability(self) -> int:
-        # The BitsandBytes kernel only supports Turing or newer GPUs.
-        return 75
+        # The BitsandBytes kernel only supports Ampere or newer GPUs.
+        return 80
 
     def merge_weight(self) -> bool:
         return True

+ 15 - 15
kernels/quantization/bitsandbytes/format.cu

@@ -23,7 +23,7 @@ namespace autoquant {
 
 __device__ void atomic_assign_u4(uint32_t* address, uint32_t index, uint32_t value)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     uint32_t old = *address;
     uint32_t assumed;
     do {
@@ -36,7 +36,7 @@ __device__ void atomic_assign_u4(uint32_t* address, uint32_t index, uint32_t val
 
 __device__ uint32_t read_u4(const uint32_t* address, uint32_t index)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     return (*address >> (index * 4u)) & 0xfu;
 #endif
 }
@@ -44,7 +44,7 @@ __device__ uint32_t read_u4(const uint32_t* address, uint32_t index)
 template<int... Ds>
 __global__ void permute_u4(uint* dst, const uint* src, Array<int, sizeof...(Ds)> dims)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     constexpr int N = sizeof...(Ds);
 
     size_t count = 1;
@@ -82,7 +82,7 @@ __global__ void permute_u4(uint* dst, const uint* src, Array<int, sizeof...(Ds)>
 
 void reformat_s4_k8_m(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     // permutation for [k/8, m] layout
     Array<int, 10> shape{k / 32, 2, 2, m / 32, 2, 2, 8, 2, 2, 2};
     //        |warp|  lane  | 2x2 |  a0-7  |
@@ -92,7 +92,7 @@ void reformat_s4_k8_m(uint32_t* dst, const uint32_t* src, int m, int k, cudaStre
 
 void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     // permutation for [k, m/8] layout
     Array<int, 10> shape{k / 32, 2, 2, 4, 2, m / 32, 2, 2, 2, 4};
     //        |warp|  lane  | 2x2 |  a0-7  |
@@ -102,7 +102,7 @@ void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStre
 
 __global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t count)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
         dst[i] = dequantize_s4_to_fp16x2_v2(src[i]);
     }
@@ -111,7 +111,7 @@ __global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t
 
 __global__ void dequantize_s4_offset_64_bf16(uint4* dst, const uint32_t* src, size_t count)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
         dst[i] = dequantize_s4_to_bf16x2_v2(src[i]);
     }
@@ -120,7 +120,7 @@ __global__ void dequantize_s4_offset_64_bf16(uint4* dst, const uint32_t* src, si
 
 __global__ void merge_Q(half2* Q, const half* scales, const half* zeros, int count)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
         Q[i] = __halves2half2(zeros[i], scales[i]);
     }
@@ -129,7 +129,7 @@ __global__ void merge_Q(half2* Q, const half* scales, const half* zeros, int cou
 
 __global__ void merge_Q(__nv_bfloat162* Q, const __nv_bfloat16* scales, const __nv_bfloat16* zeros, int count)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
         Q[i] = 	halves2bfloat162(zeros[i], scales[i]);
     }
@@ -147,7 +147,7 @@ void convert_s4_k_m8(uint32_t*       A_dst,
                      int             group_size,
                      cudaStream_t    st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     dequantize_s4_offset_64<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8);
     merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m);
     reformat_s4_k_m8(A_dst, A_src, m, k, st);
@@ -164,7 +164,7 @@ void convert_s4_k_m8(uint32_t*            A_dst,
                      int                  group_size,
                      cudaStream_t         st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     dequantize_s4_offset_64_bf16<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8);
     merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m);
     reformat_s4_k_m8(A_dst, A_src, m, k, st);
@@ -173,7 +173,7 @@ void convert_s4_k_m8(uint32_t*            A_dst,
 
 void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     Array<int, 7> shape{k, m / size_per_head, 2, size_per_head / 2 / 8, 2, 2, 2};
     //      dequant   transpose    quant
     // 0123456 -> 0123564 -> 0135642 -> 0135264
@@ -184,7 +184,7 @@ void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, i
 // [2, k, m/8] -> [k, m/8, 2]
 void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     Array<int, 6> shape{2, k, m / 8, 2, 2, 2};
     //     dequant   transpose   quant
     // 012345 -> 012453 -> 124530 -> 124053
@@ -194,7 +194,7 @@ void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaSt
 
 __global__ void dequantize_s4_kernel(uint4* dst, const uint* src, size_t count)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
         dst[i] = dequantize_s4_to_fp16x2(src[i]);
     }
@@ -203,7 +203,7 @@ __global__ void dequantize_s4_kernel(uint4* dst, const uint* src, size_t count)
 
 void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     dequantize_s4_kernel<<<512, 512>>>(dst, src, count);
 #endif
 }

+ 5 - 5
kernels/quantization/bitsandbytes/gemm_s4_f16.cu

@@ -97,7 +97,7 @@ struct OutputOps {
 template<typename T_BC, typename T_Q>
 void Impl<T_BC, T_Q>::Generate(std::vector<Kernels>& kernels)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     // smem size (KB):
     // sm75: 64
     // sm80: 163
@@ -145,7 +145,7 @@ void Impl<T_BC, T_Q>::Measure(T_BC*                 C,
                               cudaStream_t          st,
                               std::vector<Kernels>& _kernels)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     int gid = -1;
     for (size_t i = 0; i < group_sizes_.size(); ++i) {
         if (group_sizes_[i] == group_size) {
@@ -249,7 +249,7 @@ void Impl<T_BC, T_Q>::Run(T_BC*                 C,
                           cudaStream_t          st,
                           std::vector<Kernels>& kernels)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     for (size_t i = 0; i < group_sizes_.size(); ++i) {
         if (group_sizes_[i] == group_size) {
             if (algo_id < 0) {
@@ -308,7 +308,7 @@ void GemmS4F16<T_BC, T_Q>::Measure(T_BC*                C,
                                    std::vector<Metric>& metrics,
                                    cudaStream_t         st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     impl_->Measure(C, A, B, Q, m, n, k, group_size, type, metrics, st, impl_->kernels_);
 #endif
 }
@@ -326,7 +326,7 @@ void GemmS4F16<T_BC, T_Q>::Run(T_BC*        C,
                                int          algo_id,
                                cudaStream_t st)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     impl_->Run(C, A, B, Q, m, n, k, group_size, type, algo_id, st, impl_->kernels_);
 #endif
 }

+ 2 - 2
kernels/quantization/bitsandbytes/int4_fp16_gemm_kernels.cu

@@ -23,7 +23,7 @@ void autoquant_convert_s4_k_m8(
   int m,
   int k,
   int group_size){
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
       auto st_ = _quant_scales.scalar_type();
       const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
       if(st_ == at::ScalarType::Half){
@@ -55,7 +55,7 @@ torch::Tensor autoquant_s4_f16_gemm(
     torch::Tensor _kernel,
     torch::Tensor _scales_zeros)
 {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     int num_in_feats = _in_feats.size(0);
     int num_in_channels = _in_feats.size(1);
     const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));