#include #include #include #include #include // atomicAdd for double-precision floating-point numbers on hardware with // compute capability < 6.0 from: // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 // __device__ double atomicAdd( // double* address, // double val // ) { // unsigned long long int* address_as_ull = (unsigned long long int*)address; // unsigned long long int old = *address_as_ull, assumed; // // do { // assumed = old; // old = atomicCAS( // address_as_ull, // assumed, // __double_as_longlong(val + __longlong_as_double(assumed)) // ); // // // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) // } while (assumed != old); // // return __longlong_as_double(old); // } // #endif #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM) // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *address_as_ui; unsigned int assumed; do { assumed = old; unsigned short hsum = reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); hsum += val; old = reinterpret_cast(address) & 2 ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum; old = atomicCAS(address_as_ui, assumed, old); // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) } while (assumed != old); } __device__ __forceinline__ void atomicAdd(__half* address, c10::Half val) { unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); unsigned int old = *address_as_ui; unsigned int assumed; do { assumed = old; __half_raw hsum; hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); half tmpres = __hadd(hsum, val); hsum = __half_raw(tmpres); old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; old = atomicCAS(address_as_ui, assumed, old); } while (assumed != old); } #endif template __global__ void VecQuant2MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, int vec_height, int height, int width, int zero_width ); template __global__ void VecQuant3MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, int vec_height, int height, int width, int zero_width ); template __global__ void VecQuant4MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, int vec_height, int height, int width, int zero_width ); template __global__ void VecQuant8MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, int vec_height, int height, int width, int zero_width ); template __global__ void VecQuant2MatMulKernel_old( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); template __global__ void VecQuant3MatMulKernel_old( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); template __global__ void VecQuant4MatMulKernel_old( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); template __global__ void VecQuant8MatMulKernel_old( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); __global__ void VecQuant2MatMulKernelFaster_old( const half2* __restrict__ vec, const int* __restrict__ mat, float* __restrict__ mul, const float* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); __global__ void VecQuant3MatMulKernelFaster_old( const half2* __restrict__ vec, const int* __restrict__ mat, float* __restrict__ mul, const float* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); __global__ void VecQuant4MatMulKernelFaster_old( const half2* __restrict__ vec, const int* __restrict__ mat, float* __restrict__ mul, const float* __restrict__ scales, const int* __restrict__ zeros, int batch, int vec_height, int height, int width, int zero_width, int groupsize ); const int BLOCKWIDTH = 64; const int BLOCKHEIGHT2 = 4; const int BLOCKHEIGHT3 = 6; const int BLOCKHEIGHT4 = 8; const int BLOCKHEIGHT8 = 16; __device__ inline unsigned int as_unsigned(int i) { return *reinterpret_cast(&i); } __device__ inline int as_int(int i) { return *reinterpret_cast(&i); } void vecquant2matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros, torch::Tensor g_idx ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); dim3 blocks( (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, (width + BLOCKWIDTH - 1) / BLOCKWIDTH ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant2matmul_cuda", ([&] { VecQuant2MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) ); } template __global__ void VecQuant2MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, int vec_height, int height, int width, int zero_width ) { int h = BLOCKHEIGHT2 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = h * 16; int k; unsigned int g; scalar_t w_tmp; int z_w = w / 16; int z_mod = (w % 16) * 2; float weight[BLOCKWIDTH]; for (k = 0; k < BLOCKWIDTH; ++k){ int k_w = (k / 16); int k_bit = (k % 16) * 2; g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); weight[k] = scale * (w_tmp - zero); } scalar_t res; for (int b = 0; b < batch; ++b){ res = 0; blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); __syncthreads(); } } void vecquant3matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros, torch::Tensor g_idx ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); dim3 blocks( (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, (width + BLOCKWIDTH - 1) / BLOCKWIDTH ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant3matmul_cuda", ([&] { VecQuant3MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) ); } template __global__ void VecQuant3MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, int vec_height, int height, int width, int zero_width ) { int h = BLOCKHEIGHT3 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = (h / 3) * 32; int k; unsigned int g; scalar_t w_tmp; int z_w = (w / 32) * 3; int z_mod = w % 32; int z_bit; unsigned int z_tmp; if (z_mod != 10){ if (z_mod != 21){ z_bit = z_mod; if (z_bit > 21){ z_bit -= 22; z_bit *= 3; z_bit += 2; z_w += 2; } else if (z_bit > 10){ z_bit -= 11; z_bit *= 3; z_bit += 1; z_w += 1; } else { z_bit *= 3; } } else { z_w += 1; } } float weight[BLOCKWIDTH]; for (k = 0; k < BLOCKWIDTH; ++k){ int k_w = (k / 32) * 3; int k_mod = k % 32; int k_bit; if (k_mod != 10){ if (k_mod != 21){ k_bit = k_mod; if (k_bit > 21){ k_bit -= 22; k_bit *= 3; k_bit += 2; k_w += 2; } else if (k_bit > 10){ k_bit -= 11; k_bit *= 3; k_bit += 1; k_w += 1; } else { k_bit *= 3; } } else { k_w += 1; } } g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; scalar_t zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); zero = scalar_t((z_tmp) + 1); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); zero = scalar_t((z_tmp) + 1); } else { zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); } if (k_mod == 10) { w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4); } else if (k_mod == 21){ w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 31) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 1) & 0x6); } else { w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x7); } weight[k] = scale * (w_tmp - zero); } scalar_t res; for (int b = 0; b < batch; ++b){ res = 0; blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); __syncthreads(); } } void vecquant4matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros, torch::Tensor g_idx ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); dim3 blocks( (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant4matmul_cuda", ([&] { VecQuant4MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) ); } template __global__ void VecQuant4MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, int vec_height, int height, int width, int zero_width ) { int h = BLOCKHEIGHT4 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = h * 8; int k; unsigned int g; scalar_t w_tmp; int z_w = w / 8; int z_mod = (w % 8) * 4; float weight[BLOCKWIDTH]; for (k = 0; k < BLOCKWIDTH; ++k){ int k_w = (k / 8); int k_bit = (k % 8) * 4; g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); weight[k] = scale * (w_tmp - zero); } scalar_t res; for (int b = 0; b < batch; ++b){ res = 0; blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); __syncthreads(); } } void vecquant8matmul_cuda( torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor scales, torch::Tensor zeros, torch::Tensor g_idx ) { int batch = vec.size(0); int vec_height = vec.size(1); int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); dim3 blocks( (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, (width + BLOCKWIDTH - 1) / BLOCKWIDTH ); dim3 threads(BLOCKWIDTH); AT_DISPATCH_FLOATING_TYPES( vec.type(), "vecquant8matmul_cuda", ([&] { VecQuant8MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) ); } template __global__ void VecQuant8MatMulKernel( const scalar_t* __restrict__ vec, const int* __restrict__ mat, scalar_t* __restrict__ mul, const scalar_t* __restrict__ scales, const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, int vec_height, int height, int width, int zero_width ) { int h = BLOCKHEIGHT8 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = h * 4; int k; unsigned int g; scalar_t w_tmp; int z_w = w / 4; int z_mod = (w % 4) * 8; float weight[BLOCKWIDTH]; for (k = 0; k < BLOCKWIDTH; ++k){ int k_w = (k / 4); int k_bit = (k % 4) * 8; g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); weight[k] = scale * (w_tmp - zero); } scalar_t res; for (int b = 0; b < batch; ++b){ res = 0; blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); __syncthreads(); } }