Browse Source

feat: support GPTQ 2, 3, and 8bit quants (#181)

* add kernels

* add in setup.py

* pack_factor -> storage bits size

* modify gptq to support different bits

* add triton kernels

* formatting

* pylint needs to go

* replace autogptq kernels with custom exl2 kernels

* fix compile erros

* fix formatting
AlpinDale 1 year ago
parent
commit
801eda0b7a

+ 2 - 2
aphrodite/modeling/layers/linear.py

@@ -275,8 +275,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
                 current_shard_offset += output_size
                 current_shard_offset += output_size
             packed_dim = getattr(param, "packed_dim", None)
             packed_dim = getattr(param, "packed_dim", None)
             for shard_id, shard_offset, shard_size in shard_offsets:
             for shard_id, shard_offset, shard_size in shard_offsets:
-                # If quantized, we need to adjust the offset and size to account
-                # for the packing.
+                # If quantized, we need to adjust the offset and size to
+                # account for the packing.
                 if packed_dim == output_dim:
                 if packed_dim == output_dim:
                     shard_size = shard_size // param.pack_factor
                     shard_size = shard_size // param.pack_factor
                     shard_offset = shard_offset // param.pack_factor
                     shard_offset = shard_offset // param.pack_factor

+ 16 - 10
aphrodite/modeling/layers/quantization/gptq.py

@@ -1,6 +1,7 @@
 import enum
 import enum
 from enum import Enum
 from enum import Enum
 from typing import Any, Dict, List, Optional
 from typing import Any, Dict, List, Optional
+from fractions import Fraction
 
 
 import torch
 import torch
 from torch.nn.parameter import Parameter
 from torch.nn.parameter import Parameter
@@ -14,6 +15,7 @@ from aphrodite.modeling.layers.quantization.base_config import (
 
 
 class GPTQConfig(QuantizationConfig):
 class GPTQConfig(QuantizationConfig):
     """Config class for GPTQ.
     """Config class for GPTQ.
+
     Reference: https://arxiv.org/abs/2210.17323
     Reference: https://arxiv.org/abs/2210.17323
     """
     """
 
 
@@ -26,12 +28,11 @@ class GPTQConfig(QuantizationConfig):
         self.weight_bits = weight_bits
         self.weight_bits = weight_bits
         self.group_size = group_size
         self.group_size = group_size
         self.desc_act = desc_act
         self.desc_act = desc_act
-        self.pack_factor = 32 // self.weight_bits
-        # exllama kernel v1 only supports 4 bit
-        if self.weight_bits != 4:
+        self.pack_factor = Fraction(32, self.weight_bits)
+        if self.weight_bits not in [2, 3, 4, 8]:
             raise ValueError(
             raise ValueError(
-                "Currently, only 4-bit weight quantization is supported for "
-                f"GPTQ, but got {self.weight_bits} bits.")
+                "Currently, only 2/3/4/8-bit weight quantization is supported "
+                f"for GPTQ, but got {self.weight_bits} bits.")
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return (f"GPTQConfig(weight_bits={self.weight_bits}, "
         return (f"GPTQConfig(weight_bits={self.weight_bits}, "
@@ -78,6 +79,7 @@ class ExllamaState(Enum):
 
 
 class GPTQLinearMethod(LinearMethodBase):
 class GPTQLinearMethod(LinearMethodBase):
     """Linear method for GPTQ.
     """Linear method for GPTQ.
+
     Args:
     Args:
         quant_config: The GPTQ quantization config.
         quant_config: The GPTQ quantization config.
     """
     """
@@ -99,11 +101,13 @@ class GPTQLinearMethod(LinearMethodBase):
                 "The input size is not aligned with the quantized "
                 "The input size is not aligned with the quantized "
                 "weight shape. This can be caused by too large "
                 "weight shape. This can be caused by too large "
                 "tensor parallel size.")
                 "tensor parallel size.")
-        if output_size_per_partition % self.quant_config.pack_factor != 0:
+        if (output_size_per_partition % self.quant_config.pack_factor.numerator
+                != 0):
             raise ValueError(
             raise ValueError(
                 "The output size is not aligned with the quantized "
                 "The output size is not aligned with the quantized "
                 "weight shape. This can be caused by too large "
                 "weight shape. This can be caused by too large "
                 "tensor parallel size.")
                 "tensor parallel size.")
+
         if self.quant_config.group_size != -1:
         if self.quant_config.group_size != -1:
             group_size = self.quant_config.group_size
             group_size = self.quant_config.group_size
         else:
         else:
@@ -194,8 +198,8 @@ class GPTQLinearMethod(LinearMethodBase):
         qweight = weights["qweight"]
         qweight = weights["qweight"]
         out_shape = x.shape[:-1] + (qweight.shape[-1], )
         out_shape = x.shape[:-1] + (qweight.shape[-1], )
         reshaped_x = x.reshape(-1, x.shape[-1])
         reshaped_x = x.reshape(-1, x.shape[-1])
-        # exllama needs to shuffle the weight after it's loaded
-        # here we do the shuffle on the first forward pass
+        # exllama needs to shuffle the weight after the weight is loaded
+        # here we do the shuffle on first forward pass
         if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
         if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
             if self.quant_config.desc_act:
             if self.quant_config.desc_act:
                 weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
                 weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
@@ -203,11 +207,13 @@ class GPTQLinearMethod(LinearMethodBase):
             else:
             else:
                 weights["g_idx"] = torch.empty((1, 1), device="meta")
                 weights["g_idx"] = torch.empty((1, 1), device="meta")
             weights["exllama_state"] = ExllamaState.READY
             weights["exllama_state"] = ExllamaState.READY
-            quantization_ops.gptq_shuffle(weights["qweight"], weights["g_idx"])
+            quantization_ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
+                                          self.quant_config.weight_bits)
         output = quantization_ops.gptq_gemm(
         output = quantization_ops.gptq_gemm(
             reshaped_x, weights["qweight"], weights["qzeros"],
             reshaped_x, weights["qweight"], weights["qzeros"],
             weights["scales"], weights["g_idx"],
             weights["scales"], weights["g_idx"],
-            weights["exllama_state"] == ExllamaState.READY)
+            weights["exllama_state"] == ExllamaState.READY,
+            self.quant_config.weight_bits)
         if bias is not None:
         if bias is not None:
             output = output + bias
             output = output + bias
         return output.reshape(out_shape)
         return output.reshape(out_shape)

+ 4 - 2
kernels/ops.h

@@ -84,9 +84,11 @@ torch::Tensor gptq_gemm(
   torch::Tensor b_gptq_qzeros,
   torch::Tensor b_gptq_qzeros,
   torch::Tensor b_gptq_scales,
   torch::Tensor b_gptq_scales,
   torch::Tensor b_g_idx,
   torch::Tensor b_g_idx,
-  bool use_exllama);
+  bool use_exllama,
+  int bit);
 
 
 void gptq_shuffle(
 void gptq_shuffle(
   torch::Tensor q_weight,
   torch::Tensor q_weight,
-  torch::Tensor q_perm);
+  torch::Tensor q_perm,
+  int bit);
   
   

+ 72 - 0
kernels/quantization/gptq/autogptq_cuda_256.cpp

@@ -0,0 +1,72 @@
+#include <torch/all.h>
+#include <torch/python.h>
+#include <c10/cuda/CUDAGuard.h>
+
+void vecquant2matmul_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+);
+
+void vecquant2matmul(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
+}
+
+void vecquant3matmul_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+);
+
+void vecquant3matmul(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
+}
+
+void vecquant4matmul_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+);
+
+void vecquant4matmul(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
+}
+
+void vecquant8matmul_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+);
+
+void vecquant8matmul(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
+
+}

+ 71 - 0
kernels/quantization/gptq/autogptq_cuda_64.cpp

@@ -0,0 +1,71 @@
+#include <torch/all.h>
+#include <torch/python.h>
+#include <c10/cuda/CUDAGuard.h>
+
+void vecquant2matmul_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+);
+
+void vecquant2matmul(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
+}
+
+void vecquant3matmul_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+);
+
+void vecquant3matmul(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
+}
+
+void vecquant4matmul_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+);
+
+void vecquant4matmul(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
+}
+
+void vecquant8matmul_cuda(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+);
+
+void vecquant8matmul(
+  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+  torch::Tensor scales, torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
+  vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
+  m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)");
+}

+ 654 - 0
kernels/quantization/gptq/autogptq_cuda_kernel_256.cu

@@ -0,0 +1,654 @@
+#include <torch/all.h>
+#include <torch/python.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <cuda_fp16.h>
+
+// atomicAdd for double-precision floating-point numbers on hardware with
+// compute capability < 6.0 from:
+// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions
+// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
+// __device__ double atomicAdd(
+//     double* address,
+//     double val
+// ) {
+//   unsigned long long int* address_as_ull = (unsigned long long int*)address;
+//   unsigned long long int old = *address_as_ull, assumed;
+//
+//   do {
+//     assumed = old;
+//     old = atomicCAS(
+//       address_as_ull,
+//       assumed,
+//       __double_as_longlong(val + __longlong_as_double(assumed))
+//     );
+//
+//   // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
+//   } while (assumed != old);
+//
+//   return __longlong_as_double(old);
+// }
+// #endif
+
+#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM)
+// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
+
+__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
+    unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
+    unsigned int old = *address_as_ui;
+    unsigned int assumed;
+
+    do {
+        assumed = old;
+        unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
+        hsum += val;
+        old = reinterpret_cast<size_t>(address) & 2
+                 ? (old & 0xffff) | (hsum << 16)
+                 : (old & 0xffff0000) | hsum;
+        old = atomicCAS(address_as_ui, assumed, old);
+
+    // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
+    } while (assumed != old);
+}
+__device__ __forceinline__ void atomicAdd(__half* address, c10::Half val) {
+    unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
+    unsigned int old = *address_as_ui;
+    unsigned int assumed;
+
+    do {
+        assumed = old;
+        __half_raw hsum;
+        hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+        half tmpres = __hadd(hsum, val);
+        hsum = __half_raw(tmpres);
+        old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
+        old = atomicCAS(address_as_ui, assumed, old);
+    } while (assumed != old);
+}
+#endif
+
+
+template <typename scalar_t>
+__global__ void VecQuant2MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  		int* __restrict__ zeros,
+	const  	    int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+);
+
+template <typename scalar_t>
+__global__ void VecQuant3MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  		int* __restrict__ zeros,
+	const  	    int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+);
+
+template <typename scalar_t>
+__global__ void VecQuant4MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  		int* __restrict__ zeros,
+	const  	    int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+);
+
+template <typename scalar_t>
+__global__ void VecQuant8MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  		int* __restrict__ zeros,
+	const  	    int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+);
+
+template <typename scalar_t>
+__global__ void VecQuant2MatMulKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  		int* __restrict__ zeros,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+template <typename scalar_t>
+__global__ void VecQuant3MatMulKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  	int* __restrict__ zeros,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+template <typename scalar_t>
+__global__ void VecQuant4MatMulKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  	int* __restrict__ zeros,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+template <typename scalar_t>
+__global__ void VecQuant8MatMulKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  	int* __restrict__ zeros,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+__global__ void VecQuant2MatMulKernelFaster_old(
+    const  half2* __restrict__ vec,
+    const    int* __restrict__ mat,
+           float* __restrict__ mul,
+    const  float* __restrict__ scales,
+    const    int* __restrict__ zeros,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+__global__ void VecQuant3MatMulKernelFaster_old(
+    const  half2* __restrict__ vec,
+    const    int* __restrict__ mat,
+           float* __restrict__ mul,
+    const  float* __restrict__ scales,
+    const    int* __restrict__ zeros,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+__global__ void VecQuant4MatMulKernelFaster_old(
+    const  half2* __restrict__ vec,
+    const    int* __restrict__ mat,
+           float* __restrict__ mul,
+    const  float* __restrict__ scales,
+    const    int* __restrict__ zeros,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+
+const int BLOCKWIDTH  = 256;
+const int BLOCKHEIGHT2 =  16;
+const int BLOCKHEIGHT3 =  24;
+const int BLOCKHEIGHT4 =  32;
+const int BLOCKHEIGHT8 =  64;
+
+__device__ inline unsigned int as_unsigned(int i) {
+  return *reinterpret_cast<unsigned int*>(&i);
+}
+
+__device__ inline int as_int(int i) {
+  return *reinterpret_cast<int*>(&i);
+}
+
+
+void vecquant2matmul_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  int batch = vec.size(0);
+  int vec_height = vec.size(1);
+  int height = mat.size(0);
+  int width = mat.size(1);
+  int zero_width = zeros.size(1);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant2matmul_cuda", ([&] {
+      VecQuant2MatMulKernel<<<blocks, threads>>>(
+        vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
+        scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
+        batch, vec_height, height, width, zero_width
+      );
+    })
+  );
+}
+
+template <typename scalar_t>
+__global__ void VecQuant2MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  		int* __restrict__ zeros,
+    const   	int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+) {
+  int h = BLOCKHEIGHT2 * blockIdx.x;
+  int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int i = width * h + w;
+  int g_h = h * 16;
+  int k;
+  unsigned int g;
+  scalar_t w_tmp;
+
+  int z_w = w / 16;
+  int z_mod = (w % 16) * 2;
+
+  float weight[BLOCKWIDTH];
+
+  for (k = 0; k <  BLOCKWIDTH; ++k){
+	int k_w = (k / 16);
+	int k_bit = (k % 16) * 2;
+
+    g = as_int(g_idx[g_h + k]);
+    scalar_t scale = scales[g * width + w];
+    scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
+
+    w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3);
+
+	weight[k] = scale * (w_tmp - zero);
+  }
+
+  scalar_t res;
+  for (int b = 0; b < batch; ++b){
+	res = 0;
+
+    blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
+    __syncthreads();
+	for (k = 0; k <  BLOCKWIDTH; ++k){
+	  res += weight[k] * blockvec[k];
+    }
+    atomicAdd(&mul[b * width + w], res);
+    __syncthreads();
+  }
+}
+
+void vecquant3matmul_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  int batch = vec.size(0);
+  int vec_height = vec.size(1);
+  int height = mat.size(0);
+  int width = mat.size(1);
+  int zero_width = zeros.size(1);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant3matmul_cuda", ([&] {
+      VecQuant3MatMulKernel<<<blocks, threads>>>(
+        vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
+        scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
+        batch, vec_height, height, width, zero_width
+      );
+    })
+  );
+}
+
+template <typename scalar_t>
+__global__ void VecQuant3MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    const   	int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+) {
+  int h = BLOCKHEIGHT3 * blockIdx.x;
+  int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int i = width * h + w;
+  int g_h = (h / 3) * 32;
+  int k;
+  unsigned int g;
+  scalar_t w_tmp;
+
+  int z_w = (w / 32) * 3;
+  int z_mod = w % 32;
+  int z_bit;
+  unsigned int z_tmp;
+  if (z_mod != 10){
+    if (z_mod != 21){
+      z_bit = z_mod;
+      if (z_bit > 21){
+        z_bit -= 22;
+        z_bit *= 3;
+        z_bit += 2;
+        z_w += 2;
+      } else if (z_bit > 10){
+        z_bit -= 11;
+        z_bit *= 3;
+        z_bit += 1;
+        z_w += 1;
+      } else {
+        z_bit *= 3;
+      }
+    } else {
+      z_w += 1;
+    }
+  }
+
+  float weight[BLOCKWIDTH];
+
+  for (k = 0; k <  BLOCKWIDTH; ++k){
+	int k_w = (k / 32) * 3;
+	int k_mod = k % 32;
+	int k_bit;
+
+	if (k_mod != 10){
+	  if (k_mod != 21){
+        k_bit = k_mod;
+        if (k_bit > 21){
+		  k_bit -= 22;
+		  k_bit *= 3;
+		  k_bit += 2;
+		  k_w += 2;
+        } else if (k_bit > 10){
+		  k_bit -= 11;
+		  k_bit *= 3;
+		  k_bit += 1;
+		  k_w += 1;
+        } else {
+		  k_bit *= 3;
+        }
+	  } else {
+        k_w += 1;
+	  }
+	}
+
+    g = as_int(g_idx[g_h + k]);
+    scalar_t scale = scales[g * width + w];
+    scalar_t zero;
+    if (z_mod == 10) {
+      z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
+      zero = scalar_t((z_tmp) + 1);
+    } else if (z_mod == 21){
+      z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
+      zero = scalar_t((z_tmp) + 1);
+    } else {
+      zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
+    }
+
+    if (k_mod == 10) {
+      w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4);
+    } else if (k_mod == 21){
+      w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 31) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 1) & 0x6);
+    } else {
+      w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x7);
+    }
+	weight[k] = scale * (w_tmp - zero);
+  }
+
+  scalar_t res;
+  for (int b = 0; b < batch; ++b){
+	res = 0;
+
+    blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
+    __syncthreads();
+	for (k = 0; k <  BLOCKWIDTH; ++k){
+	  res += weight[k] * blockvec[k];
+    }
+    atomicAdd(&mul[b * width + w], res);
+    __syncthreads();
+  }
+}
+
+void vecquant4matmul_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  int batch = vec.size(0);
+  int vec_height = vec.size(1);
+  int height = mat.size(0);
+  int width = mat.size(1);
+  int zero_width = zeros.size(1);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant4matmul_cuda", ([&] {
+      VecQuant4MatMulKernel<<<blocks, threads>>>(
+        vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
+        scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
+        batch, vec_height, height, width, zero_width
+      );
+    })
+  );
+}
+
+template <typename scalar_t>
+__global__ void VecQuant4MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    const   	int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+) {
+  int h = BLOCKHEIGHT4 * blockIdx.x;
+  int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int i = width * h + w;
+  int g_h = h * 8;
+  int k;
+  unsigned int g;
+  scalar_t w_tmp;
+
+
+  int z_w = w / 8;
+  int z_mod = (w % 8) * 4;
+
+  float weight[BLOCKWIDTH];
+
+  for (k = 0; k <  BLOCKWIDTH; ++k){
+	int k_w = (k / 8);
+	int k_bit = (k % 8) * 4;
+
+    g = as_int(g_idx[g_h + k]);
+    scalar_t scale = scales[g * width + w];
+    scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
+
+    w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
+
+	weight[k] = scale * (w_tmp - zero);
+  }
+
+  scalar_t res;
+  for (int b = 0; b < batch; ++b){
+	res = 0;
+
+    blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
+    __syncthreads();
+	for (k = 0; k <  BLOCKWIDTH; ++k){
+	  res += weight[k] * blockvec[k];
+    }
+    atomicAdd(&mul[b * width + w], res);
+    __syncthreads();
+  }
+}
+
+void vecquant8matmul_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  int batch = vec.size(0);
+  int vec_height = vec.size(1);
+  int height = mat.size(0);
+  int width = mat.size(1);
+  int zero_width = zeros.size(1);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant8matmul_cuda", ([&] {
+      VecQuant8MatMulKernel<<<blocks, threads>>>(
+        vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
+        scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
+        batch, vec_height, height, width, zero_width
+      );
+    })
+  );
+}
+
+template <typename scalar_t>
+__global__ void VecQuant8MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    const   	int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+) {
+  int h = BLOCKHEIGHT8 * blockIdx.x;
+  int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
+
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int i = width * h + w;
+  int g_h = h * 4;
+  int k;
+  unsigned int g;
+  scalar_t w_tmp;
+
+  int z_w = w / 4;
+  int z_mod = (w % 4) * 8;
+
+  float weight[BLOCKWIDTH];
+
+  for (k = 0; k <  BLOCKWIDTH; ++k){
+	int k_w = (k / 4);
+	int k_bit = (k % 4) * 8;
+
+    g = as_int(g_idx[g_h + k]);
+    scalar_t scale = scales[g * width + w];
+    scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
+
+    w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
+
+	weight[k] = scale * (w_tmp - zero);
+  }
+
+  scalar_t res;
+  for (int b = 0; b < batch; ++b){
+	res = 0;
+
+    blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
+    __syncthreads();
+	for (k = 0; k <  BLOCKWIDTH; ++k){
+	  res += weight[k] * blockvec[k];
+    }
+    atomicAdd(&mul[b * width + w], res);
+    __syncthreads();
+  }
+}

+ 655 - 0
kernels/quantization/gptq/autogptq_cuda_kernel_64.cu

@@ -0,0 +1,655 @@
+#include <torch/all.h>
+#include <torch/python.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <cuda_fp16.h>
+
+// atomicAdd for double-precision floating-point numbers on hardware with
+// compute capability < 6.0 from:
+// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions
+// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
+// __device__ double atomicAdd(
+//     double* address,
+//     double val
+// ) {
+//   unsigned long long int* address_as_ull = (unsigned long long int*)address;
+//   unsigned long long int old = *address_as_ull, assumed;
+//
+//   do {
+//     assumed = old;
+//     old = atomicCAS(
+//       address_as_ull,
+//       assumed,
+//       __double_as_longlong(val + __longlong_as_double(assumed))
+//     );
+//
+//   // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
+//   } while (assumed != old);
+//
+//   return __longlong_as_double(old);
+// }
+// #endif
+
+
+#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM)
+// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
+__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
+    unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
+    unsigned int old = *address_as_ui;
+    unsigned int assumed;
+
+    do {
+        assumed = old;
+        unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
+        hsum += val;
+        old = reinterpret_cast<size_t>(address) & 2
+                 ? (old & 0xffff) | (hsum << 16)
+                 : (old & 0xffff0000) | hsum;
+        old = atomicCAS(address_as_ui, assumed, old);
+
+    // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
+    } while (assumed != old);
+}
+__device__ __forceinline__ void atomicAdd(__half* address, c10::Half val) {
+    unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
+    unsigned int old = *address_as_ui;
+    unsigned int assumed;
+
+    do {
+        assumed = old;
+        __half_raw hsum;
+        hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+        half tmpres = __hadd(hsum, val);
+        hsum = __half_raw(tmpres);
+        old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
+        old = atomicCAS(address_as_ui, assumed, old);
+    } while (assumed != old);
+}
+#endif
+
+
+template <typename scalar_t>
+__global__ void VecQuant2MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  		int* __restrict__ zeros,
+	const  	    int* __restrict__ g_idx,
+    int batch,
+    int vec_height, 	
+    int height,
+    int width,
+	int zero_width
+);
+
+template <typename scalar_t>
+__global__ void VecQuant3MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  		int* __restrict__ zeros,
+	const  	    int* __restrict__ g_idx,
+    int batch,
+    int vec_height, 	
+    int height,
+    int width,
+	int zero_width
+);
+
+template <typename scalar_t>
+__global__ void VecQuant4MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  		int* __restrict__ zeros,
+	const  	    int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+);
+
+
+template <typename scalar_t>
+__global__ void VecQuant8MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  		int* __restrict__ zeros,
+	const  	    int* __restrict__ g_idx,
+    int batch,
+    int vec_height, 	
+    int height,
+    int width,
+	int zero_width
+);
+
+template <typename scalar_t>
+__global__ void VecQuant2MatMulKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  		int* __restrict__ zeros,
+    int batch,
+    int vec_height, 	
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+template <typename scalar_t>
+__global__ void VecQuant3MatMulKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  	int* __restrict__ zeros,
+    int batch,
+    int vec_height, 	
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+template <typename scalar_t>
+__global__ void VecQuant4MatMulKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  	int* __restrict__ zeros,
+    int batch,
+    int vec_height, 	
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+template <typename scalar_t>
+__global__ void VecQuant8MatMulKernel_old(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  	int* __restrict__ zeros,
+    int batch,
+    int vec_height, 	
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+__global__ void VecQuant2MatMulKernelFaster_old(
+    const  half2* __restrict__ vec,
+    const    int* __restrict__ mat,
+           float* __restrict__ mul,
+    const  float* __restrict__ scales,
+    const    int* __restrict__ zeros,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+__global__ void VecQuant3MatMulKernelFaster_old(
+    const  half2* __restrict__ vec,
+    const    int* __restrict__ mat,
+           float* __restrict__ mul,
+    const  float* __restrict__ scales,
+    const    int* __restrict__ zeros,
+    int batch,
+    int vec_height, 	
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+__global__ void VecQuant4MatMulKernelFaster_old(
+    const  half2* __restrict__ vec,
+    const    int* __restrict__ mat,
+           float* __restrict__ mul,
+    const  float* __restrict__ scales,
+    const    int* __restrict__ zeros,
+    int batch,
+    int vec_height, 	
+    int height,
+    int width,
+    int zero_width,
+    int groupsize
+);
+
+
+const int BLOCKWIDTH  = 64;
+const int BLOCKHEIGHT2 =  4;
+const int BLOCKHEIGHT3 =  6;
+const int BLOCKHEIGHT4 =  8;
+const int BLOCKHEIGHT8 =  16;
+
+__device__ inline unsigned int as_unsigned(int i) {
+  return *reinterpret_cast<unsigned int*>(&i);
+}
+
+__device__ inline int as_int(int i) {
+  return *reinterpret_cast<int*>(&i);
+}
+
+
+void vecquant2matmul_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  int batch = vec.size(0);
+  int vec_height = vec.size(1);
+  int height = mat.size(0);
+  int width = mat.size(1);
+  int zero_width = zeros.size(1);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant2matmul_cuda", ([&] {
+      VecQuant2MatMulKernel<<<blocks, threads>>>(
+        vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
+        scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(), 
+        batch, vec_height, height, width, zero_width
+      );
+    })
+  );
+}
+
+template <typename scalar_t>
+__global__ void VecQuant2MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const  		int* __restrict__ zeros,
+    const   	int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+) {
+  int h = BLOCKHEIGHT2 * blockIdx.x;
+  int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
+  
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int i = width * h + w;
+  int g_h = h * 16;
+  int k;
+  unsigned int g;
+  scalar_t w_tmp;
+  
+  int z_w = w / 16; 
+  int z_mod = (w % 16) * 2;
+  
+  float weight[BLOCKWIDTH];
+  
+  for (k = 0; k <  BLOCKWIDTH; ++k){	
+	int k_w = (k / 16); 
+	int k_bit = (k % 16) * 2;
+	
+    g = as_int(g_idx[g_h + k]);
+    scalar_t scale = scales[g * width + w];
+    scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
+	
+    w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3);
+    
+	weight[k] = scale * (w_tmp - zero);
+  }
+
+  scalar_t res;
+  for (int b = 0; b < batch; ++b){	
+	res = 0;
+	
+    blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
+    __syncthreads();
+	for (k = 0; k <  BLOCKWIDTH; ++k){	
+	  res += weight[k] * blockvec[k];
+    }
+    atomicAdd(&mul[b * width + w], res);
+    __syncthreads();
+  }
+}
+
+void vecquant3matmul_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  int batch = vec.size(0);
+  int vec_height = vec.size(1);
+  int height = mat.size(0);
+  int width = mat.size(1);
+  int zero_width = zeros.size(1);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant3matmul_cuda", ([&] {
+      VecQuant3MatMulKernel<<<blocks, threads>>>(
+        vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
+        scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(), 
+        batch, vec_height, height, width, zero_width
+      );
+    })
+  );
+}
+
+template <typename scalar_t>
+__global__ void VecQuant3MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    const   	int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+) {
+  int h = BLOCKHEIGHT3 * blockIdx.x;
+  int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
+  
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int i = width * h + w;
+  int g_h = (h / 3) * 32;
+  int k;
+  unsigned int g;
+  scalar_t w_tmp;
+  
+  int z_w = (w / 32) * 3; 
+  int z_mod = w % 32;
+  int z_bit;
+  unsigned int z_tmp;
+  if (z_mod != 10){
+    if (z_mod != 21){
+      z_bit = z_mod;
+      if (z_bit > 21){
+        z_bit -= 22;
+        z_bit *= 3;
+        z_bit += 2;
+        z_w += 2;
+      } else if (z_bit > 10){
+        z_bit -= 11;
+        z_bit *= 3;
+        z_bit += 1;
+        z_w += 1;
+      } else {
+        z_bit *= 3;
+      }
+    } else {
+      z_w += 1;
+    }
+  }
+  
+  float weight[BLOCKWIDTH];
+  
+  for (k = 0; k <  BLOCKWIDTH; ++k){	
+	int k_w = (k / 32) * 3; 
+	int k_mod = k % 32;
+	int k_bit;
+	  
+	if (k_mod != 10){
+	  if (k_mod != 21){
+        k_bit = k_mod;
+        if (k_bit > 21){
+		  k_bit -= 22;
+		  k_bit *= 3;
+		  k_bit += 2;
+		  k_w += 2;
+        } else if (k_bit > 10){
+		  k_bit -= 11;
+		  k_bit *= 3;
+		  k_bit += 1;
+		  k_w += 1;
+        } else {
+		  k_bit *= 3;
+        }
+	  } else {
+        k_w += 1;
+	  }
+	}
+	
+    g = as_int(g_idx[g_h + k]);
+    scalar_t scale = scales[g * width + w];
+    scalar_t zero;
+    if (z_mod == 10) {
+      z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
+      zero = scalar_t((z_tmp) + 1);
+    } else if (z_mod == 21){
+      z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
+      zero = scalar_t((z_tmp) + 1);
+    } else {
+      zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
+    }
+	
+    if (k_mod == 10) {
+      w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4);
+    } else if (k_mod == 21){
+      w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 31) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 1) & 0x6);
+    } else {
+      w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x7);
+    }
+	weight[k] = scale * (w_tmp - zero);
+  }
+
+  scalar_t res;
+  for (int b = 0; b < batch; ++b){	
+	res = 0;
+	
+    blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
+    __syncthreads();
+	for (k = 0; k <  BLOCKWIDTH; ++k){	
+	  res += weight[k] * blockvec[k];
+    }
+    atomicAdd(&mul[b * width + w], res);
+    __syncthreads();
+  }
+}
+
+void vecquant4matmul_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  int batch = vec.size(0);
+  int vec_height = vec.size(1);
+  int height = mat.size(0);
+  int width = mat.size(1);
+  int zero_width = zeros.size(1);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant4matmul_cuda", ([&] {
+      VecQuant4MatMulKernel<<<blocks, threads>>>(
+        vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
+        scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(), 
+        batch, vec_height, height, width, zero_width
+      );
+    })
+  );
+}
+
+template <typename scalar_t>
+__global__ void VecQuant4MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    const   	int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+) {
+  int h = BLOCKHEIGHT4 * blockIdx.x;
+  int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
+  
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int i = width * h + w;
+  int g_h = h * 8;
+  int k;
+  unsigned int g;
+  scalar_t w_tmp;
+  
+
+  int z_w = w / 8; 
+  int z_mod = (w % 8) * 4;
+  
+  float weight[BLOCKWIDTH];
+  
+  for (k = 0; k <  BLOCKWIDTH; ++k){	
+	int k_w = (k / 8); 
+	int k_bit = (k % 8) * 4;
+	
+    g = as_int(g_idx[g_h + k]);
+    scalar_t scale = scales[g * width + w];
+    scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
+	
+    w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
+    
+	weight[k] = scale * (w_tmp - zero);
+  }
+
+  scalar_t res;
+  for (int b = 0; b < batch; ++b){	
+	res = 0;
+	
+    blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
+    __syncthreads();
+	for (k = 0; k <  BLOCKWIDTH; ++k){	
+	  res += weight[k] * blockvec[k];
+    }
+    atomicAdd(&mul[b * width + w], res);
+    __syncthreads();
+  }
+}
+
+void vecquant8matmul_cuda(
+  torch::Tensor vec,
+  torch::Tensor mat,
+  torch::Tensor mul,
+  torch::Tensor scales,
+  torch::Tensor zeros,
+  torch::Tensor g_idx
+) {
+  int batch = vec.size(0);
+  int vec_height = vec.size(1);
+  int height = mat.size(0);
+  int width = mat.size(1);
+  int zero_width = zeros.size(1);
+
+  dim3 blocks(
+    (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
+    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
+  );
+  dim3 threads(BLOCKWIDTH);
+
+  AT_DISPATCH_FLOATING_TYPES(
+    vec.type(), "vecquant8matmul_cuda", ([&] {
+      VecQuant8MatMulKernel<<<blocks, threads>>>(
+        vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
+        scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(), 
+        batch, vec_height, height, width, zero_width
+      );
+    })
+  );
+}
+
+template <typename scalar_t>
+__global__ void VecQuant8MatMulKernel(
+    const  scalar_t* __restrict__ vec,
+    const       int* __restrict__ mat,
+           scalar_t* __restrict__ mul,
+    const  scalar_t* __restrict__ scales,
+    const       int* __restrict__ zeros,
+    const   	int* __restrict__ g_idx,
+    int batch,
+    int vec_height,
+    int height,
+    int width,
+	int zero_width
+) {
+  int h = BLOCKHEIGHT8 * blockIdx.x;
+  int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
+  
+  __shared__ scalar_t blockvec[BLOCKWIDTH];
+  int i = width * h + w;
+  int g_h = h * 4;
+  int k;
+  unsigned int g;
+  scalar_t w_tmp;
+  
+  int z_w = w / 4; 
+  int z_mod = (w % 4) * 8;
+  
+  float weight[BLOCKWIDTH];
+  
+  for (k = 0; k <  BLOCKWIDTH; ++k){	
+	int k_w = (k / 4); 
+	int k_bit = (k % 4) * 8;
+	
+    g = as_int(g_idx[g_h + k]);
+    scalar_t scale = scales[g * width + w];
+    scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
+	
+    w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
+    
+	weight[k] = scale * (w_tmp - zero);
+  }
+
+  scalar_t res;
+  for (int b = 0; b < batch; ++b){	
+	res = 0;
+	
+    blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
+    __syncthreads();
+	for (k = 0; k <  BLOCKWIDTH; ++k){	
+	  res += weight[k] * blockvec[k];
+    }
+    atomicAdd(&mul[b * width + w], res);
+    __syncthreads();
+  }
+}

+ 123 - 0
kernels/quantization/gptq/matrix_view.cuh

@@ -146,6 +146,129 @@ public:
     __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
     __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
 };
 };
 
 
+class MatrixView_q2_row
+{
+public:
+    const uint32_t* data;
+    const int height;
+    const int width;
+
+    __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
+        : data(data), height(height), width(width)
+    { }
+
+    __device__ __forceinline__ int item(int row, int column) const
+    {
+        int shift = (column & 0x0f) * 2;
+        return (data[row * width / 16 + column / 16] >> shift) & 0x03;
+    }
+
+    __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
+    {
+        int shift = (column & 0x0f) * 2;
+        uint32_t d = data[row * width / 16 + column / 16] >> shift;
+        items[0] = d & 0x03;
+        items[1] = (d >> 2) & 0x03;
+    }
+
+    __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
+    {
+        int shift = (column & 0x0f) * 2;
+        uint32_t d = data[row * width / 16 + column / 16] >> shift;
+        items[0] = d & 0x03;
+        items[1] = (d >> 2) & 0x03;
+        items[2] = (d >> 4) & 0x03;
+        items[3] = (d >> 6) & 0x03;
+    }
+};
+
+class MatrixView_q3_row
+{
+public:
+    const uint32_t* data;
+    const int height;
+    const int width;
+
+    __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
+        : data(data), height(height), width(width)
+    { }
+
+    __device__ __forceinline__ int item(int row, int column) const
+    {
+        int z_w = column * 3 / 32;
+        int z_mod =  column & 0x1f;
+
+        if (z_mod == 10) {
+            return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
+        } else if (z_mod == 21) {
+            return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
+        } else if (z_mod < 10) {
+            return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
+        } else if (z_mod < 21) {
+            return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3  - 32)) & 0x07;
+        } else {
+            return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3  - 64)) & 0x07;
+        }
+    }
+
+    __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
+    {
+        int shift = (column & 0x1f);
+        uint32_t d;
+        if (shift <= 4) {
+            d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
+        } else if (shift == 8) {
+            d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
+        } else if (shift <= 16) {
+            d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
+        } else if (shift == 20) {
+            d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
+        } else {
+            d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
+        }
+        items[0] = d & 0x07;
+        items[1] = (d >> 3) & 0x07;
+        items[2] = (d >> 6) & 0x07;
+        items[3] = (d >> 9) & 0x07;
+    }
+};
+
+class MatrixView_q8_row
+{
+public:
+    const uint32_t* data;
+    const int height;
+    const int width;
+
+    __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
+        : data(data), height(height), width(width)
+    { }
+
+    __device__ __forceinline__ int item(int row, int column) const
+    {
+        int shift = (column & 0x03) * 8;
+        return (data[row * width / 4 + column / 4] >> shift) & 0xff;
+    }
+
+    __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
+    {
+        int shift = (column & 0x03) * 8;
+        uint32_t d = data[row * width / 4 + column / 4] >> shift;
+        items[0] = d & 0xff;
+        items[1] = (d >> 8) & 0xff;
+    }
+
+    __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
+    {
+        int shift = (column & 0x03) * 2;
+        uint32_t d = data[row * width / 4 + column / 4] >> shift;
+        items[0] = d & 0xff;
+        items[1] = (d >> 8) & 0xff;
+        items[2] = (d >> 16) & 0xff;
+        items[3] = (d >> 24) & 0xff;
+    }
+};
+
 }  // namespace gptq
 }  // namespace gptq
 }  // namespace aphrodite
 }  // namespace aphrodite
 #endif
 #endif

File diff suppressed because it is too large
+ 911 - 98
kernels/quantization/gptq/q_gemm.cu


+ 87 - 0
kernels/quantization/gptq/qdq_2.cuh

@@ -0,0 +1,87 @@
+/*
+Copied from https://github.com/turboderp/exllamav2
+*/
+
+#ifndef _qdq_2_cuh
+#define _qdq_2_cuh
+
+#include "qdq_util.cuh"
+
+namespace aphrodite {
+namespace gptq {
+
+// Permutation:
+//
+// ffddbb99 77553311  eeccaa88 66442200
+
+__forceinline__ __device__ void shuffle_2bit_16
+(
+    uint32_t* q,
+    int stride
+)
+{
+    uint32_t qa = q[0];
+    uint32_t qb = 0;
+
+    #pragma unroll
+    for (int i = 0; i < 8; i++)
+    {
+        uint32_t qa0 = qa & 0x03;
+        uint32_t qa1 = (qa & 0x0c) >> 2;
+        qa >>= 4;
+        qb |= (qa1 << (i * 2 + 16));
+        qb |= (qa0 << (i * 2));
+    }
+    q[0] = qb;
+}
+
+__forceinline__ __device__ void dequant_2bit_16
+(
+    const uint32_t q_0,
+    half2 (&dq)[8],
+    int stride,
+    const uint32_t zero
+)
+{
+    const uint32_t c0 = 0x64006400;
+    const half y4_  = __float2half_rn(1.0f /  4.0f);
+    const half y16_ = __float2half_rn(1.0f / 16.0f);
+    const half y64_ = __float2half_rn(1.0f / 64.0f);
+    const half2 y4  = __halves2half2(y4_,  y4_);
+    const half2 y16 = __halves2half2(y16_, y16_);
+    const half2 y64 = __halves2half2(y64_, y64_);
+
+    const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
+    const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
+    const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
+    const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
+    const half2 z1 = __half2half2(z1_.as_half);
+    const half2 z4 = __half2half2(z4_);
+    const half2 z16 = __half2half2(z16_);
+    const half2 z64 = __half2half2(z64_);
+
+    uint32_t qa = q_0;
+    half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1])      + 1024
+    half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) *  4 + 1024
+    half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
+    half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
+    qa >>= 8;
+    half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8])      + 1024
+    half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) *  4 + 1024
+    half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
+    half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
+
+    dq[0] = __hadd2(q0.as_half2, z1);
+    dq[1] = __hfma2(q1.as_half2, y4,  z4);
+    dq[2] = __hfma2(q2.as_half2, y16, z16);
+    dq[3] = __hfma2(q3.as_half2, y64, z64);
+    dq[4] = __hadd2(q4.as_half2, z1);
+    dq[5] = __hfma2(q5.as_half2, y4,  z4);
+    dq[6] = __hfma2(q6.as_half2, y16, z16);
+    dq[7] = __hfma2(q7.as_half2, y64, z64);
+}
+
+}  // namespace gptq
+}  // namespace aphrodite
+
+#endif

+ 141 - 0
kernels/quantization/gptq/qdq_3.cuh

@@ -0,0 +1,141 @@
+#ifndef _qdq_3_cuh
+#define _qdq_3_cuh
+
+#include "qdq_util.cuh"
+
+namespace aphrodite {
+namespace gptq {
+// Permutation:
+//
+// v9997775 55333111  u8886664 44222000  (u, v lsb)
+// vjjjhhhf ffdddbbb  uiiiggge eecccaaa
+// vtttrrrp ppnnnlll  usssqqqo oommmkkk
+
+__forceinline__ __device__ void shuffle_3bit_32
+(
+    uint32_t* q,
+    int stride
+)
+{
+    uint32_t qa = q[0 * stride];
+    uint32_t qb = q[1 * stride];
+    uint32_t qc = q[2 * stride];
+
+    // qa: aa999888 77766655  54443332 22111000
+    // qb: lkkkjjji iihhhggg  fffeeedd dcccbbba
+    // qc: vvvuuutt tsssrrrq  qqpppooo nnnmmmll
+
+    uint32_t qd = qc >> 26;
+    qc <<= 4;
+    qc |= qb >> 28;
+    qb <<= 2;
+    qb |= qa >> 30;
+
+    // qa: ..999888 77766655  54443332 22111000
+    // qb: ..jjjiii hhhgggff  feeedddc ccbbbaaa
+    // qc: ..tttsss rrrqqqpp  pooonnnm mmlllkkk
+    // qd:                               vvvuuu
+
+    uint32_t za = 0;
+    uint32_t zb = 0;
+    uint32_t zc = 0;
+
+    for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
+    for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
+    for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
+
+    // za:  9997775 55333111   8886664 44222000
+    // zb:  jjjhhhf ffdddbbb   iiiggge eecccaaa
+    // zc:  tttrrrp ppnnnlll   sssqqqo oommmkkk
+    // qd:                               vvvuuu
+
+    za |= ((qd & 0x01) >> 0) << 15;
+    zb |= ((qd & 0x02) >> 1) << 15;
+    zc |= ((qd & 0x04) >> 2) << 15;
+    za |= ((qd & 0x08) >> 3) << 31;
+    zb |= ((qd & 0x10) >> 4) << 31;
+    zc |= ((qd & 0x20) >> 5) << 31;
+
+    // za: v9997775 55333111  u8886664 44222000  (u, v lsb)
+    // zb: vjjjhhhf ffdddbbb  uiiiggge eecccaaa
+    // zc: vtttrrrp ppnnnlll  usssqqqo oommmkkk
+
+    q[0 * stride] = za;
+    q[1 * stride] = zb;
+    q[2 * stride] = zc;
+}
+
+__forceinline__ __device__ void dequant_3bit_32
+(
+    const uint32_t q_0,
+    const uint32_t q_1,
+    const uint32_t q_2,
+    half2 (&dq)[16],
+    int stride,
+    const uint32_t zero
+)
+{
+    const uint32_t c0 = 0x64006400;
+    const half y8_  = __float2half_rn(1.0f /  8.0f);
+    const half y64_ = __float2half_rn(1.0f / 64.0f);
+    const half2 y8  = __halves2half2(y8_,  y8_);
+    const half2 y64 = __halves2half2(y64_, y64_);
+    const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
+    const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
+    const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
+    const half2 z1  = __halves2half2(z1_.as_half,  z1_.as_half);
+    const half2 z8  = __halves2half2(z8_,  z8_);
+    const half2 z64 = __halves2half2(z64_, z64_);
+
+    uint32_t qa = q_0;
+    uint32_t qb = q_1;
+    uint32_t qc = q_2;
+
+    half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1])      + 1024
+    half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) *  8 + 1024
+    qa >>= 6;
+    half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5])      + 1024
+    half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) *  8 + 1024
+    half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
+    qa >>= 9;
+    qa &= 0x00010001;
+    half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11])      + 1024
+    half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) *  8 + 1024
+    qb >>= 6;
+    half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15])      + 1024
+    half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) *  8 + 1024
+    half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
+    qb >>= 8;
+    qb &= 0x00020002;
+    half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21])      + 1024
+    half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) *  8 + 1024
+    qc >>= 6;
+    half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25])      + 1024
+    half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) *  8 + 1024
+    half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
+    qc >>= 7;
+    qc &= 0x00040004;
+    half2_uint32 q15((qa | qb | qc) | c0);
+
+    dq[ 0] = __hadd2( q0.as_half2, z1);
+    dq[ 1] = __hfma2( q1.as_half2, y8,  z8);
+    dq[ 2] = __hadd2( q2.as_half2, z1);
+    dq[ 3] = __hfma2( q3.as_half2, y8,  z8);
+    dq[ 4] = __hfma2( q4.as_half2, y64, z64);
+    dq[ 5] = __hadd2( q5.as_half2, z1);
+    dq[ 6] = __hfma2( q6.as_half2, y8,  z8);
+    dq[ 7] = __hadd2( q7.as_half2, z1);
+    dq[ 8] = __hfma2( q8.as_half2, y8,  z8);
+    dq[ 9] = __hfma2( q9.as_half2, y64, z64);
+    dq[10] = __hadd2(q10.as_half2, z1);
+    dq[11] = __hfma2(q11.as_half2, y8,  z8);
+    dq[12] = __hadd2(q12.as_half2, z1);
+    dq[13] = __hfma2(q13.as_half2, y8,  z8);
+    dq[14] = __hfma2(q14.as_half2, y64, z64);
+    dq[15] = __hadd2(q15.as_half2, z1);
+}
+
+}  // namespace gptq
+}  // namespace aphrodite
+
+#endif

+ 6 - 94
kernels/quantization/gptq/qdq_4.cuh

@@ -38,16 +38,17 @@ __forceinline__ __device__ void dequant_4bit_8
 (
 (
     const uint32_t q_0,
     const uint32_t q_0,
     half2 (&dq)[4],
     half2 (&dq)[4],
-    int stride
+    int stride,
+    const uint32_t zero
 )
 )
 {
 {
     const uint32_t c0 = 0x64006400;
     const uint32_t c0 = 0x64006400;
     const half y16_ = __float2half_rn(1.0f / 16.0f);
     const half y16_ = __float2half_rn(1.0f / 16.0f);
     const half2 y16 = __halves2half2(y16_, y16_);
     const half2 y16 = __halves2half2(y16_, y16_);
-    const half z1_  = __float2half_rn(-1024.0f         - 8.0f);
-    const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
-    const half2 z1  = __halves2half2(z1_,  z1_);
-    const half2 z16 = __halves2half2(z16_, z16_);
+    const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
+    const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
+    const half2 z1 = __half2half2(z1_.as_half);
+    const half2 z16 = __half2half2(z16_);
 
 
     uint32_t qa = q_0;
     uint32_t qa = q_0;
     half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1])      + 1024
     half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1])      + 1024
@@ -143,93 +144,4 @@ __forceinline__ __device__ void dequant_4bit_8_gptq
 }  // namespace gptq
 }  // namespace gptq
 }  // namespace aphrodite
 }  // namespace aphrodite
 
 
-#else
-
-namespace aphrodite {
-namespace gptq {
-__forceinline__ __device__ void shuffle_4bit_8
-(
-    uint32_t* q,
-    int stride
-)
-{
-}
-
-__forceinline__ __device__ void dequant_4bit_8
-(
-    const uint32_t q_0,
-    half2 (&dq)[4],
-    int stride
-)
-{
-    half dqh[8];
-    for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
-
-    for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
-}
-
-__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
-(
-    const uint32_t zero,
-    const half scale,
-    half2 (&z1)[2],
-    half2 (&y1)[2]
-)
-{
-    half z = __int2half_rn(-((int)zero));
-    z = __hmul(z, scale);
-    z1[0] = __half2half2(z);
-    y1[0] = __half2half2(scale);
-}
-
-__forceinline__ __device__ void dequant_4bit_8_prep_zero
-(
-    const uint32_t zero,
-    half2(&z1)[2],
-    half2(&y1)[2]
-)
-{
-    half z = __int2half_rn(-((int)zero));
-    z1[0] = __half2half2(z);
-}
-
-__forceinline__ __device__ void dequant_4bit_8_gptq
-(
-    const uint32_t q_0,
-    half2 (&dq)[4],
-    half2 (&z1)[2],
-    half2 (&y1)[2],
-    int stride,
-    bool scaled
-)
-{
-    half2 dqh2[8];
-
-    uint32_t qa = q_0;
-    for (int i = 0; i < 4; i++)
-    {
-        half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
-        half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
-        dqh2[i] = __halves2half2(d0, d1);
-    }
-
-    if (scaled)
-    {
-        dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
-        dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
-        dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
-        dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
-    }
-    else
-    {
-        dq[0] = __hadd2(dqh2[0], z1[0]);
-        dq[1] = __hadd2(dqh2[1], z1[0]);
-        dq[2] = __hadd2(dqh2[2], z1[0]);
-        dq[3] = __hadd2(dqh2[3], z1[0]);
-    }
-}
-
-}  // namespace gptq
-}  // namespace aphrodite
-
 #endif
 #endif

+ 40 - 0
kernels/quantization/gptq/qdq_8.cuh

@@ -0,0 +1,40 @@
+/*
+Copied from https://github.com/turboderp/exllamav2
+*/
+
+#ifndef _qdq_8_cuh
+#define _qdq_8_cuh
+
+#include "qdq_util.cuh"
+
+namespace aphrodite {
+namespace gptq {
+
+__forceinline__ __device__ void shuffle_8bit_4
+(
+    uint32_t* q,
+    int stride
+)
+{
+}
+
+__forceinline__ __device__ void dequant_8bit_8
+(
+    const uint32_t q_0,
+    const uint32_t q_1,
+    half2 (&dq)[4],
+    int stride,
+    const uint32_t zero
+)
+{
+    half dqh[8];
+    for (int i = 0; i < 4; i++) dqh[i    ] = dq_ns(exb(q_0, i * 8, 0xff), zero);
+    for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
+
+    for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
+}
+
+}  // namespace gptq
+}  // namespace aphrodite
+
+#endif

+ 0 - 1
setup.py

@@ -237,7 +237,6 @@ aphrodite_extension = CUDAExtension(
 )
 )
 ext_modules.append(aphrodite_extension)
 ext_modules.append(aphrodite_extension)
 
 
-
 def get_path(*filepath) -> str:
 def get_path(*filepath) -> str:
     return os.path.join(ROOT_DIR, *filepath)
     return os.path.join(ROOT_DIR, *filepath)
 
 

Some files were not shown because too many files changed in this diff