Prechádzať zdrojové kódy

feat: re-write GPTQ and refactor exllama kernels (#152)

* refactor exllama kernels

* modify the linear layers

* awq fixes

* update gptq

* the useless quant

* skip bias init in models

* add acknowledgements and minor fixes

* fix mistral (finally)

* formatting

* things are fine, pylint

* abolish pylint

* pylintgit add .!
AlpinDale 1 rok pred
rodič
commit
62b2c4119d

+ 1 - 0
aphrodite/endpoints/kobold/protocol.py

@@ -71,6 +71,7 @@ class KAIGenerationInputSchema(BaseModel):
     disable_input_formatting: Optional[bool]
     frmtadsnsp: Optional[bool]
     quiet: Optional[bool]
+    # pylint: disable=unexpected-keyword-arg
     sampler_order: Optional[conlist(int, min_items=6)]
     sampler_seed: Optional[conint(ge=0, le=2**64 - 1)]
     sampler_full_determinism: Optional[bool]

+ 1 - 1
aphrodite/endpoints/openai/api_server.py

@@ -204,7 +204,7 @@ def create_logprobs(token_ids: List[int],
         logprobs.tokens.append(token)
         logprobs.token_logprobs.append(id_logprob[token_id])
         if len(logprobs.text_offset) == 0:
-            logprobs.text_offset.append(initial_text_offset)
+            logprobs.text_offset.append(initial_text_offset) # pylint: disable=unsubscriptable-object
         else:
             logprobs.text_offset.append(logprobs.text_offset[-1] +
                                         last_token_len)

+ 23 - 21
aphrodite/modeling/layers/linear.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
 
 import torch
 import torch.nn.functional as F
@@ -21,11 +21,10 @@ class LinearMethodBase(ABC):
     """Base class for different (maybe quantized) linear methods."""
 
     @abstractmethod
-    def create_weights(self,
-                       input_size: int,
+    def create_weights(self, input_size_per_partition: int,
+                       output_size_per_partition: int, input_size: int,
                        output_size: int,
-                       params_dtype: torch.dtype,
-                       parallel_type: str = "none") -> Dict[str, torch.Tensor]:
+                       params_dtype: torch.dtype) -> Dict[str, Any]:
         """Create weights for a linear layer."""
         raise NotImplementedError
 
@@ -49,13 +48,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
     def __init__(self, separate_bias_add: bool = False):
         self.separate_bias_add = separate_bias_add
 
-    def create_weights(self,
-                       input_size: int,
+    def create_weights(self, input_size_per_partition: int,
+                       output_size_per_partition: int, input_size: int,
                        output_size: int,
-                       params_dtype: torch.dtype,
-                       parallel_type: str = "none") -> Dict[str, torch.Tensor]:
-        weight = Parameter(torch.empty(output_size,
-                                       input_size,
+                       params_dtype: torch.dtype) -> Dict[str, Any]:
+        weight = Parameter(torch.empty(output_size_per_partition,
+                                       input_size_per_partition,
                                        device=torch.cuda.current_device(),
                                        dtype=params_dtype),
                            requires_grad=False)
@@ -108,9 +106,11 @@ class ReplicatedLinear(torch.nn.Module):
             linear_method = UnquantizedLinearMethod()
         self.linear_method = linear_method
         self.linear_weights = self.linear_method.create_weights(
-            self.input_size, self.output_size, self.params_dtype)
+            self.input_size, self.output_size, self.input_size,
+            self.output_size, self.params_dtype)
         for name, weight in self.linear_weights.items():
-            self.register_parameter(name, weight)
+            if isinstance(weight, torch.Tensor):
+                self.register_parameter(name, weight)
         if bias:
             self.bias = Parameter(
                 torch.empty(self.output_size,
@@ -174,11 +174,12 @@ class ColumnParallelLinear(torch.nn.Module):
             linear_method = UnquantizedLinearMethod()
         self.linear_method = linear_method
         self.linear_weights = self.linear_method.create_weights(
-            self.input_size, self.output_size_per_partition, self.params_dtype,
-            "column")
+            self.input_size, self.output_size_per_partition, self.input_size,
+            self.output_size, self.params_dtype)
         for name, weight in self.linear_weights.items():
-            self.register_parameter(name, weight)
-            set_weight_attrs(weight, {"weight_loader": self.weight_loader})
+            if isinstance(weight, torch.Tensor):
+                self.register_parameter(name, weight)
+                set_weight_attrs(weight, {"weight_loader": self.weight_loader})
         if bias:
             self.bias = Parameter(
                 torch.empty(self.output_size_per_partition,
@@ -488,11 +489,12 @@ class RowParallelLinear(torch.nn.Module):
             linear_method = UnquantizedLinearMethod()
         self.linear_method = linear_method
         self.linear_weights = self.linear_method.create_weights(
-            self.input_size_per_partition, self.output_size, self.params_dtype,
-            "row")
+            self.input_size_per_partition, self.output_size, self.input_size,
+            self.output_size, self.params_dtype)
         for name, weight in self.linear_weights.items():
-            self.register_parameter(name, weight)
-            set_weight_attrs(weight, {"weight_loader": self.weight_loader})
+            if isinstance(weight, torch.Tensor):
+                self.register_parameter(name, weight)
+                set_weight_attrs(weight, {"weight_loader": self.weight_loader})
 
         if not reduce_results and (bias and not skip_bias_add):
             raise ValueError("When not reduce the results, adding bias to the "

+ 12 - 13
aphrodite/modeling/layers/quantization/awq.py

@@ -84,17 +84,16 @@ class AWQLinearMethod(LinearMethodBase):
     def __init__(self, quant_config: AWQConfig):
         self.quant_config = quant_config
 
-    def create_weights(self,
-                       input_size: int,
+    def create_weights(self, input_size_per_partition: int,
+                       output_size_per_partition: int, input_size: int,
                        output_size: int,
-                       params_dtype: torch.dtype,
-                       parallel_type: str = "none") -> Dict[str, torch.Tensor]:
-        if input_size % self.quant_config.group_size != 0:
+                       params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
+        if input_size_per_partition % self.quant_config.group_size != 0:
             raise ValueError(
                 "The input size is not aligned with the quantized "
                 "weight shape. This can be caused by too large "
                 "tensor parallel size.")
-        if output_size % self.quant_config.pack_factor != 0:
+        if output_size_per_partition % self.quant_config.pack_factor != 0:
             raise ValueError(
                 "The output size is not aligned with the quantized "
                 "weight shape. This can be caused by too large "
@@ -102,8 +101,8 @@ class AWQLinearMethod(LinearMethodBase):
 
         qweight = Parameter(
             torch.empty(
-                input_size,
-                output_size // self.quant_config.pack_factor,
+                input_size_per_partition,
+                output_size_per_partition // self.quant_config.pack_factor,
                 device="cuda",
                 dtype=torch.int32,
             ),
@@ -118,8 +117,8 @@ class AWQLinearMethod(LinearMethodBase):
             })
         qzeros = Parameter(
             torch.empty(
-                input_size // self.quant_config.group_size,
-                output_size // self.quant_config.pack_factor,
+                input_size_per_partition // self.quant_config.group_size,
+                output_size_per_partition // self.quant_config.pack_factor,
                 device="cuda",
                 dtype=torch.int32,
             ),
@@ -134,8 +133,8 @@ class AWQLinearMethod(LinearMethodBase):
             })
         scales = Parameter(
             torch.empty(
-                input_size // self.quant_config.group_size,
-                output_size,
+                input_size_per_partition // self.quant_config.group_size,
+                output_size_per_partition,
                 device="cuda",
                 dtype=params_dtype,
             ),
@@ -152,7 +151,7 @@ class AWQLinearMethod(LinearMethodBase):
         }
 
     def apply_weights(self,
-                      weights: Dict[str, torch.Tensor],
+                      weights: Dict[str, Any],
                       x: torch.Tensor,
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         qweight = weights["qweight"]

+ 54 - 87
aphrodite/modeling/layers/quantization/gptq.py

@@ -1,3 +1,4 @@
+from enum import Enum
 from typing import Any, Dict, List, Optional
 
 import torch
@@ -8,13 +9,11 @@ from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               set_weight_attrs)
 from aphrodite.modeling.layers.quantization.base_config import (
     QuantizationConfig)
-from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
 
 
 class GPTQConfig(QuantizationConfig):
     """Config class for GPTQ.
-    Reference: https://arxiv.org/abs/2306.00978
+    Reference: https://arxiv.org/abs/2210.17323
     """
 
     def __init__(
@@ -71,6 +70,9 @@ class GPTQConfig(QuantizationConfig):
         return []
 
 
+ExlState = Enum("ExlState", ["Unused", "Uninitialized", "Ready"])
+
+
 class GPTQLinearMethod(LinearMethodBase):
     """Linear method for GPTQ.
     Args:
@@ -79,28 +81,42 @@ class GPTQLinearMethod(LinearMethodBase):
 
     def __init__(self, quant_config: GPTQConfig):
         self.quant_config = quant_config
-        self.use_exllama = True
 
-    def create_weights(self,
-                       input_size: int,
+    def create_weights(self, input_size_per_partition: int,
+                       output_size_per_partition: int, input_size: int,
                        output_size: int,
-                       params_dtype: torch.dtype,
-                       parallel_type: str = "none") -> Dict[str, torch.Tensor]:
-        if input_size % self.quant_config.group_size != 0:
+                       params_dtype: torch.dtype) -> Dict[str, Any]:
+        if input_size_per_partition % self.quant_config.group_size != 0:
             raise ValueError(
                 "The input size is not aligned with the quantized "
                 "weight shape. This can be caused by too large "
                 "tensor parallel size.")
-        if output_size % self.quant_config.pack_factor != 0:
+        if output_size_per_partition % self.quant_config.pack_factor != 0:
             raise ValueError(
                 "The output size is not aligned with the quantized "
                 "weight shape. This can be caused by too large "
                 "tensor parallel size.")
+        if self.quant_config.group_size != -1:
+            group_size = self.quant_config.group_size
+        else:
+            group_size = input_size
+        exllama_state = ExlState.Uninitialized
+        scale_and_zero_size = input_size // group_size
+        scale_and_zero_input_dim = None
+        if (input_size != input_size_per_partition
+                and self.quant_config.group_size != -1):
+            # For act-order models, we cannot use Exllama for row parallel layer
+            if self.quant_config.desc_act:
+                exllama_state = ExlState.Unused
+            else:
+                # we need to partition qzeros and scales for exllama kernel
+                scale_and_zero_size = input_size_per_partition // group_size
+                scale_and_zero_input_dim = 0
 
         qweight = Parameter(
             torch.empty(
-                input_size // self.quant_config.pack_factor,
-                output_size,
+                input_size_per_partition // self.quant_config.pack_factor,
+                output_size_per_partition,
                 device="cuda",
                 dtype=torch.int32,
             ),
@@ -115,37 +131,20 @@ class GPTQLinearMethod(LinearMethodBase):
             })
         g_idx = Parameter(
             torch.tensor(
-                [i // self.quant_config.group_size for i in range(input_size)],
+                [
+                    i // self.quant_config.group_size
+                    for i in range(input_size_per_partition)
+                ],
                 device="cuda",
                 dtype=torch.int32,
             ),
             requires_grad=False,
         )
         set_weight_attrs(g_idx, {"input_dim": 0})
-        tp_size = get_tensor_model_parallel_world_size()
-        if parallel_type == "row" and tp_size > 1 and (
-                self.quant_config.desc_act
-                and self.quant_config.group_size != -1):
-            input_size = input_size * tp_size
-            use_exllama = Parameter(torch.tensor(False,
-                                                 dtype=torch.bool,
-                                                 device="cuda"),
-                                    requires_grad=False)
-        else:
-            use_exllama = Parameter(torch.tensor(True,
-                                                 dtype=torch.bool,
-                                                 device="cuda"),
-                                    requires_grad=False)
-        if self.quant_config.desc_act or self.quant_config.group_size == -1:
-            input_dim = None
-        else:
-            input_dim = 0
-        # pylint: disable=line-too-long
-        group_size = self.quant_config.group_size if self.quant_config.group_size != -1 else input_size
         qzeros = Parameter(
             torch.empty(
-                input_size // group_size,
-                output_size // self.quant_config.pack_factor,
+                scale_and_zero_size,
+                output_size_per_partition // self.quant_config.pack_factor,
                 device="cuda",
                 dtype=torch.int32,
             ),
@@ -153,22 +152,22 @@ class GPTQLinearMethod(LinearMethodBase):
         )
         set_weight_attrs(
             qzeros, {
-                "input_dim": input_dim,
+                "input_dim": scale_and_zero_input_dim,
                 "output_dim": 1,
                 "packed_dim": 1,
                 "pack_factor": self.quant_config.pack_factor,
             })
         scales = Parameter(
             torch.empty(
-                input_size // group_size,
-                output_size,
+                scale_and_zero_size,
+                output_size_per_partition,
                 device="cuda",
                 dtype=params_dtype,
             ),
             requires_grad=False,
         )
         set_weight_attrs(scales, {
-            "input_dim": input_dim,
+            "input_dim": scale_and_zero_input_dim,
             "output_dim": 1,
         })
         return {
@@ -176,62 +175,30 @@ class GPTQLinearMethod(LinearMethodBase):
             "g_idx": g_idx,
             "qzeros": qzeros,
             "scales": scales,
-            "use_exllama": use_exllama,
+            "exllama_state": exllama_state,
         }
 
     def apply_weights(self,
-                      weights: Dict[str, torch.Tensor],
+                      weights: Dict[str, Any],
                       x: torch.Tensor,
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         qweight = weights["qweight"]
-        height, width = weights["qweight"].shape
         out_shape = x.shape[:-1] + (qweight.shape[-1], )
         reshaped_x = x.reshape(-1, x.shape[-1])
-        if weights["use_exllama"]:
-            if "q4" not in weights:
-                if not self.quant_config.desc_act:
-                    none_tensor = torch.empty((1, 1), device="meta")
-                    weights["q4"] = quantization_ops.make_q_matrix(
-                        weights["qweight"],
-                        none_tensor,
-                        none_tensor,
-                        weights["qzeros"],
-                        weights["scales"],
-                        none_tensor,
-                    )
-                else:
-                    weights["q_perm"] = torch.empty(
-                        (height * self.quant_config.pack_factor, ),
-                        dtype=torch.short,
-                        device=weights["qweight"].device)
-                    weights["q_invperm"] = torch.empty_like(weights["q_perm"])
-                    weights["q4"] = quantization_ops.make_q_matrix(
-                        weights["qweight"],
-                        weights["q_perm"],
-                        weights["q_invperm"],
-                        weights["qzeros"],
-                        weights["scales"],
-                        weights["g_idx"].cpu(),
-                    )
-            temp_dq = torch.empty(
-                (height * self.quant_config.pack_factor, width),
-                dtype=torch.float16,
-                device=x.device)
-            output = torch.empty((reshaped_x.shape[0], qweight.shape[-1]),
-                                 dtype=torch.float16,
-                                 device=x.device)
-            quantization_ops.gemm_half_q_half(reshaped_x, weights["q4"],
-                                              output, temp_dq, False)
-        else:
-            output = torch.zeros((reshaped_x.shape[0], qweight.shape[-1]),
-                                 dtype=torch.float32,
-                                 device=x.device)
-            quantization_ops.gptq_descact_matmul(reshaped_x.float(),
-                                                 weights["qweight"], output,
-                                                 weights["scales"].float(),
-                                                 weights["qzeros"],
-                                                 weights["g_idx"])
-            output = output.half()
+        # exllama needs to shuffle the weight after it's loaded
+        # here we do the shuffle on the first forward pass
+        if weights["exllama_state"] == ExlState.Uninitialized:
+            if self.quant_config.desc_act:
+                weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
+                    torch.int)
+            else:
+                weights["g_idx"] = torch.empty((1, 1), device="meta")
+            weights["exllama_state"] = ExlState.Ready
+            quantization_ops.gptq_shuffle(weights["qweight"], weights["g_idx"])
+        output = quantization_ops.gptq_gemm(
+            reshaped_x, weights["qweight"], weights["qzeros"],
+            weights["scales"], weights["g_idx"],
+            weights["exllama_state"] == ExlState.Ready)
         if bias is not None:
             output = output + bias
         return output.reshape(out_shape)

+ 8 - 9
aphrodite/modeling/layers/quantization/squeezellm.py

@@ -67,20 +67,19 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
     def __init__(self, quant_config: SqueezeLLMConfig):
         self.quant_config = quant_config
 
-    def create_weights(self,
-                       input_size: int,
+    def create_weights(self, input_size_per_partition: int,
+                       output_size_per_partition: int, input_size: int,
                        output_size: int,
-                       params_dtype: torch.dtype,
-                       parallel_type: str = "none") -> Dict[str, torch.Tensor]:
-        if input_size % self.quant_config.pack_factor != 0:
+                       params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
+        if input_size_per_partition % self.quant_config.pack_factor != 0:
             raise ValueError(
                 "The input size is not aligned with the quantized "
                 "weight shape. This can be caused by too large "
                 "tensor parallel size.")
         qweight = Parameter(
             torch.empty(
-                input_size // self.quant_config.pack_factor,
-                output_size,
+                input_size_per_partition // self.quant_config.pack_factor,
+                output_size_per_partition,
                 device="cuda",
                 dtype=torch.int32,
             ),
@@ -95,7 +94,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
             })
         lookup_table = Parameter(
             torch.empty(
-                output_size,
+                output_size_per_partition,
                 self.quant_config.weight_bits**2,
                 device="cuda",
                 dtype=params_dtype,
@@ -111,7 +110,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
         }
 
     def apply_weights(self,
-                      weights: Dict[str, torch.Tensor],
+                      weights: Dict[str, Any],
                       x: torch.Tensor,
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         qweight = weights["qweight"]

+ 4 - 2
aphrodite/modeling/models/gpt_j.py

@@ -286,14 +286,16 @@ class GPTJForCausalLM(nn.Module):
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)
-                if name not in params_dict:
+                # skip loading extra bias for GPTQ models
+                if name.endswith("bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                if name not in params_dict:
+                # skip loading extra bias for GPTQ models
+                if name.endswith("bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",

+ 4 - 2
aphrodite/modeling/models/llama.py

@@ -333,14 +333,16 @@ class LlamaForCausalLM(nn.Module):
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)
-                if name not in params_dict:
+                # skip loading extra bias for GPTQ models
+                if name.endswith("bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                if name not in params_dict:
+                # skip loading extra bias for GPTQ models
+                if name.endswith("bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",

+ 8 - 13
aphrodite/modeling/models/mistral.py

@@ -21,11 +21,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Inference-only Mistral model compatible with HuggingFace weights.
-
-The input of the model is flattened to a 1D tensor of tokens. The model uses
-InputMetadata to extract the original 2D shape of the input.
-"""
+"""Inference-only Mistral model compatible with HuggingFace weights."""
 from typing import List, Optional, Tuple
 
 import torch
@@ -131,6 +127,7 @@ class MistralAttention(nn.Module):
             bias=False,
             linear_method=linear_method,
         )
+
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
@@ -199,7 +196,7 @@ class MistralDecoderLayer(nn.Module):
         input_metadata: InputMetadata,
         cache_event: Optional[torch.cuda.Event],
         residual: Optional[torch.Tensor],
-    ) -> torch.Tensor:
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Self Attention
         if residual is None:
             residual = hidden_states
@@ -214,7 +211,6 @@ class MistralDecoderLayer(nn.Module):
             input_metadata=input_metadata,
             cache_event=cache_event,
         )
-        hidden_states = residual + hidden_states
 
         # Fully Connected
         hidden_states, residual = self.post_attention_layernorm(
@@ -256,10 +252,7 @@ class MistralModel(nn.Module):
         hidden_states = self.embed_tokens(input_ids)
         residual = None
         for i in range(len(self.layers)):
-            if cache_events is None:
-                cache_event = None
-            else:
-                cache_event = cache_events[i]
+            cache_event = None if cache_events is None else cache_events[i]
             layer = self.layers[i]
             hidden_states, residual = layer(
                 positions,
@@ -330,14 +323,16 @@ class MistralForCausalLM(nn.Module):
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)
-                if name not in params_dict:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                if name not in params_dict:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",

+ 2 - 2
aphrodite/modeling/models/phi1_5.py

@@ -271,8 +271,8 @@ class PhiForCausalLM(nn.Module):
             if "rotary_emb.inv_freq" in name:
                 continue
 
-            # pylint: disable=E1136
-            if name not in params_dict:
+            # skip loading extra bias for GPTQ models
+            if name.endswith("bias") and name not in params_dict:
                 continue
             param = params_dict[name]
             weight_loader = getattr(param, "weight_loader",

+ 4 - 2
aphrodite/modeling/models/yi.py

@@ -326,14 +326,16 @@ class YiForCausalLM(nn.Module):
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)
-                if name not in params_dict:
+                # skip loading extra bias for GPTQ models
+                if name.endswith("bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                if name not in params_dict:
+                # skip loading extra bias for GPTQ models
+                if name.endswith("bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",

+ 14 - 21
kernels/ops.h

@@ -78,26 +78,19 @@ void squeezellm_gemm(
   torch::Tensor mul,
   torch::Tensor lookup_table);
 
-uintptr_t make_q_matrix(
-    torch::Tensor q_weight,
-    torch::Tensor q_perm,
-    torch::Tensor q_invperm,
-    torch::Tensor gptq_qzeros,
-    torch::Tensor gptq_scales,
-    torch::Tensor gptq_g_idx);
+torch::Tensor gptq_gemm
+(
+  torch::Tensor a,
+  torch::Tensor b_q_weight,
+  torch::Tensor b_gptq_qzeros,
+  torch::Tensor b_gptq_scales,
+  torch::Tensor b_g_idx,
+  bool use_exllama
+);
 
-void gemm_half_q_half(
-    torch::Tensor a,
-    uintptr_t b,
-    torch::Tensor c,
-    torch::Tensor temp_dq,
-    bool force_cuda);
-
-void gptq_descact_matmul(
-  torch::Tensor vec,
-  torch::Tensor mat,
-  torch::Tensor mul,
-  torch::Tensor scales,
-  torch::Tensor zeros,
-  torch::Tensor g_idx);
+void gptq_shuffle
+(
+  torch::Tensor q_weight,
+  torch::Tensor q_perm
+);
   

+ 2 - 3
kernels/pybind.cpp

@@ -50,12 +50,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 
   // Quantization ops
   #ifndef USE_ROCM
+  ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
+  ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
   ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
   #endif
   ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
-  ops.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
-  ops.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
-  ops.def("gptq_descact_matmul", &gptq_descact_matmul, "Quantized GEMM for GPTQ for parallelized desc_act layer.");
 
   // Cache ops
   pybind11::module cache_ops = m.def_submodule("cache_ops", "Aphrodite Engine cache ops");

+ 8 - 0
kernels/quantization/gptq/compat.cuh

@@ -1,6 +1,12 @@
+/*
+Copied from https://github.com/turboderp/exllamav2
+*/
+
 #ifndef _compat_cuh
 #define _compat_cuh
 
+namespace aphrodite {
+namespace gptq {
 // atomicAdd for half types, to support CC < 7.x
 
 __device__ __forceinline__ void atomicAdd_half(half* address, half val)
@@ -53,4 +59,6 @@ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd
 #endif
 #endif
 
+}  // namespace gptq
+}  // namespace aphrodite
 #endif

+ 0 - 100
kernels/quantization/gptq/exllama_ext.cpp

@@ -1,100 +0,0 @@
-#include <torch/extension.h>
-#include <c10/cuda/CUDAGuard.h>
-#include <ATen/cuda/CUDAContext.h>
-#include <cuda_runtime.h>
-#include <cuda_fp16.h>
-#include <cstdint>
-#include <cstdio>
-
-#include "q_matrix.cuh"
-#include "q_gemm.cuh"
-
-// Some decluttering macros
-
-#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
-#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
-#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
-#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
-
-
-// Quant matrix
-
-uintptr_t make_q_matrix
-(
-    torch::Tensor q_weight,
-    torch::Tensor q_perm,
-    torch::Tensor q_invperm,
-    torch::Tensor gptq_qzeros,
-    torch::Tensor gptq_scales,
-    torch::Tensor gptq_g_idx
-)
-{
-    TORCH_CHECK_DTYPE(q_weight, kInt);
-    TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
-    TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
-    TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
-    TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
-    TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
-    TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
-    TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
-
-    TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
-
-    int device = q_weight.device().index();
-    int width = q_weight.size(1);
-    int groups;
-    int height;
-
-    groups = gptq_qzeros.size(0);
-    height = q_weight.size(0) * 8;
-
-    QMatrix* m = new QMatrix
-    (
-        device,
-        height,
-        width,
-        groups,
-        (uint32_t*) q_weight.data_ptr(),
-        q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
-        q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
-        gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
-        gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
-        gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr()
-    );
-
-    return reinterpret_cast<uintptr_t> (m);
-}
-
-void gemm_half_q_half
-(
-    torch::Tensor a,
-    uintptr_t b,
-    torch::Tensor c,
-    torch::Tensor temp_dq,
-    bool force_cuda
-)
-{
-    QMatrix* qm = reinterpret_cast<QMatrix*> (b);
-
-    TORCH_CHECK_DTYPE(a, kHalf);
-    TORCH_CHECK_DTYPE(c, kHalf);
-    TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
-    TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
-    TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
-
-    const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
-
-    gemm_half_q_half_cuda
-    (
-        at::cuda::getCurrentCUDABlasHandle(),
-        (const half*) a.data_ptr(),
-        qm,
-        (half*) c.data_ptr(),
-        c.size(0), // m
-        c.size(1), // n
-        a.size(1), // k
-        true,
-        (half*) temp_dq.data_ptr(),
-        force_cuda
-    );
-}

+ 30 - 0
kernels/quantization/gptq/matrix_view.cuh

@@ -1,3 +1,7 @@
+/*
+Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
+*/
+
 #ifndef _matrix_view_cuh
 #define _matrix_view_cuh
 
@@ -6,6 +10,9 @@
 
 #include "qdq_util.cuh"
 
+namespace aphrodite {
+namespace gptq {
+
 class MatrixView_half
 {
 public:
@@ -118,4 +125,27 @@ public:
     }
 };
 
+class MatrixView_q4_column
+{
+public:
+    const uint32_t* data;
+    const int height;
+    const int width;
+
+    __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
+        : data(data), height(height), width(width)
+    { }
+
+    __device__ __forceinline__ int item(int row, int column) const
+    {
+        int shift = (row & 0x07) * 4;
+        return (data[row / 8 * width + column] >> shift) & 0x0f;
+    }
+
+    __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
+    __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
+};
+
+}  // namespace gptq
+}  // namespace vllm
 #endif

+ 0 - 125
kernels/quantization/gptq/old_matmul_kernel.cu

@@ -1,125 +0,0 @@
-#include <torch/all.h>
-#include <torch/python.h>
-#include <c10/cuda/CUDAGuard.h>
-#include <cuda.h>
-#include <cuda_runtime.h>
-#include <cuda_fp16.h>
-#include "compat.cuh"
-
-const int BLOCKWIDTH  = 256;
-const int BLOCKHEIGHT =  32;
-
-__device__ inline unsigned int as_unsigned(int i) {
-  return *reinterpret_cast<unsigned int*>(&i);
-}
-
-__device__ inline int as_int(int i) {
-  return *reinterpret_cast<int*>(&i);
-}
-
-template <typename scalar_t>
-__global__ void VecQuant4MatMulKernel(
-    const  scalar_t* __restrict__ vec,
-    const       int* __restrict__ mat,
-           scalar_t* __restrict__ mul,
-    const  scalar_t* __restrict__ scales,
-    const       int* __restrict__ zeros,
-    const   	int* __restrict__ g_idx,
-    int batch,
-    int vec_height,
-    int height,
-    int width,
-	int zero_width
-) {
-    int h = BLOCKHEIGHT * blockIdx.x;
-    int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
-    int h_end = min(h + BLOCKHEIGHT, height);
-
-    __shared__ scalar_t blockvec[BLOCKWIDTH];
-    int i = width * h + w;
-    int g_h = h * 8;
-    int h_range = (h_end - h) * 8;
-    int k;
-    unsigned int g;
-    scalar_t w_tmp;
-
-
-    int z_w = w / 8;
-    int z_mod = (w % 8) * 4;
-
-    float weight[BLOCKWIDTH];
-
-    if (w < width) {
-        for (k = 0; k < h_range; ++k) {
-    	      int k_w = (k / 8);
-	          int k_bit = (k % 8) * 4;
-
-            g = as_int(g_idx[g_h + k]);
-            scalar_t scale = scales[g * width + w];
-            scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
-            w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
-            weight[k] = scale * (w_tmp - zero);
-        }
-    }
-
-    scalar_t res;
-    for (int b = 0; b < batch; ++b) {
-	    res = 0;
-
-        if (threadIdx.x < h_range) {
-            blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
-        }
-        __syncthreads();
-        if (w < width) {
-	        for (k = 0; k < h_range; ++k){
-	            res += weight[k] * blockvec[k];
-            }
-            atomicAdd(&mul[b * width + w], res);
-        }
-        __syncthreads();
-    }
-}
-
-void vecquant4matmul_cuda(
-    torch::Tensor vec,
-    torch::Tensor mat,
-    torch::Tensor mul,
-    torch::Tensor scales,
-    torch::Tensor zeros,
-    torch::Tensor g_idx
-) {
-    int batch = vec.size(0);
-    int vec_height = vec.size(1);
-    int height = mat.size(0);
-    int width = mat.size(1);
-    int zero_width = zeros.size(1);
-
-    dim3 blocks(
-        (height + BLOCKHEIGHT - 1) / BLOCKHEIGHT,
-        (width + BLOCKWIDTH - 1) / BLOCKWIDTH
-    );
-    dim3 threads(BLOCKWIDTH);
-
-    AT_DISPATCH_FLOATING_TYPES(
-        vec.type(), "vecquant4matmul_cuda", ([&] {
-            VecQuant4MatMulKernel<<<blocks, threads>>>(
-                vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
-                scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
-                batch, vec_height, height, width, zero_width
-            );
-        })
-    );
-}
-
-void gptq_descact_matmul(
-  torch::Tensor vec,
-  torch::Tensor mat,
-  torch::Tensor mul,
-  torch::Tensor scales,
-  torch::Tensor zeros,
-  torch::Tensor g_idx
-)
-{
-  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
-  vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
-}

+ 741 - 48
kernels/quantization/gptq/q_gemm.cu

@@ -1,17 +1,32 @@
-#include "q_gemm.cuh"
-#include "matrix_view.cuh"
+/*
+Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa
+*/
+
+#include <cstdint>
+#include <cstdio>
 
+#include <torch/extension.h>
+#include <c10/cuda/CUDAGuard.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <cuda_runtime.h>
+#include <cuda_fp16.h>
+
+#include "compat.cuh"
+#include "matrix_view.cuh"
 #include "qdq_4.cuh"
 
+namespace aphrodite {
+namespace gptq {
+
 #define BLOCK_KN_SIZE 128
 #define BLOCK_M_SIZE_MAX 8
 #define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
-#define CLEAR_N_SIZE 256
 #define MAX_Q_GEMM_ROWS 50
+#define MAX_ALT_GEMM_ROWS 8
+#define THREADS_X 32
+#define THREADS_Y 32
 #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
 
-#include "q_gemm_kernel_gptq.cuh"
-
 #if defined(USE_ROCM)
 __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t    handle,
                                                                hipblasOperation_t transA,
@@ -41,16 +56,224 @@ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t
 #define rocblas_hgemm __compat_hipblasHgemm
 #endif
 
+__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
+{
+    half2 result = {};
+    const half2* a2_ptr = (const half2*)a_ptr;
+    #pragma unroll
+    for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
+    return __hadd2(result, g_result);
+}
+
+__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
+{
+    half2 result = {};
+    const half2* a2_ptr = (const half2*)a_ptr;
+    #pragma unroll
+    for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
+    return __half2float(__low2half(result)) + __half2float(__high2half(result));
+}
+
+typedef void (*fp_gemm_half_q_half_gptq_kernel)
+(
+    const half*,
+    const uint32_t*,
+    const uint32_t*,
+    const half*,
+    half*,
+    const int,
+    const int,
+    const int,
+    const int,
+    const int*
+);
+
+template <bool first_block, int m_count>
+__global__ void gemm_half_q_half_gptq_kernel
+(
+    const half* __restrict__ a,
+    const uint32_t* __restrict__ b_q_weight,
+    const uint32_t* __restrict__ b_gptq_qzeros,
+    const half* __restrict__ b_gptq_scales,
+    half* __restrict__ c,
+    const int size_m,
+    const int size_n,
+    const int size_k,
+    const int groups,
+    const int* __restrict__ b_q_perm
+)
+{
+    MatrixView_half a_(a, size_m, size_k);
+    MatrixView_half_rw c_(c, size_m, size_n);
+    MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
+    MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
+
+    int t = threadIdx.x;
+
+    // Block
+    int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
+    int offset_m = blockIdx.y * m_count;
+    int offset_k = blockIdx.z * BLOCK_KN_SIZE;
+
+    int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
+    int end_m = min(offset_m + m_count, size_m);
+    int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
+
+    int n = offset_n + t * 4;
+
+    // Preload block_a
+    __shared__ half block_a[m_count][BLOCK_KN_SIZE];
+
+    if (offset_k + t < end_k)
+    {
+        for (int m = 0; m < m_count; ++m)
+        {
+            const half* a_ptr = a_.item_ptr(offset_m + m, 0);
+            half* block_a_ptr = block_a[m];
+
+            half a0;
+            if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
+            else a0 = a_ptr[offset_k + t];
+            block_a_ptr[t] = a0;
+        }
+    }
+
+    // Zero output
+    if (n >= size_n) return;
+
+    if (blockIdx.z == 0)
+    {
+        for (int m = 0; m < m_count; m++)
+            *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
+    }
+
+    __syncthreads();
+
+    // Find initial group
+    int groupsize = size_k / groups;
+    int group = offset_k / groupsize;
+    int nextgroup = offset_k + groupsize;
+
+    // a, b offset
+    int qk = offset_k / (32 / 4);
+
+    const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
+    const half* a_ptr = &block_a[0][0];
+    int a_stride = BLOCK_KN_SIZE;
+
+    // Initial group
+    int zeros[4];
+    float scales[4];
+    half2 z1z16[4][2];
+    half2 y1y16[4][2];
+    b_gptq_qzeros_.item4(zeros, group, n);
+    b_gptq_scales_.item4_f(scales, group, n);
+    dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
+    dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
+    dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
+    dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
+
+    // Column result
+    float block_c[m_count][4] = {};
+
+    // Dequantize and multiply
+    int k = offset_k;
+    while (k < end_k)
+    {
+        if (k == nextgroup)
+        {
+            group++;
+            nextgroup += groupsize;
+            b_gptq_qzeros_.item4(zeros, group, n);
+            b_gptq_scales_.item4_f(scales, group, n);
+            dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
+            dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
+            dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
+            dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
+        }
+
+        #pragma unroll
+        for (int j = 0; j < 4; j++)
+        {
+            const int4* b_ptr4 = (int4*) b_ptr;
+            int4 load_int4 = *b_ptr4;
+
+            half2 dq[4][4];
+            dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
+            dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
+            dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
+            dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
+
+            #pragma unroll
+            for (int m = 0; m < m_count; m++)
+            {
+                block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
+                block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
+                block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
+                block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
+            }
+
+            b_ptr += size_n;
+            a_ptr += 8;
+        }
+
+        k += 32;
+    }
+
+    for (int m = 0; m < m_count; m++)
+    {
+        half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
+        half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
+        half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
+        atomicAdd(out    , result01);
+        atomicAdd(out + 1, result23);
+    }
+}
+
+
+fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
+{
+    #if BLOCK_M_SIZE_MAX >= 1
+    if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
+    #endif
+    #if BLOCK_M_SIZE_MAX >= 2
+    if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
+    #endif
+    #if BLOCK_M_SIZE_MAX >= 3
+    if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
+    #endif
+    #if BLOCK_M_SIZE_MAX >= 4
+    if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
+    #endif
+    #if BLOCK_M_SIZE_MAX >= 5
+    if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
+    #endif
+    #if BLOCK_M_SIZE_MAX >= 6
+    if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
+    #endif
+    #if BLOCK_M_SIZE_MAX >= 7
+    if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
+    #endif
+    #if BLOCK_M_SIZE_MAX >= 8
+    if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
+    #endif
+    return NULL;
+}
+
+
 void gemm_half_q_half_cuda_part
 (
     const half* a,
-    QMatrix* b,
+    const uint32_t* b_q_weight,
+    const uint32_t* b_gptq_qzeros,
+    const half* b_gptq_scales,
+    const int* b_q_perm,
     half* c,
     int size_m,
     int size_n,
     int size_k,
     int m_count,
-    bool clear
+    int groups
 )
 {
     dim3 blockDim, gridDim;
@@ -66,44 +289,391 @@ void gemm_half_q_half_cuda_part
     kernel<<<gridDim, blockDim>>>
     (
         a,
-        b->cuda_q_weight,
-        b->cuda_gptq_qzeros,
-        b->cuda_gptq_scales,
+        b_q_weight,
+        b_gptq_qzeros,
+        b_gptq_scales,
         c,
         size_m,
         size_n,
         size_k,
-        b->groups,
-        b->groupsize,
-        b->cuda_q_perm,
-        clear
+        groups,
+        b_q_perm
+    );
+}
+
+
+__global__ void reconstruct_exllama_kernel
+(
+    const uint32_t* __restrict__ b_q_weight,
+    const int* __restrict__ b_q_perm,
+    const uint32_t* __restrict__ b_gptq_qzeros,
+    const half* __restrict__ b_gptq_scales,
+    const int size_k,
+    const int size_n,
+    const int groups,
+    half* __restrict__ b
+)
+{
+    MatrixView_half_rw b_(b, size_k, size_n);
+    MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
+    MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
+
+    int offset_k = BLOCK_KN_SIZE * blockIdx.y;
+    int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
+
+    int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
+
+    // Preload remapping table
+    __shared__ int perm[BLOCK_KN_SIZE];
+    int t = threadIdx.x;
+
+    if (b_q_perm)
+    {
+        if (offset_k + t < size_k)
+            perm[t] = b_q_perm[offset_k + t];
+    }
+
+    // Column
+    int n = offset_n + t * 4;
+    if (n >= size_n) return;
+
+    // Find initial group
+    int groupsize = size_k / groups;
+    int group = offset_k / groupsize;
+    int nextgroup = offset_k + groupsize;
+
+    // b offset
+    int qk = offset_k / (32 / 4);
+
+    const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
+
+    // Initial zeros/scale
+    int zeros[4];
+    half2 scales[4];
+    half2 z1z16[4][2];
+    half2 y1y16[4][2];
+    b_gptq_qzeros_.item4(zeros, group, n);
+    b_gptq_scales_.item4_h2(scales, group, n);
+    dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
+    dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
+    dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
+    dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
+
+    __syncthreads();
+
+    int k = offset_k;
+    int lk = 0;
+
+    while (k < end_k)
+    {
+        if (k == nextgroup)
+        {
+            group++;
+            nextgroup += groupsize;
+            b_gptq_qzeros_.item4(zeros, group, n);
+            b_gptq_scales_.item4_h2(scales, group, n);
+            dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
+            dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
+            dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
+            dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
+        }
+
+        for (int p = 0; p < 4; p++)
+        {
+            half2 dq[4][4];
+            const int4* b_ptr4 = (int4*) b_ptr;
+            int4 load_int4 = *b_ptr4;
+
+            dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
+            dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
+            dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
+            dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
+
+            b_ptr += size_n;
+            //half* dqh = (half*)dq;
+            if (b_q_perm)
+            {
+                for (int j = 0; j < 4; j++)
+                {
+                    for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
+                    b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
+                    b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
+                }
+            }
+            else
+            {
+                for (int j = 0; j < 4; j++)
+                {
+                    for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
+                    b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
+                    b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
+                }
+            }
+        }
+        k += 32;
+    }
+}
+
+
+void reconstruct_exllama
+(
+    const uint32_t* b_q_weight,
+    const uint32_t* b_gptq_qzeros,
+    const half* b_gptq_scales,
+    const int* b_q_perm,
+    half* out,
+    int height,
+    int width,
+    int groups
+)
+{
+    dim3 blockDim, gridDim;
+    blockDim.x = BLOCK_KN_SIZE;
+    blockDim.y = 1;
+    gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
+    gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
+
+    reconstruct_exllama_kernel<<<gridDim, blockDim>>>
+    (
+        b_q_weight,
+        b_q_perm,
+        b_gptq_qzeros,
+        b_gptq_scales,
+        height,
+        width,
+        groups,
+        out
+    );
+}
+
+
+__global__ void gemm_half_q_half_alt_kernel(
+    const half2* __restrict__ vec,
+    const uint32_t* __restrict__ mat,
+    half* __restrict__ mul,
+    const half* __restrict__ scales,
+    const uint32_t* __restrict__ zeros,
+    const int* __restrict__ g_idx,
+    int batch,
+    int height,
+    int width
+)
+{
+    int zero_width = width / 8;
+    int vec_height = height * 4;
+    const int blockwidth2 = BLOCK_KN_SIZE / 2;
+    int b = blockIdx.y * BLOCK_M_SIZE_MAX;
+    int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
+    int h = BLOCK_KN_SIZE * blockIdx.z / 8;
+    int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
+    int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
+
+    __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
+    if (threadIdx.x < h_end) {
+        for (int m = 0; m < b_end; ++m) {
+          blockvec[m][threadIdx.x] =
+              vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
+                  threadIdx.x];
+        }
+    }
+
+    __shared__ half2 deq2[256][8];
+    int val = threadIdx.x / 8;
+    int off = threadIdx.x % 8;
+    for (; val < 256; val += BLOCK_KN_SIZE / 8) {
+        deq2[val][off] = __halves2half2(
+            __int2half_rn(val & 0xF), __int2half_rn(val >> 4)
+        );
+    }
+
+    if (blockIdx.z == 0)
+    {
+        for (int m = 0; m < b_end; m++)
+            mul[(b + m) * width + w] = __int2half_rn(0);
+    }
+    __syncthreads();
+
+    int i = width * h + w;
+    int g_h = h * 8;
+    int k = 0;
+    int z_w = w / 8;
+    int z_mod = (w % 8) * 4;
+    half2 res2;
+    half res[BLOCK_M_SIZE_MAX] = {};
+
+    unsigned int tmp;
+    while (k < h_end) {
+        tmp = mat[i];
+        half2 scales_tmp[4];
+        half2 zeros_tmp[4];
+        for (int tmp_k = 0; tmp_k < 4; tmp_k++) {
+            int g = g_idx[g_h + (k + tmp_k) * 2];
+            int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
+            half scale_f = scales[g * width + w];
+            half scale_f2 = scales[g2 * width + w];
+            half2 scale = __halves2half2(scale_f, scale_f2);
+            half2 zero = __halves2half2(
+                __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)),
+                __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))
+            );
+            scales_tmp[tmp_k] = scale;
+            zeros_tmp[tmp_k] = zero;
+        }
+        for (int m = 0; m < b_end; m++) {
+            res2 = {};
+            res2 = __hfma2(__hfma2(deq2[(tmp >>  0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
+            res2 = __hfma2(__hfma2(deq2[(tmp >>  8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
+            res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
+            res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
+            res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
+        }
+        i += width;
+        k += 4;
+    }
+    for (int m = 0; m < b_end; m++) {
+        atomicAdd(&mul[(b + m) * width + w], res[m]);
+    }
+}
+
+
+void gemm_half_q_half_alt
+(
+    const half* a,
+    const uint32_t* b_q_weight,
+    const uint32_t* b_gptq_qzeros,
+    const half* b_gptq_scales,
+    const int* b_g_idx,
+    half* c,
+    int size_m,
+    int size_n,
+    int size_k
+)
+{
+    dim3 blockDim, gridDim;
+    blockDim.x = BLOCK_KN_SIZE;
+    blockDim.y = 1;
+    blockDim.z = 1;
+    gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE);
+    gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
+    gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
+
+    gemm_half_q_half_alt_kernel<<<gridDim, blockDim>>>
+    (
+        (const half2*) a,
+        b_q_weight,
+        c,
+        b_gptq_scales,
+        b_gptq_qzeros,
+        b_g_idx,
+        size_m,
+        size_k / 8,
+        size_n
+    );
+}
+
+
+__global__ void reconstruct_gptq_kernel
+(
+    const uint32_t* __restrict__ w,
+    const half* __restrict__ w_scales,
+    const uint32_t* __restrict__ w_zeros,
+    const int* __restrict__ g_idx,
+    const int height,
+    const int width,
+    const int group,
+    half* __restrict__ out
+)
+{
+    // Start of block
+
+    int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
+    int row = blockIdx.y * 8;
+    if (column >= width) return;
+
+    // Views
+
+    MatrixView_q4_column w_(w, height, width);
+    MatrixView_half_rw out_(out, height, width);
+    MatrixView_half w_scales_(w_scales, group, width);
+    MatrixView_q4_row w_zeros_(w_zeros, group, width);
+
+    uint32_t w_read = w_.item_uint32_t(row, column);
+    half* out_ptr = out_.item_ptr(row, column);
+
+    #pragma unroll
+    for (int s = 0; s < 32; s += 4)
+    {
+        int group = g_idx[row + s / 4];
+        half w_scale = w_scales_.item(group, column);
+        uint32_t w_zero = w_zeros_.item(group, column) + 1;
+        half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
+        *out_ptr = w_item; out_ptr += out_.width;
+    }
+}
+
+
+void reconstruct_gptq
+(
+    const uint32_t* b_q_weight,
+    const uint32_t* b_gptq_qzeros,
+    const half* b_gptq_scales,
+    const int* b_g_idx,
+    half* out,
+    int height,
+    int width,
+    int groups
+)
+{
+    dim3 blockDim, gridDim;
+    blockDim.x = BLOCK_KN_SIZE;
+    blockDim.y = 1;
+    gridDim.y = DIVIDE(height, 8);
+    gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
+    reconstruct_gptq_kernel<<<gridDim, blockDim>>>
+    (
+        b_q_weight,
+        b_gptq_scales,
+        b_gptq_qzeros,
+        b_g_idx,
+        height,
+        width,
+        groups,
+        out
     );
 }
 
+
 void gemm_half_q_half_cuda
 (
     cublasHandle_t cublas_handle,
     const half* a,
-    QMatrix* b,
+    const uint32_t* b_q_weight,
+    const uint32_t* b_gptq_qzeros,
+    const half* b_gptq_scales,
+    const int* b_g_idx,
     half* c,
+    half* temp_dq,
     int size_m,
     int size_n,
     int size_k,
-    bool clear,
-    half* temp_dq,
-    bool force_cuda
+    int groups,
+    bool use_exllama
 )
 {
-    if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
-    {
-
+    if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) {
         // Reconstruct FP16 matrix, then cuBLAS
-        b->reconstruct(temp_dq);
-
-        //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
+        if (use_exllama) {
+            reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
+                                size_k, size_n, groups);
+        }
+        else
+        {
+            reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
+                             temp_dq, size_k, size_n, groups);
+        }
 
         const half alpha = __float2half(1.0f);
-        const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
+        const half beta = __float2half(0.0f);
         cublasHgemm(cublas_handle,
                     CUBLAS_OP_N,
                     CUBLAS_OP_N,
@@ -111,56 +681,179 @@ void gemm_half_q_half_cuda
                     &alpha, temp_dq, size_n,
                             a,       size_k,
                     &beta,  c,       size_n);
-
     }
-    else
+    else if (use_exllama)
     {
         // Quantized matmul
-
-        //if (clear) clear_tensor_cuda(c, size_m, size_n);
-
         int max_chunks = size_m / BLOCK_M_SIZE_MAX;
         int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
         int last_chunk_size = size_m - last_chunk;
 
         if (max_chunks)
         {
-            gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
+            gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
+                                        c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX,
+                                        groups);
         }
 
         if (last_chunk_size)
         {
-            gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
+            gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros,
+                                        b_gptq_scales, b_g_idx, c + last_chunk * size_n,
+                                        last_chunk_size, size_n, size_k, last_chunk_size,
+                                        groups);
         }
     }
+    else
+    {
+        gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
+                             c, size_m, size_n, size_k);
+    }
 }
 
-__global__ void clear_kernel
+
+__global__ void shuffle_kernel
 (
-    half* __restrict__ c,
-    const int size_m,
+    uint32_t* __restrict__ b_q_weight,
+    const int size_k,
     const int size_n
 )
 {
-    int m = blockIdx.y;
-    int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;
+    int n = blockIdx.x * THREADS_X + threadIdx.x;
     if (n >= size_n) return;
-    int4* c_ptr = (int4*)(c + m * size_n + n);
-    *c_ptr = {};
+    int k = 0;
+    uint32_t* b_ptr = b_q_weight + n;
+    while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k +=  8; }
 }
 
-void clear_tensor_cuda
+
+__global__ void make_sequential_kernel
 (
-    half* c,
-    int size_m,
-    int size_n
+    const uint32_t* __restrict__ w,
+    uint32_t* __restrict__ w_new,
+    const int* __restrict__ q_perm,
+    const int w_height,
+    const int w_width
 )
 {
-    return;
+    const uint64_t* w2 = (uint64_t*) w;
+    uint64_t* w_new2 = (uint64_t*) w_new;
+    int w2_stride = w_width >> 1;
+    int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
+    if (w2_column >= w2_stride) return;
+    int w_new2_row = blockIdx.y;
+    int q_perm_idx = w_new2_row << 3;
+    uint64_t dst = 0;
+
+    #pragma unroll
+    for (int i = 0; i < 8; i++)
+    {
+        int source_row = q_perm[q_perm_idx++];
+
+        int w2_row = source_row >> 3;
+        int w2_subrow = source_row & 0x07;
+        int w2_row_shift = w2_subrow << 2;
+        int wnew2_row_shift = i << 2;
+
+        uint64_t src = w2[w2_row * w2_stride + w2_column];
+        src >>= w2_row_shift;
+        src &= 0x0000000f0000000f;
+        src <<= wnew2_row_shift;
+        dst |= src;
+    }
+    w_new2[w_new2_row * w2_stride + w2_column] = dst;
+}
+
+
+void shuffle_exllama_weight
+(
+    uint32_t* q_weight,
+    int* q_perm,
+    int height,
+    int width
+)
+{
+    if (q_perm)
+    {
+        uint32_t* new_qweight = NULL;
+        cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t));
+
+        dim3 blockDim, gridDim;
+        blockDim.x = THREADS_X;
+        blockDim.y = 1;
+        gridDim.x = DIVIDE(width, THREADS_X);
+        gridDim.y = height / 8;
+
+        make_sequential_kernel<<<gridDim, blockDim>>>
+        (
+            q_weight,
+            new_qweight,
+            q_perm,
+            height / 8,
+            width
+        );
+        // Replace qweights
+        cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
+        // Cleanup
+        cudaDeviceSynchronize();
+        cudaFree(new_qweight);
+    }
     dim3 blockDim, gridDim;
-    blockDim.x = CLEAR_N_SIZE;
+    blockDim.x = THREADS_X;
     blockDim.y = 1;
-    gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
-    gridDim.y = size_m;
-    clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
+    gridDim.x = DIVIDE(width, THREADS_X);
+    gridDim.y = 1;
+    shuffle_kernel<<<gridDim, blockDim>>>(q_weight, height, width);
+}
+
+}  // namespace gptq
+}  // namespace aphrodite
+
+torch::Tensor gptq_gemm
+(
+    torch::Tensor a,
+    torch::Tensor b_q_weight,
+    torch::Tensor b_gptq_qzeros,
+    torch::Tensor b_gptq_scales,
+    torch::Tensor b_g_idx,
+    bool use_exllama
+)
+{
+    const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
+    auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
+    at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
+    at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options);
+
+    aphrodite::gptq::gemm_half_q_half_cuda
+    (
+        at::cuda::getCurrentCUDABlasHandle(),
+        (const half*) a.data_ptr(),
+        (const uint32_t*) b_q_weight.data_ptr(),
+        (const uint32_t*)b_gptq_qzeros.data_ptr(),
+        (const half*) b_gptq_scales.data_ptr(),
+        b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
+        (half*) c.data_ptr(),
+        (half*) temp_dq.data_ptr(),
+        c.size(0),  // m
+        c.size(1),  // n
+        a.size(1),  // k
+        b_gptq_qzeros.size(0),  // group number
+        use_exllama
+    );
+    return c;
+}
+
+void gptq_shuffle
+(
+    torch::Tensor q_weight,
+    torch::Tensor q_perm
+)
+{
+    const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
+    aphrodite::gptq::shuffle_exllama_weight(
+        (uint32_t*) q_weight.data_ptr(),
+        q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
+        q_weight.size(0) * 8,
+        q_weight.size(1)
+    );
 }

+ 0 - 33
kernels/quantization/gptq/q_gemm.cuh

@@ -1,33 +0,0 @@
-#ifndef _q_gemm_cuh
-#define _q_gemm_cuh
-
-#include <cuda_runtime.h>
-#include <cuda_fp16.h>
-#include <cstdint>
-#include <cstdio>
-#include <ATen/cuda/CUDAContext.h>
-
-#include "q_matrix.cuh"
-
-void gemm_half_q_half_cuda
-(
-    cublasHandle_t cublas_handle,
-    const half* a,
-    QMatrix* b,
-    half* c,
-    int size_m,
-    int size_n,
-    int size_k,
-    bool clear = false,
-    half* reconstruct = NULL,
-    bool force_cuda = false
-);
-
-void clear_tensor_cuda
-(
-    half* c,
-    int size_m,
-    int size_n
-);
-
-#endif

+ 0 - 217
kernels/quantization/gptq/q_gemm_kernel_gptq.cuh

@@ -1,217 +0,0 @@
-#include "compat.cuh"
-
-__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
-{
-    half2 result = {};
-    const half2* a2_ptr = (const half2*)a_ptr;
-    #pragma unroll
-    for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
-    return __hadd2(result, g_result);
-}
-
-__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
-{
-    half2 result = {};
-    const half2* a2_ptr = (const half2*)a_ptr;
-    #pragma unroll
-    for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
-    return __half2float(__low2half(result)) + __half2float(__high2half(result));
-}
-
-typedef void (*fp_gemm_half_q_half_gptq_kernel)
-(
-    const half*,
-    const uint32_t*,
-    const uint32_t*,
-    const half*,
-    half*,
-    const int,
-    const int,
-    const int,
-    const int,
-    const int,
-    const uint16_t*,
-    const bool
-);
-
-template <bool first_block, int m_count>
-__global__ void gemm_half_q_half_gptq_kernel
-(
-    const half* __restrict__ a,
-    const uint32_t* __restrict__ b_q_weight,
-    const uint32_t* __restrict__ b_gptq_qzeros,
-    const half* __restrict__ b_gptq_scales,
-    half* __restrict__ c,
-    const int size_m,
-    const int size_n,
-    const int size_k,
-    const int groups,
-    const int groupsize,
-    const uint16_t* __restrict__ b_q_perm,
-    const bool clear
-)
-{
-    MatrixView_half a_(a, size_m, size_k);
-    MatrixView_half_rw c_(c, size_m, size_n);
-    MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
-    MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
-
-    int t = threadIdx.x;
-
-    // Block
-
-    int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
-    int offset_m = blockIdx.y * m_count;
-    int offset_k = blockIdx.z * BLOCK_KN_SIZE;
-
-    int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
-    int end_m = min(offset_m + m_count, size_m);
-    int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
-
-    int n = offset_n + t * 4;
-
-    // Preload block_a
-
-    __shared__ half block_a[m_count][BLOCK_KN_SIZE];
-
-    if (offset_k + t < end_k)
-    {
-        for (int m = 0; m < m_count; ++m)
-        {
-            const half* a_ptr = a_.item_ptr(offset_m + m, 0);
-            half* block_a_ptr = block_a[m];
-
-            half a0;
-            if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
-            else a0 = a_ptr[offset_k + t];
-            block_a_ptr[t] = a0;
-        }
-    }
-
-    // Zero output
-
-    if (n >= size_n) return;
-
-    if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
-    {
-        for (int m = 0; m < m_count; m++)
-            *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
-    }
-
-    __syncthreads();
-
-    // Find initial group
-
-    int group = offset_k / groupsize;
-    int nextgroup = offset_k + groupsize;
-
-    // a, b offset
-
-    int qk = offset_k / (32 / 4);
-
-    const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
-    const half* a_ptr = &block_a[0][0];
-    int a_stride = BLOCK_KN_SIZE;
-
-    // Initial group
-
-    int zeros[4];
-    float scales[4];
-    half2 z1z16[4][2];
-    half2 y1y16[4][2];
-    b_gptq_qzeros_.item4(zeros, group, n);
-    b_gptq_scales_.item4_f(scales, group, n);
-    dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
-    dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
-    dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
-    dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
-
-//    __syncthreads();
-
-    // Column result
-
-    float block_c[m_count][4] = {};
-
-    // Dequantize and multiply
-
-    int k = offset_k;
-    while (k < end_k)
-    {
-        if (k == nextgroup)
-        {
-            group++;
-            nextgroup += groupsize;
-            b_gptq_qzeros_.item4(zeros, group, n);
-            b_gptq_scales_.item4_f(scales, group, n);
-            dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
-            dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
-            dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
-            dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
-        }
-
-        #pragma unroll
-        for (int j = 0; j < 4; j++)
-        {
-            const int4* b_ptr4 = (int4*) b_ptr;
-            int4 load_int4 = *b_ptr4;
-
-            half2 dq[4][4];
-            dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
-            dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
-            dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
-            dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
-
-            #pragma unroll
-            for (int m = 0; m < m_count; m++)
-            {
-                block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
-                block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
-                block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
-                block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
-            }
-
-            b_ptr += size_n;
-            a_ptr += 8;
-        }
-
-        k += 32;
-    }
-
-    for (int m = 0; m < m_count; m++)
-    {
-        half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
-        half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
-        half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
-        atomicAdd(out    , result01);
-        atomicAdd(out + 1, result23);
-    }
-}
-
-fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
-{
-    #if BLOCK_M_SIZE_MAX >= 1
-    if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
-    #endif
-    #if BLOCK_M_SIZE_MAX >= 2
-    if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
-    #endif
-    #if BLOCK_M_SIZE_MAX >= 3
-    if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
-    #endif
-    #if BLOCK_M_SIZE_MAX >= 4
-    if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
-    #endif
-    #if BLOCK_M_SIZE_MAX >= 5
-    if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
-    #endif
-    #if BLOCK_M_SIZE_MAX >= 6
-    if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
-    #endif
-    #if BLOCK_M_SIZE_MAX >= 7
-    if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
-    #endif
-    #if BLOCK_M_SIZE_MAX >= 8
-    if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
-    #endif
-    return NULL;
-}

+ 0 - 338
kernels/quantization/gptq/q_matrix.cu

@@ -1,338 +0,0 @@
-#include "q_matrix.cuh"
-#include "matrix_view.cuh"
-
-#include "qdq_4.cuh"
-
-#define BLOCK_KN_SIZE 128
-
-#define THREADS_X 32
-#define THREADS_Y 32
-#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
-
-// Shuffle quantized data on load
-
-__global__ void shuffle_kernel
-(
-    uint32_t* __restrict__ b_q_weight,
-    const int size_k,
-    const int size_n
-)
-{
-    int n = blockIdx.x * THREADS_X + threadIdx.x;
-    if (n >= size_n) return;
-    int k = 0;
-    uint32_t* b_ptr = b_q_weight + n;
-    while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k +=  8; }
-}
-
-
-// QMatrix constructor
-
-QMatrix::QMatrix
-(
-    const int _device,
-    const int _height,
-    const int _width,
-    const int _groups,
-
-    uint32_t* _q_weight,
-    uint16_t* _q_perm,
-    uint16_t* _q_invperm,
-
-    uint32_t* _gptq_qzeros,
-    half* _gptq_scales,
-    uint32_t* _gptq_g_idx
-) :
-    device(_device),
-    height(_height),
-    width(_width),
-    groups(_groups)
-{
-    cudaSetDevice(device);
-
-    cuda_q_weight = _q_weight;
-    cuda_q_perm = _q_perm;
-    cuda_q_invperm = _q_invperm;
-    cuda_gptq_qzeros = _gptq_qzeros;
-    cuda_gptq_scales = _gptq_scales;
-
-    is_gptq = true;
-
-    groupsize = 1;
-    while (groupsize * groups < height) groupsize *= 2;
-
-    if (_gptq_g_idx) make_sequential(_gptq_g_idx);
-
-    // Shuffle quantized data
-
-    dim3 blockDim, gridDim;
-    blockDim.x = THREADS_X;
-    blockDim.y = 1;
-    gridDim.x = DIVIDE(width, THREADS_X);
-    gridDim.y = 1;
-
-    shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width);
-}
-
-
-// Reconstruct b[k,n] (GPTQ)
-
-__global__ void reconstruct_gptq_kernel
-(
-    const uint32_t* __restrict__ b_q_weight,
-    const uint16_t* __restrict__ b_q_perm,
-    const uint32_t* __restrict__ b_gptq_qzeros,
-    const half* __restrict__ b_gptq_scales,
-    const int size_k,
-    const int size_n,
-    const int groupsize,
-    const int groups,
-    half* __restrict__ b
-)
-{
-    MatrixView_half_rw b_(b, size_k, size_n);
-    MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
-    MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
-
-    int offset_k = BLOCK_KN_SIZE * blockIdx.y;
-    int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
-
-    int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
-
-    // Preload remapping table
-
-    __shared__ uint16_t perm[BLOCK_KN_SIZE];
-    int t = threadIdx.x;
-
-    if (b_q_perm)
-    {
-        if (offset_k + t < size_k)
-            perm[t] = b_q_perm[offset_k + t];
-    }
-
-    // Column
-
-    int n = offset_n + t * 4;
-    if (n >= size_n) return;
-
-    // Find initial group
-
-    int group = offset_k / groupsize;
-    int nextgroup = offset_k + groupsize;
-
-    // b offset
-
-    int qk = offset_k / (32 / 4);
-
-    const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
-
-    // Initial zeros/scale
-
-    int zeros[4];
-    half2 scales[4];
-    half2 z1z16[4][2];
-    half2 y1y16[4][2];
-    b_gptq_qzeros_.item4(zeros, group, n);
-    b_gptq_scales_.item4_h2(scales, group, n);
-    dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
-    dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
-    dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
-    dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
-
-    __syncthreads();
-
-    int k = offset_k;
-    int lk = 0;
-
-    while (k < end_k)
-    {
-        if (k == nextgroup)
-        {
-            group++;
-            nextgroup += groupsize;
-            b_gptq_qzeros_.item4(zeros, group, n);
-            b_gptq_scales_.item4_h2(scales, group, n);
-            dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
-            dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
-            dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
-            dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
-        }
-
-        for (int p = 0; p < 4; p++)
-        {
-            half2 dq[4][4];
-            const int4* b_ptr4 = (int4*) b_ptr;
-            int4 load_int4 = *b_ptr4;
-
-            dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
-            dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
-            dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
-            dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
-
-            b_ptr += size_n;
-            //half* dqh = (half*)dq;
-            if (b_q_perm)
-            {
-                for (int j = 0; j < 4; j++)
-                {
-                    for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
-                    b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
-                    b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
-                }
-            }
-            else
-            {
-                for (int j = 0; j < 4; j++)
-                {
-                    for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
-                    b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
-                    b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
-                }
-            }
-        }
-        k += 32;
-    }
-}
-
-void QMatrix::reconstruct(half* out)
-{
-    dim3 blockDim, gridDim;
-    blockDim.x = BLOCK_KN_SIZE;
-    blockDim.y = 1;
-    gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
-    gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
-
-    reconstruct_gptq_kernel<<<gridDim, blockDim>>>
-    (
-        cuda_q_weight,
-        cuda_q_perm,
-        cuda_gptq_qzeros,
-        cuda_gptq_scales,
-        height,
-        width,
-        groupsize,
-        groups,
-        out
-    );
-}
-
-__global__ void make_sequential_kernel
-(
-    const uint32_t* __restrict__ w,
-    uint32_t* __restrict__ w_new,
-    const uint16_t* __restrict__ q_perm,
-    const int w_height,
-    const int w_width
-)
-{
-    const uint64_t* w2 = (uint64_t*) w;
-    uint64_t* w_new2 = (uint64_t*) w_new;
-    int w2_stride = w_width >> 1;
-
-    int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
-    if (w2_column >= w2_stride) return;
-
-    int w_new2_row = blockIdx.y;
-
-    int q_perm_idx = w_new2_row << 3;
-
-    uint64_t dst = 0;
-
-    #pragma unroll
-    for (int i = 0; i < 8; i++)
-    {
-        int source_row = q_perm[q_perm_idx++];
-
-        int w2_row = source_row >> 3;
-        int w2_subrow = source_row & 0x07;
-        int w2_row_shift = w2_subrow << 2;
-        int wnew2_row_shift = i << 2;
-
-        uint64_t src = w2[w2_row * w2_stride + w2_column];
-        src >>= w2_row_shift;
-        src &= 0x0000000f0000000f;
-        src <<= wnew2_row_shift;
-        dst |= src;
-    }
-
-    w_new2[w_new2_row * w2_stride + w2_column] = dst;
-}
-
-void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
-{
-    uint32_t* cuda_new_qweight = NULL;
-    cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
-
-    uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
-    uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
-    uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
-
-    // Group histogram
-
-    for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
-
-    // Group map
-
-    for (int i = 0, acc = 0; i < groups; i++)
-    {
-        short tmp = cpu_g_idx_map[i];
-        cpu_g_idx_map[i] = acc;
-        acc += tmp;
-    }
-
-    // X map (inverse)
-
-    for (int row = 0; row < height; row++)
-    {
-        uint32_t target_group = cpu_g_idx[row];
-        uint32_t target_row = cpu_g_idx_map[target_group];
-        cpu_g_idx_map[target_group]++;
-        cpu_x_map_inv[row] = target_row;
-    }
-
-    // X map
-
-    for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
-
-    // Reduce to uint16_t
-
-    uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
-    uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
-    for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
-    for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
-
-    // Move to CUDA
-
-    cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
-    cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
-
-    // Rearrange rows in w
-
-    dim3 blockDim, gridDim;
-    blockDim.x = THREADS_X;
-    blockDim.y = 1;
-    gridDim.x = DIVIDE(width, THREADS_X);
-    gridDim.y = height / 8;
-
-    make_sequential_kernel<<<gridDim, blockDim>>>
-    (
-        cuda_q_weight,
-        cuda_new_qweight,
-        cuda_q_perm,
-        height / 8,
-        width
-    );
-
-    // Replace qweights
-
-    cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
-
-    // Cleanup
-
-    cudaDeviceSynchronize();
-
-    cudaFree(cuda_new_qweight);
-    free(cpu_g_idx_map);
-    free(cpu_x_map);
-    free(cpu_x_map_inv);
-}

+ 0 - 54
kernels/quantization/gptq/q_matrix.cuh

@@ -1,54 +0,0 @@
-#ifndef _q_matrix_cuh
-#define _q_matrix_cuh
-
-#include <cuda_runtime.h>
-#include <cuda_fp16.h>
-#include <cstdint>
-#include <cstdio>
-
-#define MAX_SUPERGROUPS 16
-
-class QMatrix
-{
-public:
-
-    int device;
-    bool is_gptq;
-
-    int height;
-    int width;
-    int groups;
-    int groupsize;
-
-    uint32_t* cuda_q_weight = NULL;
-    uint16_t* cuda_q_perm = NULL;
-    uint16_t* cuda_q_invperm = NULL;
-    uint32_t* cuda_gptq_qzeros = NULL;
-    half* cuda_gptq_scales = NULL;
-
-    QMatrix
-    (
-        const int _device,
-        const int _height,
-        const int _width,
-        const int _groups,
-
-        uint32_t* _q_weight,
-        uint16_t* _q_perm,
-        uint16_t* _q_invperm,
-
-        uint32_t* _gptq_qzeros,
-        half* _gptq_scales,
-        uint32_t* _gptq_g_idx
-    );
-
-    ~QMatrix();
-
-    void reconstruct(half* out);
-    void make_sequential(const uint32_t* cpu_g_idx);
-
-private:
-
-};
-
-#endif

+ 13 - 0
kernels/quantization/gptq/qdq_4.cuh

@@ -1,8 +1,14 @@
+/*
+Copied from https://github.com/turboderp/exllamav2
+*/
+
 #ifndef _qdq_4_cuh
 #define _qdq_4_cuh
 
 #include "qdq_util.cuh"
 
+namespace aphrodite {
+namespace gptq {
 // Permutation:
 //
 // 77775555 33331111  66664444 22220000
@@ -134,9 +140,13 @@ __forceinline__ __device__ void dequant_4bit_8_gptq
         dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);  // half2( q[6] - z, q[7] - z )
     }
 }
+}  // namespace gptq
+}  // namespace aphrodite
 
 #else
 
+namespace aphrodite {
+namespace gptq {
 __forceinline__ __device__ void shuffle_4bit_8
 (
     uint32_t* q,
@@ -219,4 +229,7 @@ __forceinline__ __device__ void dequant_4bit_8_gptq
     }
 }
 
+}  // namespace gptq
+}  // namespace aphrodite
+
 #endif

+ 9 - 0
kernels/quantization/gptq/qdq_util.cuh

@@ -1,6 +1,13 @@
+/*
+Copied from https://github.com/turboderp/exllamav2
+*/
+
 #ifndef _qdq_util_cuh
 #define _qdq_util_cuh
 
+namespace aphrodite {
+namespace gptq {
+
 union half2_uint32
 {
     uint32_t as_uint32;
@@ -48,4 +55,6 @@ __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const i
     return (int)(__funnelshift_rc(q0, q1, shift) & mask);
 }
 
+}  // namespace gptq
+}  // namespace aphrodite
 #endif

+ 0 - 3
setup.py

@@ -219,10 +219,7 @@ aphrodite_extension_sources = [
     "kernels/activation_kernels.cu",
     "kernels/layernorm_kernels.cu",
     "kernels/quantization/squeezellm/quant_cuda_kernel.cu",
-    "kernels/quantization/gptq/exllama_ext.cpp",
-    "kernels/quantization/gptq/q_matrix.cu",
     "kernels/quantization/gptq/q_gemm.cu",
-    "kernels/quantization/gptq/old_matmul_kernel.cu",
     "kernels/cuda_utils_kernels.cu",
     "kernels/pybind.cpp",
 ]