Browse Source

feat: backport kernels (#235)

* backport all kernels

* bump gguf capability to sm_61

* bump torch in CI
AlpinDale 1 year ago
parent
commit
aebd68c632

+ 1 - 1
.github/workflows/publish.yml

@@ -49,7 +49,7 @@ jobs:
       matrix:
           os: ['ubuntu-20.04']
           python-version: ['3.8', '3.9', '3.10', '3.11']
-          pytorch-version: ['2.1.2']
+          pytorch-version: ['2.2.0']
           cuda-version: ['12.1']
 
     steps:

+ 1 - 1
aphrodite/modeling/layers/quantization/gguf.py

@@ -42,7 +42,7 @@ class GGUFConfig(QuantizationConfig):
         return [torch.half]
 
     def get_min_capability(self) -> int:
-        return 70
+        return 61
 
     @staticmethod
     def get_config_filenames() -> List[str]:

+ 1 - 1
build-linux-wheel.sh

@@ -1,3 +1,3 @@
 #!/bin/bash
-export TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
+export TORCH_CUDA_ARCH_LIST="6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
 ./runtime.sh python setup.py bdist_wheel

+ 28 - 0
kernels/quantization/marlin/marlin_cuda_kernel.cu

@@ -54,6 +54,7 @@ using FragS = Vec<half2, 1>; // quantization scales
 // Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that
 // are not multiples of 16.
 __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   const int BYTES = 16;
   uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
   asm volatile(
@@ -63,12 +64,14 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
     "   @p cp.async.cg.shared.global [%1], [%2], %3;\n"
     "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES)
   );
+#endif
 }
 
 // Asynchronous global->shared copy with a chache hint indicating that the values may be evicted immediately; used for
 // quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need
 // for inputs A and outputs C.
 __device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   const int BYTES = 16;
   uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
   asm volatile(
@@ -78,21 +81,27 @@ __device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
     "   cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
     "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)
   );
+#endif
 }
 
 // Async copy fence.
 __device__ inline void cp_async_fence() {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   asm volatile("cp.async.commit_group;\n" ::);
+#endif
 }
 
 // Wait until at most `n` async copy stages are still pending.
 template <int n>
 __device__ inline void cp_async_wait() {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   asm volatile("cp.async.wait_group %0;\n" :: "n"(n));
+#endif
 }
 
 // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation.
 __device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
   const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
   float* c = reinterpret_cast<float*>(&frag_c);
@@ -103,34 +112,40 @@ __device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag
     :  "r"(a[0]),  "r"(a[1]),  "r"(a[2]),  "r"(a[3]),  "r"(b[0]),  "r"(b[1]),
        "f"(c[0]),  "f"(c[1]),  "f"(c[2]),  "f"(c[3])
   );
+#endif
 }
 
 // Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout.
 __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
   uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
   asm volatile(
     "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
     : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)
   );
+#endif
 }
 
 // Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to
 // automatically recognize it in all cases.
 template <int lut>
 __device__ inline int lop3(int a, int b, int c) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   int res;
   asm volatile(
     "lop3.b32 %0, %1, %2, %3, %4;\n"
     : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)
   );
   return res;
+#endif
 }
 
 // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values.
 // We mostly follow the strategy in the link below, with some small changes:
 // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
 __device__ inline FragB dequant(int q) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   const int LO = 0x000f000f;
   const int HI = 0x00f000f0;
   const int EX = 0x64006400;
@@ -151,17 +166,21 @@ __device__ inline FragB dequant(int q) {
     *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD)
   );
   return frag_b;
+#endif
 }
 
 // Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization.
 __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
   frag_b[0] = __hmul2(frag_b[0], s);
   frag_b[1] = __hmul2(frag_b[1], s);
+#endif
 }
 
 // Wait until barrier reaches `count`, then lock for current threadblock.
 __device__ inline void barrier_acquire(int* lock, int count) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   if (threadIdx.x == 0) {
     int state = -1;
     do
@@ -170,10 +189,12 @@ __device__ inline void barrier_acquire(int* lock, int count) {
     while (state != count);
   }
   __syncthreads();
+#endif
 }
 
 // Release barrier and increment visitation count.
 __device__ inline void barrier_release(int* lock, bool reset = false) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   __syncthreads();
   if (threadIdx.x == 0) {
     if (reset) {
@@ -185,6 +206,7 @@ __device__ inline void barrier_release(int* lock, bool reset = false) {
     asm volatile ("fence.acq_rel.gpu;\n");
     asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val));
   }
+#endif
 }
 
 template <
@@ -205,6 +227,7 @@ __global__ void Marlin(
   int  prob_k, // reduction dimension k
   int* locks // extra global storage for barrier synchronization
 ) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple
   // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example:
   //   0 1 3
@@ -670,6 +693,7 @@ __global__ void Marlin(
       }
     }
   }
+#endif
 }
 
 
@@ -715,6 +739,7 @@ int marlin_cuda(
   int thread_n = -1,
   int sms = -1
 ) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   int tot_m = prob_m;
   int tot_m_blocks = ceildiv(tot_m, 16);
 
@@ -779,6 +804,7 @@ int marlin_cuda(
   }
 
   return ret;
+#endif
 }
 
 #endif
@@ -802,6 +828,7 @@ void marlin_gemm(
   const torch::Tensor& scales,
         torch::Tensor& workspace
 ) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   // thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1)
   int thread_k = -1;
   // thread_n: `n` size of a thread_tile in `weights` (can usually be left as auto -1)
@@ -840,4 +867,5 @@ void marlin_gemm(
       "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "."
     );
   }
+#endif
 }

+ 35 - 1
kernels/quantization/quip/origin_order.cu

@@ -1,7 +1,9 @@
 #include <cuda_bf16.h>
 #include <cuda_fp16.h>
 #include <cuda_runtime.h>
-#include <mma.h>
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 700
+  #include <mma.h>
+#endif
 
 #include <ATen/ATen.h>
 #include <ATen/core/Tensor.h>
@@ -103,6 +105,7 @@ static __device__ void store(
     int32_t nTile,
     int32_t laneId,
     const float4& out) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
 
   // sum.x / sum.y are written at
   // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
@@ -126,6 +129,7 @@ static __device__ void store(
   if (outRow + 8 < m) {
     *reinterpret_cast<half2*>(cPtr + 8 * n) = v23;
   }
+#endif
 }
 };
 
@@ -144,6 +148,7 @@ static __device__ void load(
     int32_t kTileStart,
     int32_t laneId,
     f16x2x2_u32 b[KTilesPerIteration]) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   auto Bptr = reinterpret_cast<const uint8_t*>(B);
   #pragma unroll
   for (int i = 0; i < KTilesPerIteration; ++i) {
@@ -151,6 +156,7 @@ static __device__ void load(
        const int col = (kTileStart + i) * kKTileSize / 4 + laneId % 4;
        *(reinterpret_cast<uint64_t*>(b[i].vals)) = CB[Bptr[row * k/4 + col]];
   }
+#endif
 }
 };
 
@@ -169,6 +175,7 @@ static __device__ void load(
     int32_t kTileStart,
     int32_t laneId,
     f16x2x2_u32 b[KTilesPerIteration]) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   auto Bptr = reinterpret_cast<const uint32_t*>(B);
   #pragma unroll
   for (int i = 0; i < KTilesPerIteration; ++i) {
@@ -189,6 +196,7 @@ static __device__ void load(
       *(half2*)(b[i].vals) = __hfma2(*((half2*)(&q0)), y16, z16);
       *(half2*)(b[i].vals+1) = __hfma2(*((half2*)(&q1)), y16, z16);
   }
+#endif
 }
 };
 
@@ -199,6 +207,7 @@ __device__ static inline uint64_t decode8weights(
     uint16_t weight_compressed,
     const int64_t *__restrict__ codebook_abs
 ) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
 
     uint8_t bits_sign = weight_compressed & 0xff;
     uint8_t parity = __popc(bits_sign) & 1;
@@ -215,6 +224,7 @@ __device__ static inline uint64_t decode8weights(
     packed -= parity * 0x0202020202020202;
 
     return packed;
+#endif
 }
 
 __device__ static inline uint32_t decode8weights(
@@ -222,6 +232,7 @@ __device__ static inline uint32_t decode8weights(
     const int64_t *__restrict__ codebook_abs,
     int idx
 ) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
     uint8_t bits_sign = weight_compressed & 0xff; //__brev(weight_compressed) >> 24;
     const uint32_t magic_nums[2] = {0x08040201ll, 0x80402010ll};
     uint8_t parity = __popc(bits_sign) & 1;
@@ -237,6 +248,7 @@ __device__ static inline uint32_t decode8weights(
     packed |= 0x01010101;
     packed -= parity * 0x02020202;
     return packed;
+#endif
 };
 
 template <int KTilesPerIteration>
@@ -251,6 +263,7 @@ static __device__ void load(
     int32_t kTileStart,
     int32_t laneId,
     f16x2x2_u32 b[KTilesPerIteration]) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   auto Bptr = (const uint16_t*) B;
   #pragma unroll
   for (int i = 0; i < KTilesPerIteration; ++i) {
@@ -275,6 +288,7 @@ static __device__ void load(
        //*((half*)(b[i].vals) + 2) = unpacked[1].x;
        //*((half*)(b[i].vals) + 3) = unpacked[1].y;
   }
+#endif
 }
 };
 
@@ -304,6 +318,7 @@ __launch_bounds__(256) void tinygemm_m16n8k16_chunk_kernel(
     int32_t mTiles,
     int32_t nTiles,
     int32_t kTiles) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   __shared__ uint64_t CB_[256];
   if (BLayout::use_codebook) {
     CB_[threadIdx.x + threadIdx.y * 32] = CB[threadIdx.x + threadIdx.y * 32];
@@ -444,12 +459,14 @@ __launch_bounds__(256) void tinygemm_m16n8k16_chunk_kernel(
         laneId,
         sum_f32);
   }
+#endif
 }
 
 at::Tensor d4_mm_origorder(
     const at::Tensor& A,
     const at::Tensor& B,
     const at::Tensor& CB) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   c10::cuda::CUDAGuard g(A.device());
   auto stream = at::cuda::getCurrentCUDAStream();
 
@@ -488,12 +505,14 @@ at::Tensor d4_mm_origorder(
       kTiles);
 
   return C_final;
+#endif
 }
 
 at::Tensor e8p_mm_origorder(
     const at::Tensor& A,
     const at::Tensor& B,
     const at::Tensor& CB) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   c10::cuda::CUDAGuard g(A.device());
   auto stream = at::cuda::getCurrentCUDAStream();
 
@@ -531,11 +550,13 @@ at::Tensor e8p_mm_origorder(
       kTiles);
 
   return C_final;
+#endif
 }
 
 at::Tensor hi_mm_origorder(
     const at::Tensor& A,
     const at::Tensor& B) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   c10::cuda::CUDAGuard g(A.device());
   auto stream = at::cuda::getCurrentCUDAStream();
 
@@ -573,6 +594,7 @@ at::Tensor hi_mm_origorder(
       kTiles);
 
   return C_final;
+#endif
 }
 
 #define DECOMPRESS_D4_BLOCK_SIZE 256
@@ -582,12 +604,14 @@ __global__ void cuda_decompress_d4_origorder_kernel(
     const c10::Half* __restrict__ CB,           // 256 x 4
     c10::Half* __restrict__ Y             // m x n
 ) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   const long i = threadIdx.x + DECOMPRESS_D4_BLOCK_SIZE * blockIdx.x;
 
   for(long r = 0; r < 4; r++) {
     uint8_t yidx = ((uint8_t*)YIs)[i*4 + r];
     ((uint64_t*)Y)[i*4 + r] = ((uint64_t*)CB)[yidx & 255];
   }
+#endif
 }
 
 
@@ -596,6 +620,7 @@ void decompress_d4_origorder(
     torch::Tensor CB,       // 256 x 4
     torch::Tensor Y         // m x n
 ) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   size_t m = Y.sizes()[0];
   size_t n = Y.sizes()[1];
 
@@ -615,6 +640,7 @@ void decompress_d4_origorder(
     CB.data_ptr<c10::Half>(),
     Y.data_ptr<c10::Half>()
   );
+#endif
 }
 
 #define DECOMPRESS_E8P_BLOCK_SIZE 256
@@ -624,6 +650,7 @@ __global__ void cuda_decompress_e8p_origorder_kernel(
     const int64_t* __restrict__ CB, // 256 x 8
     c10::Half* __restrict__ Y             // m x n
 ) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   const long i = threadIdx.x + DECOMPRESS_E8P_BLOCK_SIZE * blockIdx.x;
   uint16_t yidx = ((uint16_t*)YIs)[i];
   uint64_t decoded =  BLayout_E8::decode8weights(yidx, CB);
@@ -643,6 +670,7 @@ __global__ void cuda_decompress_e8p_origorder_kernel(
   ((__half2*)Y)[i*4+2] = __hadd2(unpacked[0][1], adjust); // 45
   ((__half2*)Y)[i*4+1] = __hadd2(unpacked[1][0], adjust); // 23
   ((__half2*)Y)[i*4+3] = __hadd2(unpacked[1][1], adjust); // 67
+#endif
 }
 
 
@@ -651,6 +679,7 @@ void decompress_e8p_origorder(
     torch::Tensor CB,       // 256 x 8
     torch::Tensor &Y         // m x n
 ) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   size_t m = Y.sizes()[0];
   size_t n = Y.sizes()[1];
 
@@ -670,6 +699,7 @@ void decompress_e8p_origorder(
     CB.data_ptr<int64_t>(),
     Y.data_ptr<c10::Half>()
   );
+#endif
 }
 
 #define DECOMPRESS_HI_BLOCK_SIZE 256
@@ -678,6 +708,7 @@ __global__ void cuda_decompress_hi_origorder_kernel(
     const uint32_t* __restrict__ YIs,	  // m x (n/8)
     c10::Half* __restrict__ Y             // m x n
 ) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   const long i = threadIdx.x + DECOMPRESS_HI_BLOCK_SIZE * blockIdx.x;
   uint32_t qa = YIs[i];
 
@@ -697,12 +728,14 @@ __global__ void cuda_decompress_hi_origorder_kernel(
   ((__half2*)Y)[i*4+1] = __hfma2(*((half2*)(&q1)), y16, z16);
   ((__half2*)Y)[i*4+2] = __hfma2(*((half2*)(&q2)), y16, z16);
   ((__half2*)Y)[i*4+3] = __hfma2(*((half2*)(&q3)), y16, z16);
+#endif
 }
 
 void decompress_hi_origorder(
     torch::Tensor YIs,      // m x (n/8)
     torch::Tensor Y         // m x n
 ){
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   size_t m = Y.sizes()[0];
   size_t n = Y.sizes()[1];
 
@@ -719,4 +752,5 @@ void decompress_hi_origorder(
     (uint32_t*)YIs.data_ptr<int32_t>(),
     Y.data_ptr<c10::Half>()
   );
+#endif
 }

+ 1 - 1
setup.py

@@ -20,7 +20,7 @@ MAIN_CUDA_VERSION = "12.1"
 
 # Supported NVIDIA GPU architectures.
 NVIDIA_SUPPORTED_ARCHS = {
-    "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "8.9", "9.0"
+    "6.1", "7.0", "7.5", "8.0", "8.6", "8.9", "9.0"
 }
 ROCM_SUPPORTED_ARCHS = {
     "gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"