Bläddra i källkod

feat: GGUF, QuIP#, and Marlin support (#228)

* add gguf/marlin/quip kernels and modeling code

* add support for all models

* add conversion script

* fix quant init

* fix pybind and ops, add to benchmark script

* formatting

* add fast-hadamard-transform as a dep
AlpinDale 1 år sedan
förälder
incheckning
c3a221eb02
35 ändrade filer med 7049 tillägg och 151 borttagningar
  1. 0 1
      .gitignore
  2. 1 0
      .pylintrc
  3. 2 2
      aphrodite/common/config.py
  4. 4 3
      aphrodite/endpoints/llm.py
  5. 12 11
      aphrodite/engine/args_tools.py
  6. 5 0
      aphrodite/modeling/hf_downloader.py
  7. 25 5
      aphrodite/modeling/layers/linear.py
  8. 4 0
      aphrodite/modeling/layers/quantization/__init__.py
  9. 6 0
      aphrodite/modeling/layers/quantization/awq.py
  10. 13 1
      aphrodite/modeling/layers/quantization/base_config.py
  11. 150 0
      aphrodite/modeling/layers/quantization/gguf.py
  12. 118 17
      aphrodite/modeling/layers/quantization/gptq.py
  13. BIN
      aphrodite/modeling/layers/quantization/hadamard.safetensors
  14. 199 0
      aphrodite/modeling/layers/quantization/quip.py
  15. 126 0
      aphrodite/modeling/layers/quantization/quip_utils.py
  16. 6 0
      aphrodite/modeling/layers/quantization/squeezellm.py
  17. 6 7
      aphrodite/modeling/layers/sampler.py
  18. 35 20
      aphrodite/modeling/layers/vocab_parallel_embedding.py
  19. 37 10
      aphrodite/modeling/models/gpt_j.py
  20. 6 1
      aphrodite/modeling/models/gpt_neox.py
  21. 67 17
      aphrodite/modeling/models/llama.py
  22. 67 17
      aphrodite/modeling/models/mistral.py
  23. 46 15
      aphrodite/modeling/models/mixtral.py
  24. 336 0
      aphrodite/modeling/models/phi.py
  25. 71 22
      aphrodite/modeling/models/yi.py
  26. 2 0
      aphrodite/task_handler/model_runner.py
  27. 151 0
      examples/gguf_to_torch.py
  28. 46 0
      kernels/ops.h
  29. 7 0
      kernels/pybind.cpp
  30. 3925 0
      kernels/quantization/gguf/gguf_kernel.cu
  31. 843 0
      kernels/quantization/marlin/marlin_cuda_kernel.cu
  32. 722 0
      kernels/quantization/quip/origin_order.cu
  33. 3 0
      requirements.txt
  34. 4 0
      setup.py
  35. 4 2
      tests/benchmarks/throughput.py

+ 0 - 1
.gitignore

@@ -12,7 +12,6 @@ dist*
 conda/
 umamba.exe
 bin/
-models/
 *.whl
 # Byte-compiled / optimized / DLL files
 __pycache__/

+ 1 - 0
.pylintrc

@@ -92,6 +92,7 @@ disable=abstract-method,
         logging-not-lazy,
         long-builtin,
         long-suffix,
+        line-too-long,
         map-builtin-not-iterating,
         misplaced-comparison-constant,
         missing-class-docstring,

+ 2 - 2
aphrodite/common/config.py

@@ -148,8 +148,8 @@ class ModelConfig:
         self.tokenizer_mode = tokenizer_mode
 
     def _verify_quantization(self) -> None:
-        supported_quantization = ["awq", "gptq", "squeezellm"]
-        rocm_not_supported_quantization = ["awq"]
+        supported_quantization = ["awq", "gguf", "gptq", "quip", "squeezellm"]
+        rocm_not_supported_quantization = ["awq", "quip"]
         if self.quantization is not None:
             self.quantization = self.quantization.lower()
 

+ 4 - 3
aphrodite/endpoints/llm.py

@@ -39,9 +39,10 @@ class LLM:
             However, if the `torch_dtype` in the config is `float32`, we will
             use `float16` instead.
         quantization: The method used to quantize the model weights. Currently,
-            we support "awq", "squeezellm", and "gptq". If None, we assume the
-            model weights are not quantized and use `dtype` to determine the
-            data type of the weights.
+            we support "awq", "gptq", "quip" and "squeezellm". If None,
+            we first check the `quantization_config` attribute in the model
+            config file. If that is None, we assume the model weights are not
+            quantized and use `dtype` to determine the data type of the weights.
         revision: The specific model version to use. It can be a branch name,
             a tag name, or a commit id.
         seed: The seed to initialize the random number generator for sampling.

+ 12 - 11
aphrodite/engine/args_tools.py

@@ -197,17 +197,18 @@ class EngineArgs:
                             action='store_true',
                             help='disable logging statistics')
         # Quantization settings.
-        parser.add_argument('--quantization',
-                            '-q',
-                            type=str,
-                            choices=['awq', 'gptq', 'squeezellm', None],
-                            default=None,
-                            help='Method used to quantize the weights. If '
-                            'None, we first check the `quantization_config` '
-                            'attribute in the model config file. If that is '
-                            'None, we assume the model weights are not '
-                            'quantized and use `dtype` to determine the data '
-                            'type of the weights.')
+        parser.add_argument(
+            '--quantization',
+            '-q',
+            type=str,
+            choices=['awq', 'gguf', 'gptq', 'quip', 'squeezellm', None],
+            default=None,
+            help='Method used to quantize the weights. If '
+            'None, we first check the `quantization_config` '
+            'attribute in the model config file. If that is '
+            'None, we assume the model weights are not '
+            'quantized and use `dtype` to determine the data '
+            'type of the weights.')
         parser.add_argument('--enforce-eager',
                             action='store_true',
                             help='Always use eager-mode PyTorch. If False, '

+ 5 - 0
aphrodite/modeling/hf_downloader.py

@@ -90,6 +90,9 @@ def get_quant_config(
     cache_dir: Optional[str] = None,
 ) -> QuantizationConfig:
     quant_cls = get_quantization_config(quantization)
+    # No need for extra config
+    if quantization == "gguf":
+        return quant_cls()
     # Read the quantization config from the HF model config, if available.
     hf_quant_config = getattr(hf_config, "quantization_config", None)
     if hf_quant_config is not None:
@@ -281,6 +284,8 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
 def default_weight_loader(param: torch.Tensor,
                           loaded_weight: torch.Tensor) -> None:
     """Default weight loader."""
+    if isinstance(param, torch.nn.parameter.UninitializedParameter):
+        param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
     assert param.size() == loaded_weight.size()
     param.data.copy_(loaded_weight)
 

+ 25 - 5
aphrodite/modeling/layers/linear.py

@@ -71,6 +71,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
             return F.linear(x, weight)
         return F.linear(x, weight, bias)
 
+    def apply_embedding(self, weights: Dict[str, torch.Tensor],
+                        x: torch.Tensor) -> torch.Tensor:
+        weight = weights["weight"]
+        return F.embedding(x, weight)
+
 
 class ReplicatedLinear(torch.nn.Module):
     """Replicated linear layer.
@@ -109,7 +114,7 @@ class ReplicatedLinear(torch.nn.Module):
             self.input_size, self.output_size, self.input_size,
             self.output_size, self.params_dtype)
         for name, weight in self.linear_weights.items():
-            if isinstance(weight, torch.Tensor):
+            if isinstance(weight, torch.nn.parameter.Parameter):
                 self.register_parameter(name, weight)
         if bias:
             self.bias = Parameter(
@@ -177,7 +182,7 @@ class ColumnParallelLinear(torch.nn.Module):
             self.input_size, self.output_size_per_partition, self.input_size,
             self.output_size, self.params_dtype)
         for name, weight in self.linear_weights.items():
-            if isinstance(weight, torch.Tensor):
+            if isinstance(weight, torch.nn.parameter.Parameter):
                 self.register_parameter(name, weight)
                 set_weight_attrs(weight, {"weight_loader": self.weight_loader})
         if bias:
@@ -194,13 +199,20 @@ class ColumnParallelLinear(torch.nn.Module):
 
     def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
         tp_rank = get_tensor_model_parallel_rank()
+        tp_size = get_tensor_model_parallel_world_size()
         output_dim = getattr(param, "output_dim", None)
         param_data = param.data
         if output_dim is not None:
-            shard_size = param_data.shape[output_dim]
+            if loaded_weight.shape[output_dim] % tp_size != 0:
+                raise ValueError("Size is not aligned with the "
+                                 "quantized weight shape")
+            shard_size = loaded_weight.shape[output_dim] // tp_size
             start_idx = tp_rank * shard_size
             loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                  shard_size)
+        if isinstance(param, torch.nn.parameter.UninitializedParameter):
+            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
+            param_data = param.data
         assert param_data.shape == loaded_weight.shape
         param_data.copy_(loaded_weight)
 
@@ -499,7 +511,7 @@ class RowParallelLinear(torch.nn.Module):
             self.input_size_per_partition, self.output_size, self.input_size,
             self.output_size, self.params_dtype)
         for name, weight in self.linear_weights.items():
-            if isinstance(weight, torch.Tensor):
+            if isinstance(weight, torch.nn.parameter.Parameter):
                 self.register_parameter(name, weight)
                 set_weight_attrs(weight, {"weight_loader": self.weight_loader})
 
@@ -521,13 +533,21 @@ class RowParallelLinear(torch.nn.Module):
 
     def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
         tp_rank = get_tensor_model_parallel_rank()
+        tp_size = get_tensor_model_parallel_world_size()
         input_dim = getattr(param, "input_dim", None)
         param_data = param.data
         if input_dim is not None:
-            shard_size = param_data.shape[input_dim]
+            if loaded_weight.shape[input_dim] % tp_size != 0:
+                raise ValueError("Size is not aligned with the quantized "
+                                 "weight shape")
+
+            shard_size = loaded_weight.shape[input_dim] // tp_size
             start_idx = tp_rank * shard_size
             loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                  shard_size)
+        if isinstance(param, torch.nn.parameter.UninitializedParameter):
+            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
+            param_data = param.data
         assert param_data.shape == loaded_weight.shape
         param_data.copy_(loaded_weight)
 

+ 4 - 0
aphrodite/modeling/layers/quantization/__init__.py

@@ -2,12 +2,16 @@ from typing import Type
 
 from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
 from aphrodite.modeling.layers.quantization.awq import AWQConfig
+from aphrodite.modeling.layers.quantization.gguf import GGUFConfig
 from aphrodite.modeling.layers.quantization.gptq import GPTQConfig
+from aphrodite.modeling.layers.quantization.quip import QuipConfig
 from aphrodite.modeling.layers.quantization.squeezellm import SqueezeLLMConfig
 
 _QUANTIZATION_CONFIG_REGISTRY = {
     "awq": AWQConfig,
+    "gguf": GGUFConfig,
     "gptq": GPTQConfig,
+    "quip": QuipConfig,
     "squeezellm": SqueezeLLMConfig,
 }
 

+ 6 - 0
aphrodite/modeling/layers/quantization/awq.py

@@ -66,6 +66,12 @@ class AWQConfig(QuantizationConfig):
     def get_scaled_act_names(self) -> List[str]:
         return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
 
+    def merge_weight(self) -> bool:
+        return True
+
+    def rope_style(self) -> Optional[bool]:
+        return None
+
 
 class AWQLinearMethod(LinearMethodBase):
     """Linear method for AWQ.

+ 13 - 1
aphrodite/modeling/layers/quantization/base_config.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional
 
 import torch
 
@@ -62,3 +62,15 @@ class QuantizationConfig(ABC):
         For now, this is only used by AWQ.
         """
         raise NotImplementedError
+
+    @abstractmethod
+    def merge_weight(self) -> bool:
+        """whether fuse qkv and up/gate."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def rope_style(self) -> Optional[bool]:
+        raise NotImplementedError
+
+    def quant_vocab(self) -> Optional[bool]:
+        return False

+ 150 - 0
aphrodite/modeling/layers/quantization/gguf.py

@@ -0,0 +1,150 @@
+from typing import Any, Dict, List, Optional
+
+import torch
+from torch.nn.parameter import Parameter
+
+from aphrodite._C import ops
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              set_weight_attrs)
+from aphrodite.modeling.layers.quantization.base_config import (
+    QuantizationConfig)
+
+GGML_QUANT_SIZES = {
+    0: (1, 4),
+    1: (1, 2),
+    2: (32, 2 + 16),
+    3: (32, 2 + 2 + 16),
+    6: (32, 2 + 4 + 16),
+    7: (32, 2 + 2 + 4 + 16),
+    8: (32, 2 + 32),
+    9: (32, 4 + 4 + 32),
+    10: (256, 2 + 2 + 256 // 16 + 256 // 4),
+    11: (256, 2 + 256 // 4 + 256 // 8 + 12),
+    12: (256, 2 + 2 + 256 // 2 + 12),
+    13: (256, 2 + 2 + 256 // 2 + 256 // 8 + 12),
+    14: (256, 2 + 256 // 2 + 256 // 4 + 256 // 16),
+    15: (256, 4 + 256 + 256 // 8),
+    16: (256, 2 + 256 // 4),
+    17: (256, 2 + 256 // 4 + 256 // 32),
+}
+
+
+class GGUFConfig(QuantizationConfig):
+    """Config class for GGUF"""
+
+    def __repr__(self) -> str:
+        return "GGUFConfig()"
+
+    def get_name(self) -> str:
+        return "gguf"
+
+    def get_supported_act_dtypes(self) -> List[torch.dtype]:
+        return [torch.half]
+
+    def get_min_capability(self) -> int:
+        return 70
+
+    @staticmethod
+    def get_config_filenames() -> List[str]:
+        return []
+
+    @classmethod
+    def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
+        return cls()
+
+    def get_linear_method(self) -> "GGUFLinearMethod":
+        return GGUFLinearMethod(self)
+
+    def get_scaled_act_names(self) -> List[str]:
+        return []
+
+    def merge_weight(self) -> bool:
+        return False
+
+    def rope_style(self) -> Optional[bool]:
+        return False
+
+    def quant_vocab(self) -> Optional[bool]:
+        return True
+
+
+class GGUFLinearMethod(LinearMethodBase):
+    """Linear method for GGUF.
+    Args:
+        quant_config: The GGUF quantization config.
+    """
+
+    def __init__(self, quant_config: GGUFConfig):
+        self.quant_config = quant_config
+
+    def create_weights(self, input_size_per_partition: int,
+                       output_size_per_partition: int, input_size: int,
+                       output_size: int,
+                       params_dtype: torch.dtype) -> Dict[str, Any]:
+        # The type of weight is unknown until load state dict
+        weight = torch.nn.parameter.UninitializedParameter(requires_grad=False)
+        # No need for pack_factor because we don't fuse qkv layers anyway.
+        set_weight_attrs(weight, {
+            "input_dim": 1,
+            "output_dim": 0,
+        })
+        weight_type = Parameter(
+            torch.tensor((1), dtype=torch.int, device="cuda"),
+            requires_grad=False,
+        )
+        set_weight_attrs(weight_type, {"ignore_warning": True})
+        return {"weight": weight, "weight_type": weight_type}
+
+    def apply_weights(self,
+                      weights: Dict[str, Any],
+                      x: torch.Tensor,
+                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+        if isinstance(weights["weight_type"], torch.Tensor):
+            weights["weight_type"] = int(weights["weight_type"])
+            # Check tensor parallel shape here on first pass
+            block_size = GGML_QUANT_SIZES[weights["weight_type"]][1]
+            if weights["weight"].shape[1] % block_size != 0:
+                raise ValueError("Size is not aligned with the quantized "
+                                 "weight shape.")
+
+        weight = weights["weight"]
+        weight_type = weights["weight_type"]
+        infeatures = x.shape[-1]
+        outfeatures = weight.shape[0]
+        out_shape = x.shape[:-1] + (weight.shape[0], )
+        reshaped_x = x.reshape(-1, x.shape[-1])
+
+        xshape = x.view(-1, x.shape[-1])
+        if xshape.shape[0] == 1:
+            out = ops.ggml_mul_mat_vec_a8(weight, reshaped_x, weight_type,
+                                          outfeatures)
+        elif xshape.shape[0] < 8 and weight_type < 16:
+            out = ops.ggml_mul_mat_a8(weight, reshaped_x, weight_type,
+                                      outfeatures)
+        else:
+            weight = ops.ggml_dequantize(weight, weight_type, outfeatures,
+                                         infeatures)
+            out = reshaped_x @ weight.T
+
+        if bias is not None:
+            out = out + bias
+        return out.reshape(out_shape)
+
+    def apply_embedding(self, weights: Dict[str, torch.Tensor],
+                        x: torch.Tensor) -> torch.Tensor:
+        if isinstance(weights["weight_type"], torch.Tensor):
+            weights["weight_type"] = int(weights["weight_type"])
+        weight = weights["weight"]
+        weight_type = weights["weight_type"]
+        dim, block_size = GGML_QUANT_SIZES[weights["weight_type"]]
+        vocab_size = weight.shape[0]
+        hidden_size = weight.shape[1] // block_size * dim
+        if weight_type < 2:
+            return torch.embedding(weight.view(vocab_size, -1), x)
+        x_flat = x.flatten()
+        quant = torch.index_select(weight.view(vocab_size, -1),
+                                   dim=0,
+                                   index=x_flat)
+        dequant = ops.ggml_dequantize(quant, weight_type, hidden_size,
+                                      x_flat.shape[0])
+        return dequant.view(*x.shape, hidden_size)

+ 118 - 17
aphrodite/modeling/layers/quantization/gptq.py

@@ -3,31 +3,92 @@ from enum import Enum
 from typing import Any, Dict, List, Optional
 from fractions import Fraction
 
+import numpy as np
 import torch
 from torch.nn.parameter import Parameter
 
 from aphrodite._C import ops
+from aphrodite.common.utils import is_hip
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               set_weight_attrs)
 from aphrodite.modeling.layers.quantization.base_config import (
     QuantizationConfig)
 
 
+def _get_perms():
+    perm = []
+    for i in range(32):
+        perm1 = []
+        col = i // 4
+        for block in [0, 1]:
+            for row in [
+                    2 * (i % 4), 2 * (i % 4) + 1, 2 * (i % 4 + 4),
+                    2 * (i % 4 + 4) + 1
+            ]:
+                perm1.append(16 * row + col + 8 * block)
+        for j in range(4):
+            perm.extend([p + 256 * j for p in perm1])
+
+    perm = np.array(perm)
+    interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
+    perm = perm.reshape((-1, 8))[:, interleave].ravel()
+    perm = torch.from_numpy(perm)
+    scale_perm = []
+    for i in range(8):
+        scale_perm.extend([i + 8 * j for j in range(8)])
+    scale_perm_single = []
+    for i in range(4):
+        scale_perm_single.extend(
+            [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
+    return perm, scale_perm, scale_perm_single
+
+
+_perm, _scale_perm, _scale_perm_single = _get_perms()
+
+
+def pemute_weight(qweight, scale, group_size, g_idx=None):
+    # unpack and permute qweight
+    w = torch.bitwise_right_shift(
+        torch.unsqueeze(qweight, 1).expand(-1, 8, -1),
+        torch.tensor(list(range(0, 32, 4)),
+                     dtype=torch.int32,
+                     device=qweight.device).unsqueeze(0).unsqueeze(-1),
+    ).bitwise_and(15)
+    w = w.reshape(-1, w.shape[2]).contiguous()
+    if g_idx is not None:
+        w = w[g_idx, :]
+    tile = 16
+    w = w.reshape((w.shape[0] // tile, tile, w.shape[1] // tile, tile))
+    w = w.permute((0, 2, 1, 3)).reshape(w.shape[0], -1)
+    res = w.reshape((-1, _perm.numel()))[:, _perm].reshape(w.shape)
+    q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
+    res = res.cpu().numpy().astype(np.uint32)
+    for i in range(8):
+        q |= res[:, i::8] << 4 * i
+    q = torch.from_numpy(q.astype(np.int32)).to(w.device)
+    # permute scale
+    dim = scale.shape[1]
+    if group_size == -1:
+        scale = scale.reshape(
+            (-1, len(_scale_perm_single)))[:, _scale_perm_single]
+    else:
+        scale = scale.reshape((-1, len(_scale_perm)))[:, _scale_perm]
+    scale = scale.reshape((-1, dim)).contiguous()
+    return q, scale
+
+
 class GPTQConfig(QuantizationConfig):
     """Config class for GPTQ.
 
     Reference: https://arxiv.org/abs/2210.17323
     """
 
-    def __init__(
-        self,
-        weight_bits: int,
-        group_size: int,
-        desc_act: bool,
-    ) -> None:
+    def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
+                 sym: bool) -> None:
         self.weight_bits = weight_bits
         self.group_size = group_size
         self.desc_act = desc_act
+        self.sym = sym
         self.pack_factor = Fraction(32, self.weight_bits)
         if self.weight_bits not in [2, 3, 4, 8]:
             raise ValueError(
@@ -37,7 +98,8 @@ class GPTQConfig(QuantizationConfig):
     def __repr__(self) -> str:
         return (f"GPTQConfig(weight_bits={self.weight_bits}, "
                 f"group_size={self.group_size}, "
-                f"desc_act={self.desc_act})")
+                f"desc_act={self.desc_act}, "
+                f"sym={self.sym}")
 
     @classmethod
     def get_name(cls) -> str:
@@ -61,7 +123,8 @@ class GPTQConfig(QuantizationConfig):
         weight_bits = cls.get_from_keys(config, ["bits"])
         group_size = cls.get_from_keys(config, ["group_size"])
         desc_act = cls.get_from_keys(config, ["desc_act"])
-        return cls(weight_bits, group_size, desc_act)
+        sym = cls.get_from_keys(config, ["sym"])
+        return cls(weight_bits, group_size, desc_act, sym)
 
     def get_linear_method(self) -> "GPTQLinearMethod":
         return GPTQLinearMethod(self)
@@ -69,12 +132,20 @@ class GPTQConfig(QuantizationConfig):
     def get_scaled_act_names(self) -> List[str]:
         return []
 
+    def merge_weight(self) -> bool:
+        return True
+
+    def rope_style(self) -> Optional[bool]:
+        return None
+
 
 class ExllamaState(Enum):
 
     UNUSED = enum.auto()
     UNINITIALIZED = enum.auto()
     READY = enum.auto()
+    MARLIN_UNINITIALIZED = enum.auto()
+    MARLIN_READY = enum.auto()
 
 
 class GPTQLinearMethod(LinearMethodBase):
@@ -86,6 +157,14 @@ class GPTQLinearMethod(LinearMethodBase):
 
     def __init__(self, quant_config: GPTQConfig):
         self.quant_config = quant_config
+        self.workspace = torch.zeros((512, ), dtype=torch.int, device="cuda")
+
+    def fit_marlin(self, output_size):
+        return self.quant_config.group_size in (-1, 128) and (
+            self.quant_config.weight_bits
+            == 4) and (self.quant_config.sym) and (
+                not self.quant_config.desc_act) and (output_size % 256
+                                                     == 0) and not is_hip()
 
     def create_weights(
         self,
@@ -115,15 +194,18 @@ class GPTQLinearMethod(LinearMethodBase):
         exllama_state = ExllamaState.UNINITIALIZED
         scale_and_zero_size = input_size // group_size
         scale_and_zero_input_dim = None
+        # For act-order models, we cannot use Exllama for row parallel layer
         if (input_size != input_size_per_partition
                 and self.quant_config.group_size != -1):
-            # For act-order models, we cannot use Exllama for row parallel layer
             if self.quant_config.desc_act:
                 exllama_state = ExllamaState.UNUSED
             else:
-                # we need to partition qzeros and scales for exllama kernel
                 scale_and_zero_size = input_size_per_partition // group_size
                 scale_and_zero_input_dim = 0
+                if self.fit_marlin(output_size_per_partition):
+                    exllama_state = ExllamaState.MARLIN_UNINITIALIZED
+        elif self.fit_marlin(output_size_per_partition):
+            exllama_state = ExllamaState.MARLIN_UNINITIALIZED
 
         qweight = Parameter(
             torch.empty(
@@ -195,8 +277,7 @@ class GPTQLinearMethod(LinearMethodBase):
                       weights: Dict[str, Any],
                       x: torch.Tensor,
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:
-        qweight = weights["qweight"]
-        out_shape = x.shape[:-1] + (qweight.shape[-1], )
+        out_shape = x.shape[:-1] + (weights["scales"].shape[-1], )
         reshaped_x = x.reshape(-1, x.shape[-1])
         # exllama needs to shuffle the weight after the weight is loaded
         # here we do the shuffle on first forward pass
@@ -209,11 +290,31 @@ class GPTQLinearMethod(LinearMethodBase):
             weights["exllama_state"] = ExllamaState.READY
             ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
                              self.quant_config.weight_bits)
-        output = ops.gptq_gemm(reshaped_x, weights["qweight"],
-                               weights["qzeros"], weights["scales"],
-                               weights["g_idx"],
-                               weights["exllama_state"] == ExllamaState.READY,
-                               self.quant_config.weight_bits)
+        elif weights["exllama_state"] == ExllamaState.MARLIN_UNINITIALIZED:
+            if self.quant_config.desc_act:
+                weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
+                    torch.int)
+            else:
+                weights["g_idx"] = None
+            weights["qweight"], weights["scales"] = pemute_weight(
+                weights["qweight"], weights["scales"],
+                self.quant_config.group_size, weights["g_idx"])
+            weights["exllama_state"] = ExllamaState.MARLIN_READY
+
+        if weights["exllama_state"] == ExllamaState.MARLIN_READY:
+            output = torch.empty(out_shape, dtype=x.dtype, device=x.device)
+            # reorder input for act-order model
+            if weights["g_idx"] is not None:
+                reshaped_x = reshaped_x[:, weights["g_idx"]]
+            ops.marlin_gemm(reshaped_x, weights["qweight"],
+                            output.view(-1, output.shape[-1]),
+                            weights["scales"], self.workspace)
+        else:
+            output = ops.gptq_gemm(
+                reshaped_x, weights["qweight"], weights["qzeros"],
+                weights["scales"], weights["g_idx"],
+                weights["exllama_state"] == ExllamaState.READY,
+                self.quant_config.weight_bits)
         if bias is not None:
             output = output + bias
         return output.reshape(out_shape)

BIN
aphrodite/modeling/layers/quantization/hadamard.safetensors


+ 199 - 0
aphrodite/modeling/layers/quantization/quip.py

@@ -0,0 +1,199 @@
+from typing import Any, Dict, List, Optional
+
+import torch
+from torch.nn.parameter import Parameter
+
+from aphrodite._C import ops
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              set_weight_attrs)
+from aphrodite.modeling.layers.quantization.base_config import (
+    QuantizationConfig)
+from aphrodite.modeling.layers.quantization.quip_utils import (
+    get_packed_abs_grid,
+    get_hadK,
+    matmul_hadUt_cuda,
+    matmul_hadU_cuda,
+)
+
+
+class QuipConfig(QuantizationConfig):
+    """Config class for Quip.
+    Reference: https://cornell-relaxml.github.io/quip-sharp/
+    """
+
+    def __init__(self, codebook: int, use_rand: bool) -> None:
+        self.codebook = codebook
+        self.use_rand = use_rand
+
+        if self.codebook != "E8P12":
+            raise ValueError("Currently, only E8P12 is supported for "
+                             f"Quip, but got {self.codebook}.")
+
+    def __repr__(self) -> str:
+        return (f"QuipConfig(codebook={self.codebook}, "
+                f"rescale_WH={self.rescale_WH})")
+
+    @classmethod
+    def get_name(cls) -> str:
+        return "quip"
+
+    @classmethod
+    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+        return [torch.half]
+
+    @classmethod
+    def get_min_capability(cls) -> int:
+        return 80
+
+    @classmethod
+    def get_config_filenames(cls) -> List[str]:
+        return ["quantization_config.json"]
+
+    @classmethod
+    def from_config(cls, config: Dict[str, Any]) -> "QuipConfig":
+        codebook = cls.get_from_keys(config, ["codebook"])
+        use_rand = cls.get_from_keys(config, ["use_rand"])
+        return cls(codebook, use_rand)
+
+    def get_linear_method(self) -> "QuipLinearMethod":
+        return QuipLinearMethod(self)
+
+    def get_scaled_act_names(self) -> List[str]:
+        return []
+
+    def merge_weight(self) -> bool:
+        return False
+
+    def rope_style(self) -> Optional[bool]:
+        return None
+
+
+class QuipLinearMethod(LinearMethodBase):
+    """Linear method for Quip.
+    Args:
+        quant_config: The Quip quantization config.
+    """
+
+    def __init__(self, quant_config: QuipConfig):
+        self.quant_config = quant_config
+        self.grid_packed_abs = get_packed_abs_grid().to(device="cuda")
+        self.pack = 8
+        self.idx_dtype = torch.int16
+
+    def create_weights(
+        self,
+        input_size_per_partition: int,
+        output_size_per_partition: int,
+        input_size: int,
+        output_size: int,
+        params_dtype: torch.dtype,
+    ) -> Dict[str, Any]:
+        if input_size != input_size_per_partition or output_size != output_size_per_partition:
+            raise ValueError(
+                "Currently Quip doesn't support tensor parallel yet")
+
+        had_left, K_left, q_in_features = get_hadK(input_size,
+                                                   self.quant_config.use_rand)
+        had_right, K_right, q_out_features = get_hadK(
+            output_size, self.quant_config.use_rand)
+        weights = {
+            "K_left": K_left,
+            "K_right": K_right,
+            "q_in_features": q_in_features,
+            "q_out_features": q_out_features,
+        }
+        if had_left is not None:
+            weights["had_left"] = Parameter(
+                had_left.to(dtype=params_dtype, device="cuda"),
+                requires_grad=False,
+            )
+            set_weight_attrs(weights["had_left"], {"ignore_warning": True})
+        if had_right is not None:
+            weights["had_right"] = Parameter(
+                had_right.to(dtype=params_dtype, device="cuda"),
+                requires_grad=False,
+            )
+            set_weight_attrs(weights["had_right"], {"ignore_warning": True})
+        Qidxs = Parameter(
+            torch.empty(q_out_features,
+                        q_in_features // self.pack,
+                        device="cuda",
+                        dtype=self.idx_dtype),
+            requires_grad=False,
+        )
+        set_weight_attrs(Qidxs, {"ignore_warning": True})
+        Wscale = Parameter(
+            torch.ones((), dtype=torch.float, device="cuda"),
+            requires_grad=False,
+        )
+        set_weight_attrs(Wscale, {"ignore_warning": True})
+        SU = Parameter(
+            torch.ones(
+                input_size,
+                device="cuda",
+                dtype=params_dtype,
+            ),
+            requires_grad=False,
+        )
+        set_weight_attrs(SU, {"ignore_warning": True})
+        SV = Parameter(
+            torch.ones(
+                output_size,
+                device="cuda",
+                dtype=params_dtype,
+            ),
+            requires_grad=False,
+        )
+        set_weight_attrs(SV, {"ignore_warning": True})
+        weights.update({
+            "Qidxs": Qidxs,
+            "Wscale": Wscale,
+            "SU": SU,
+            "SV": SV,
+        })
+        return weights
+
+    def apply_weights(self,
+                      weights: Dict[str, Any],
+                      x: torch.Tensor,
+                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+        # First run
+        if isinstance(weights["Wscale"], torch.Tensor):
+            weights["Wscale"] = weights["Wscale"].item()
+            if "SU" in weights and torch.all(weights["SU"] > 0):
+                del weights["SU"]
+            if "SV" in weights and torch.all(weights["SV"] > 0):
+                del weights["SV"]
+
+        reshaped_x = x.reshape(-1, x.shape[-1])
+        out_dim = weights["Qidxs"].shape[0]
+
+        if "SU" in weights:
+            reshaped_x = reshaped_x * weights["SU"]
+        reshaped_x = matmul_hadUt_cuda(reshaped_x,
+                                       weights.get("had_left",
+                                                   None), weights["K_left"],
+                                       weights["q_in_features"],
+                                       weights["Wscale"])
+
+        m, n = weights["Qidxs"].shape
+        if reshaped_x.size(0) < 32:
+            out = ops.quip_gemv(reshaped_x, weights["Qidxs"],
+                                self.grid_packed_abs)
+        else:
+            W_decompressed = torch.empty(m,
+                                         n * 8,
+                                         dtype=torch.float16,
+                                         device=x.device)
+            ops.quip_decompress(weights["Qidxs"], self.grid_packed_abs,
+                                W_decompressed)
+            out = reshaped_x @ W_decompressed.T
+
+        out = matmul_hadU_cuda(out, weights.get("had_right",
+                                                None), weights["K_right"],
+                               weights["q_out_features"])[..., :out_dim]
+        if "SV" in weights:
+            out = out * weights["SV"]
+        out = out.view(*x.shape[:-1], out.shape[-1])
+        out = out + bias if bias is not None else out
+        return out

+ 126 - 0
aphrodite/modeling/layers/quantization/quip_utils.py

@@ -0,0 +1,126 @@
+import math
+from pathlib import Path
+
+import scipy
+import torch
+import fast_hadamard_transform
+from safetensors.torch import load_file
+
+HADA_TENSORS = load_file(
+    Path(__file__).resolve().parent / "hadamard.safetensors")
+
+
+def int2mask(i, int_map):
+    return ((i & int_map) > 0).int()
+
+
+def mask2int(mask, int_map):
+    return (int_map.unsqueeze(0) * mask.int()).sum(dim=-1)
+
+
+def get_norm12():
+    # 29 elements of norm 12 in E8 + 1/4
+    return torch.tensor([
+        [3, 1, 1, 1, 3, 3, 3, 3],
+        [1, 3, 1, 1, 3, 3, 3, 3],
+        [1, 1, 3, 1, 3, 3, 3, 3],
+        [1, 1, 1, 3, 3, 3, 3, 3],
+        [3, 3, 3, 1, 3, 3, 1, 1],
+        [3, 3, 3, 1, 3, 1, 3, 1],
+        [3, 3, 3, 1, 1, 3, 3, 1],
+        [3, 3, 3, 1, 3, 1, 1, 3],
+        [3, 3, 3, 1, 1, 3, 1, 3],
+        [3, 3, 3, 1, 1, 1, 3, 3],
+        [3, 3, 1, 3, 3, 3, 1, 1],
+        [3, 3, 1, 3, 3, 1, 3, 1],
+        [3, 3, 1, 3, 1, 3, 3, 1],
+        [3, 3, 1, 3, 3, 1, 1, 3],
+        [3, 3, 1, 3, 1, 3, 1, 3],
+        [3, 3, 1, 3, 1, 1, 3, 3],
+        [3, 1, 3, 3, 3, 3, 1, 1],
+        [3, 1, 3, 3, 3, 1, 3, 1],
+        [3, 1, 3, 3, 1, 3, 3, 1],
+        [3, 1, 3, 3, 3, 1, 1, 3],
+        [3, 1, 3, 3, 1, 3, 1, 3],
+        [1, 3, 3, 3, 1, 1, 3, 3],
+        [1, 3, 3, 3, 3, 3, 1, 1],
+        [1, 3, 3, 3, 3, 1, 3, 1],
+        [1, 3, 3, 3, 1, 3, 3, 1],
+        [1, 3, 3, 3, 3, 1, 1, 3],
+        [1, 3, 3, 3, 1, 3, 1, 3],
+        [1, 1, 3, 3, 1, 3, 3, 3],
+        [3, 3, 1, 1, 3, 3, 3, 1],
+    ]) / 2
+
+
+def get_packed_abs_grid():
+    intr = torch.arange(-4, 4)
+    d8 = torch.cartesian_prod(*[intr] * 8).float() + 1 / 2
+    d8m2 = d8.sum(dim=-1) % 2 == 0
+    d8n = d8.norm(dim=-1)**2 <= 10
+    d8abs = torch.unique(d8[sorted(torch.where(d8m2 * d8n)[0])].abs(), dim=0)
+    norm12 = get_norm12()
+    cba = torch.concat([d8abs, norm12], dim=0)
+    cba = cba[:, [0, 2, 1, 3, 4, 6, 5, 7]]
+    cba[:, 7] *= (1 - 2 * (cba.sum(1) % 2))
+    cba = cba * 4
+    cba = cba.to(torch.int64)
+    acc = cba[:, 0]
+    for i in range(7):
+        acc = acc | (cba[:, (i + 1)] << ((i + 1) * 8))
+    return acc
+
+
+def next_power_of_2(n):
+    if n == 0:
+        return 1
+    return 2**math.ceil(math.log(n, 2))
+
+
+def get_power_of_2(n):
+    """Returns the highest power of 2 that divides n."""
+    k = 0
+    while n % 2 == 0:
+        n //= 2
+        k += 1
+    return k, n
+
+
+def get_hadK(n, use_rand=True):
+    exp, base = get_power_of_2(n)
+    if base == 1:
+        return None, 1, n
+    if use_rand:
+        rand_mat = torch.tensor(scipy.stats.special_ortho_group.rvs(base)).to(
+            torch.float32)
+        return rand_mat, base, n
+
+    # Use hadamad only and add padding if cannot find one
+    pad_n = next_power_of_2(n)
+    if exp < 2 or str(base * 4) not in HADA_TENSORS:
+        return None, 1, pad_n
+    base_mat = HADA_TENSORS[str(base * 4)] / math.sqrt(base * 4)
+    return base_mat, base * 4, n
+
+
+def matmul_hadU_cuda(X, hadK, K, n, scale=None, transpose=False):
+    if n != X.shape[-1]:
+        X = torch.nn.functional.pad(X, (0, n - X.shape[-1]))
+
+    had_scale = 1 / math.sqrt(n // K) if scale is None else scale / math.sqrt(
+        n // K)
+    if K == 1:
+        return fast_hadamard_transform.hadamard_transform(X.contiguous(),
+                                                          scale=had_scale)
+
+    if transpose:
+        hadK = hadK.T.contiguous()
+    input = X.view(-1, K, n // K)  # pylint: disable=redefined-builtin
+    input = fast_hadamard_transform.hadamard_transform(input.contiguous(),
+                                                       scale=had_scale)
+    input = hadK @ input
+    return input.reshape(X.shape)
+
+
+def matmul_hadUt_cuda(X, hadK, K, n, scale=None):
+    return matmul_hadU_cuda(X, hadK, K, n, scale=scale, transpose=True)

+ 6 - 0
aphrodite/modeling/layers/quantization/squeezellm.py

@@ -56,6 +56,12 @@ class SqueezeLLMConfig(QuantizationConfig):
     def get_scaled_act_names(self) -> List[str]:
         return []
 
+    def merge_weight(self) -> bool:
+        return True
+
+    def rope_style(self) -> Optional[bool]:
+        return None
+
 
 class SqueezeLLMLinearMethod(LinearMethodBase):
     """Linear method for SqueezeLLM.

+ 6 - 7
aphrodite/modeling/layers/sampler.py

@@ -52,16 +52,15 @@ class Sampler(nn.Module):
 
     def forward(
         self,
-        embedding: torch.Tensor,
-        hidden_states: torch.Tensor,
+        logits: torch.Tensor,
         sampling_metadata: SamplingMetadata,
-        embedding_bias: Optional[torch.Tensor] = None,
     ) -> Optional[SamplerOutput]:
         # Get the hidden states that we use for sampling.
-        hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
-
-        # Get the logits for the next tokens.
-        logits = self._get_logits(hidden_states, embedding, embedding_bias)
+        logits = _prune_hidden_states(logits, sampling_metadata)
+        logits = tensor_model_parallel_gather(logits)
+        # Remove paddings in vocab (if any).
+        if logits is not None:
+            logits = logits[:, :self.vocab_size]
 
         # Only perform sampling in the driver worker.
         # Note: `_get_logits` is still distributed across TP workers because

+ 35 - 20
aphrodite/modeling/layers/vocab_parallel_embedding.py

@@ -1,9 +1,9 @@
 from typing import Optional, Sequence
 
 import torch
-import torch.nn.functional as F
 from torch.nn.parameter import Parameter
 
+from aphrodite.modeling.layers.linear import UnquantizedLinearMethod
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_rank,
     get_tensor_model_parallel_world_size,
@@ -54,6 +54,7 @@ class VocabParallelEmbedding(torch.nn.Module):
                  num_embeddings: int,
                  embedding_dim: int,
                  params_dtype: Optional[torch.dtype] = None,
+                 linear_method=None,
                  org_num_embeddings: Optional[int] = None,
                  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
         super().__init__()
@@ -74,22 +75,32 @@ class VocabParallelEmbedding(torch.nn.Module):
                 self.tp_size))
         self.num_embeddings_per_partition = (self.vocab_end_index -
                                              self.vocab_start_index)
-        self.weight = Parameter(
-            torch.empty(self.num_embeddings_per_partition,
-                        self.embedding_dim,
-                        device=torch.cuda.current_device(),
-                        dtype=params_dtype))
-        set_weight_attrs(self.weight, {
-            "parallel_dim": 0,
-            "weight_loader": self.weight_loader
-        })
+        if linear_method is None or not linear_method.quant_config.quant_vocab(
+        ):
+            linear_method = UnquantizedLinearMethod()
+        self.linear_method = linear_method
+        self.linear_weights = self.linear_method.create_weights(
+            self.embedding_dim, self.num_embeddings_per_partition,
+            self.embedding_dim, self.num_embeddings_padded, params_dtype)
+        for name, weight in self.linear_weights.items():
+            if isinstance(weight, torch.nn.parameter.Parameter):
+                self.register_parameter(name, weight)
+                set_weight_attrs(weight, {"weight_loader": self.weight_loader})
 
     def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
-        parallel_dim = param.parallel_dim
-        assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
-        loaded_weight = loaded_weight[self.vocab_start_index:self.
-                                      vocab_end_index]
-        param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
+        output_dim = getattr(param, "output_dim", None)
+        if output_dim is not None:
+            assert loaded_weight.shape[output_dim] == self.num_embeddings
+            loaded_weight = loaded_weight[self.vocab_start_index:self.
+                                          vocab_end_index]
+        if isinstance(param, torch.nn.parameter.UninitializedParameter):
+            param.materialize(
+                (self.num_embeddings_per_partition, loaded_weight.shape[1]),
+                dtype=loaded_weight.dtype)
+        if output_dim is not None:
+            param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
+        else:
+            param.data.copy_(loaded_weight)
 
     def forward(self, input_):
         if self.tp_size > 1:
@@ -102,7 +113,8 @@ class VocabParallelEmbedding(torch.nn.Module):
         else:
             masked_input = input_
             # Get the embeddings.
-        output_parallel = F.embedding(masked_input, self.weight)
+        output_parallel = self.linear_method.apply_embedding(
+            self.linear_weights, masked_input)
         # Mask the output embedding.
         if self.tp_size > 1:
             output_parallel[input_mask, :] = 0.0
@@ -132,22 +144,25 @@ class ParallelLMHead(VocabParallelEmbedding):
                  embedding_dim: int,
                  bias: bool = False,
                  params_dtype: Optional[torch.dtype] = None,
+                 linear_method=None,
                  org_num_embeddings: Optional[int] = None,
                  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
         super().__init__(num_embeddings, embedding_dim, params_dtype,
-                         org_num_embeddings, padding_size)
+                         linear_method, org_num_embeddings, padding_size)
         if bias:
             self.bias = Parameter(
                 torch.empty(self.num_embeddings_per_partition,
                             device=torch.cuda.current_device(),
                             dtype=params_dtype))
             set_weight_attrs(self.bias, {
-                "parallel_dim": 0,
+                "output_dim": 0,
                 "weight_loader": self.weight_loader
             })
         else:
             self.register_parameter("bias", None)
 
     def forward(self, input_):
-        del input_
-        raise RuntimeError("LMHead's weights should be used in the sampler.")
+        logits = self.linear_method.apply_weights(self.linear_weights, input_)
+        if self.bias is not None:
+            logits += self.bias
+        return logits

+ 37 - 10
aphrodite/modeling/models/gpt_j.py

@@ -56,13 +56,30 @@ class GPTJAttention(nn.Module):
         self.hidden_size = config.hidden_size
         self.head_size = self.hidden_size // self.total_num_heads
 
-        self.qkv_proj = QKVParallelLinear(
-            config.hidden_size,
-            self.head_size,
-            self.total_num_heads,
-            bias=False,
-            linear_method=linear_method,
-        )
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.q_proj = ColumnParallelLinear(config.hidden_size,
+                                               config.hidden_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+            self.k_proj = ColumnParallelLinear(config.hidden_size,
+                                               config.hidden_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+            self.v_proj = ColumnParallelLinear(config.hidden_size,
+                                               config.hidden_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.qkv_proj = QKVParallelLinear(
+                config.hidden_size,
+                self.head_dim,
+                self.total_num_heads,
+                bias=False,
+                linear_method=linear_method,
+            )
         self.out_proj = RowParallelLinear(
             config.hidden_size,
             config.hidden_size,
@@ -96,8 +113,13 @@ class GPTJAttention(nn.Module):
         kv_cache: KVCache,
         input_metadata: InputMetadata,
     ) -> torch.Tensor:
-        qkv, _ = self.qkv_proj(hidden_states)
-        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        if self.merge_weight:
+            qkv, _ = self.qkv_proj(hidden_states)
+            q, k, v = qkv.chunk(chunks=3, dim=-1)
+        else:
+            q, _ = self.q_proj(hidden_states)
+            k, _ = self.k_proj(hidden_states)
+            v, _ = self.v_proj(hidden_states)
         q, k = self.rotary_emb(position_ids, q, k)
         k_cache, v_cache = kv_cache
         attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
@@ -183,6 +205,7 @@ class GPTJModel(nn.Module):
         self.wte = VocabParallelEmbedding(
             config.vocab_size,
             self.embed_dim,
+            linear_method=linear_method,
         )
         self.h = nn.ModuleList(
             [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
@@ -224,6 +247,7 @@ class GPTJForCausalLM(nn.Module):
             config.vocab_size,
             config.n_embd,
             bias=True,
+            linear_method=linear_method,
         )
         self.sampler = Sampler(config.vocab_size)
 
@@ -243,7 +267,7 @@ class GPTJForCausalLM(nn.Module):
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
+        next_tokens = self.sampler(self.lm_head(hidden_states),
                                    sampling_metadata, self.lm_head.bias)
         return next_tokens
 
@@ -260,6 +284,9 @@ class GPTJForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path, cache_dir, load_format, revision):

+ 6 - 1
aphrodite/modeling/models/gpt_neox.py

@@ -82,11 +82,14 @@ class GPTNeoXAttention(nn.Module):
         rope_theta = getattr(config, "rope_theta", 10000)
         max_position_embeddings = getattr(config, "max_position_embeddings",
                                           8192)
+        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
+        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
             self.head_size,
             rotary_dim=rotary_dim,
             max_position=max_position_embeddings,
             base=rope_theta,
+            is_neox_style=is_neox_style,
         )
         self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
 
@@ -196,6 +199,7 @@ class GPTNeoXModel(nn.Module):
         self.embed_in = VocabParallelEmbedding(
             config.vocab_size,
             config.hidden_size,
+            linear_method=linear_method,
         )
         self.layers = nn.ModuleList([
             GPTNeoXLayer(config, linear_method)
@@ -238,6 +242,7 @@ class GPTNeoXForCausalLM(nn.Module):
         self.embed_out = ParallelLMHead(
             config.vocab_size,
             config.hidden_size,
+            linear_method=linear_method,
         )
         self.sampler = Sampler(config.vocab_size)
 
@@ -257,7 +262,7 @@ class GPTNeoXForCausalLM(nn.Module):
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        next_tokens = self.sampler(self.embed_out.weight, hidden_states,
+        next_tokens = self.sampler(self.embed_out(hidden_states),
                                    sampling_metadata)
         return next_tokens
 

+ 67 - 17
aphrodite/modeling/models/llama.py

@@ -35,7 +35,8 @@ from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
-                                              RowParallelLinear)
+                                              RowParallelLinear,
+                                              ColumnParallelLinear)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
@@ -61,10 +62,23 @@ class LlamaMLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
-        self.gate_up_proj = MergedColumnParallelLinear(
-            hidden_size, [intermediate_size] * 2,
-            bias=False,
-            linear_method=linear_method)
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.gate_proj = ColumnParallelLinear(hidden_size,
+                                                  intermediate_size,
+                                                  bias=False,
+                                                  linear_method=linear_method)
+            self.up_proj = ColumnParallelLinear(hidden_size,
+                                                intermediate_size,
+                                                bias=False,
+                                                linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.gate_up_proj = MergedColumnParallelLinear(
+                hidden_size, [intermediate_size] * 2,
+                bias=False,
+                linear_method=linear_method)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
@@ -75,7 +89,12 @@ class LlamaMLP(nn.Module):
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
-        gate_up, _ = self.gate_up_proj(x)
+        if self.merge_weight:
+            gate_up, _ = self.gate_up_proj(x)
+        else:
+            up, _ = self.up_proj(x)
+            gate, _ = self.gate_proj(x)
+            gate_up = torch.cat([gate, up], dim=-1)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
         return x
@@ -116,14 +135,31 @@ class LlamaAttention(nn.Module):
         self.rope_theta = rope_theta
         self.max_position_embeddings = max_position_embeddings
 
-        self.qkv_proj = QKVParallelLinear(
-            hidden_size,
-            self.head_dim,
-            self.total_num_heads,
-            self.total_num_kv_heads,
-            bias=False,
-            linear_method=linear_method,
-        )
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.q_proj = ColumnParallelLinear(hidden_size,
+                                               self.q_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+            self.k_proj = ColumnParallelLinear(hidden_size,
+                                               self.kv_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+            self.v_proj = ColumnParallelLinear(hidden_size,
+                                               self.kv_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.qkv_proj = QKVParallelLinear(
+                hidden_size,
+                self.head_dim,
+                self.total_num_heads,
+                self.total_num_kv_heads,
+                bias=False,
+                linear_method=linear_method,
+            )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
@@ -131,12 +167,15 @@ class LlamaAttention(nn.Module):
             linear_method=linear_method,
         )
 
+        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
+        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position_embeddings,
             base=rope_theta,
             rope_scaling=rope_scaling,
+            is_neox_style=is_neox_style,
         )
         self.attn = PagedAttention(self.num_heads,
                                    self.head_dim,
@@ -150,8 +189,14 @@ class LlamaAttention(nn.Module):
         kv_cache: KVCache,
         input_metadata: InputMetadata,
     ) -> torch.Tensor:
-        qkv, _ = self.qkv_proj(hidden_states)
-        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        if self.merge_weight:
+            qkv, _ = self.qkv_proj(hidden_states)
+            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
+                                dim=-1)
+        else:
+            q, _ = self.q_proj(hidden_states)
+            k, _ = self.k_proj(hidden_states)
+            v, _ = self.v_proj(hidden_states)
         q, k = self.rotary_emb(positions, q, k)
         k_cache, v_cache = kv_cache
         attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
@@ -239,6 +284,7 @@ class LlamaModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(
             self.vocab_size,
             config.hidden_size,
+            linear_method=linear_method,
             org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
@@ -288,6 +334,7 @@ class LlamaForCausalLM(nn.Module):
         self.lm_head = ParallelLMHead(
             unpadded_vocab_size,
             config.hidden_size,
+            linear_method=linear_method,
             org_num_embeddings=config.vocab_size,
             padding_size=DEFAULT_VOCAB_PADDING_SIZE
             # We need bigger padding if using lora for kernel
@@ -312,7 +359,7 @@ class LlamaForCausalLM(nn.Module):
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
+        next_tokens = self.sampler(self.lm_head(hidden_states),
                                    sampling_metadata)
         return next_tokens
 
@@ -329,6 +376,9 @@ class LlamaForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path, cache_dir, load_format, revision):

+ 67 - 17
aphrodite/modeling/models/mistral.py

@@ -35,7 +35,8 @@ from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
-                                              RowParallelLinear)
+                                              RowParallelLinear,
+                                              ColumnParallelLinear)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
@@ -61,10 +62,23 @@ class MistralMLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
-        self.gate_up_proj = MergedColumnParallelLinear(
-            hidden_size, [intermediate_size] * 2,
-            bias=False,
-            linear_method=linear_method)
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.gate_proj = ColumnParallelLinear(hidden_size,
+                                                  intermediate_size,
+                                                  bias=False,
+                                                  linear_method=linear_method)
+            self.up_proj = ColumnParallelLinear(hidden_size,
+                                                intermediate_size,
+                                                bias=False,
+                                                linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.gate_up_proj = MergedColumnParallelLinear(
+                hidden_size, [intermediate_size] * 2,
+                bias=False,
+                linear_method=linear_method)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
@@ -75,7 +89,12 @@ class MistralMLP(nn.Module):
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
-        gate_up, _ = self.gate_up_proj(x)
+        if self.merge_weight:
+            gate_up, _ = self.gate_up_proj(x)
+        else:
+            up, _ = self.up_proj(x)
+            gate, _ = self.gate_proj(x)
+            gate_up = torch.cat([gate, up], dim=-1)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
         return x
@@ -114,14 +133,31 @@ class MistralAttention(nn.Module):
         self.rope_theta = rope_theta
         self.sliding_window = sliding_window
 
-        self.qkv_proj = QKVParallelLinear(
-            hidden_size,
-            self.head_dim,
-            self.total_num_heads,
-            self.total_num_kv_heads,
-            bias=False,
-            linear_method=linear_method,
-        )
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.q_proj = ColumnParallelLinear(hidden_size,
+                                               self.q_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+            self.k_proj = ColumnParallelLinear(hidden_size,
+                                               self.kv_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+            self.v_proj = ColumnParallelLinear(hidden_size,
+                                               self.kv_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.qkv_proj = QKVParallelLinear(
+                hidden_size,
+                self.head_dim,
+                self.total_num_heads,
+                self.total_num_kv_heads,
+                bias=False,
+                linear_method=linear_method,
+            )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
@@ -129,11 +165,14 @@ class MistralAttention(nn.Module):
             linear_method=linear_method,
         )
 
+        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
+        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position,
             base=self.rope_theta,
+            is_neox_style=is_neox_style,
         )
         self.attn = PagedAttention(self.num_heads,
                                    self.head_dim,
@@ -148,8 +187,14 @@ class MistralAttention(nn.Module):
         kv_cache: KVCache,
         input_metadata: InputMetadata,
     ) -> torch.Tensor:
-        qkv, _ = self.qkv_proj(hidden_states)
-        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        if self.merge_weight:
+            qkv, _ = self.qkv_proj(hidden_states)
+            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
+                                dim=-1)
+        else:
+            q, _ = self.q_proj(hidden_states)
+            k, _ = self.k_proj(hidden_states)
+            v, _ = self.v_proj(hidden_states)
         q, k = self.rotary_emb(positions, q, k)
         k_cache, v_cache = kv_cache
         attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
@@ -235,6 +280,7 @@ class MistralModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(
             self.vocab_size,
             config.hidden_size,
+            linear_method=linear_method,
             org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
@@ -286,6 +332,7 @@ class MistralForCausalLM(nn.Module):
         self.lm_head = ParallelLMHead(
             unpadded_vocab_size,
             config.hidden_size,
+            linear_method=linear_method,
             org_num_embeddings=config.vocab_size,
             padding_size=DEFAULT_VOCAB_PADDING_SIZE
             # We need bigger padding if using lora for kernel
@@ -310,7 +357,7 @@ class MistralForCausalLM(nn.Module):
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
+        next_tokens = self.sampler(self.lm_head(hidden_states),
                                    sampling_metadata)
         return next_tokens
 
@@ -327,6 +374,9 @@ class MistralForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path, cache_dir, load_format, revision):

+ 46 - 15
aphrodite/modeling/models/mixtral.py

@@ -38,7 +38,8 @@ from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               ReplicatedLinear,
                                               QKVParallelLinear,
-                                              RowParallelLinear)
+                                              RowParallelLinear,
+                                              ColumnParallelLinear)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
@@ -194,26 +195,45 @@ class MixtralAttention(nn.Module):
         self.rope_theta = rope_theta
         self.sliding_window = sliding_window
 
-        self.qkv_proj = QKVParallelLinear(
-            hidden_size,
-            self.head_dim,
-            self.total_num_heads,
-            self.total_num_kv_heads,
-            bias=False,
-            linear_method=linear_method,
-        )
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.q_proj = ColumnParallelLinear(hidden_size,
+                                               self.q_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+            self.k_proj = ColumnParallelLinear(hidden_size,
+                                               self.kv_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+            self.v_proj = ColumnParallelLinear(hidden_size,
+                                               self.kv_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.qkv_proj = QKVParallelLinear(
+                hidden_size,
+                self.head_dim,
+                self.total_num_heads,
+                self.total_num_kv_heads,
+                bias=False,
+                linear_method=linear_method,
+            )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
             linear_method=linear_method,
         )
+        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
+        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position,
             base=int(self.rope_theta),
-            is_neox_style=True,
+            is_neox_style=is_neox_style,
         )
         self.attn = PagedAttention(
             self.num_heads,
@@ -230,8 +250,14 @@ class MixtralAttention(nn.Module):
         kv_cache: KVCache,
         input_metadata: InputMetadata,
     ) -> torch.Tensor:
-        qkv, _ = self.qkv_proj(hidden_states)
-        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        if self.merge_weight:
+            qkv, _ = self.qkv_proj(hidden_states)
+            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
+                                dim=-1)
+        else:
+            q, _ = self.q_proj(hidden_states)
+            k, _ = self.k_proj(hidden_states)
+            v, _ = self.v_proj(hidden_states)
         q, k = self.rotary_emb(positions, q, k)
         k_cache, v_cache = kv_cache
         attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
@@ -308,6 +334,7 @@ class MixtralModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(
             config.vocab_size,
             config.hidden_size,
+            linear_method=linear_method,
         )
         self.layers = nn.ModuleList([
             MixtralDecoderLayer(config, linear_method=linear_method)
@@ -344,7 +371,9 @@ class MixtralForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.model = MixtralModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
         self.sampler = Sampler(config.vocab_size)
 
     def forward(
@@ -363,7 +392,7 @@ class MixtralForCausalLM(nn.Module):
         hidden_states: Optional[torch.Tensor],
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
+        next_tokens = self.sampler(self.lm_head(hidden_states),
                                    sampling_metadata)
         return next_tokens
 
@@ -378,7 +407,9 @@ class MixtralForCausalLM(nn.Module):
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
         ]
-
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path,

+ 336 - 0
aphrodite/modeling/models/phi.py

@@ -0,0 +1,336 @@
+# coding=utf-8
+# Adapted from
+# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+#
+# BSD 3-Clause License
+#
+# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Inference-only Phi model compatible with HuggingFace weights."""
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding, ParallelLMHead)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class PhiAttention(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.total_num_heads = config.num_attention_heads
+        self.hidden_size = config.hidden_size
+        self.head_size = self.hidden_size // self.total_num_heads
+
+        tensor_model_parallel_world_size = (
+            get_tensor_model_parallel_world_size())
+        assert self.total_num_heads % tensor_model_parallel_world_size == 0
+        self.num_heads = (self.total_num_heads //
+                          tensor_model_parallel_world_size)
+
+        # pylint: disable=C0103
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.q_proj = ColumnParallelLinear(self.hidden_size,
+                                               self.hidden_size,
+                                               bias=True,
+                                               linear_method=linear_method)
+            self.k_proj = ColumnParallelLinear(self.hidden_size,
+                                               self.hidden_size,
+                                               bias=True,
+                                               linear_method=linear_method)
+            self.v_proj = ColumnParallelLinear(self.hidden_size,
+                                               self.hidden_size,
+                                               bias=True,
+                                               linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.qkv_proj = QKVParallelLinear(
+                self.hidden_size,
+                self.head_size,
+                self.total_num_heads,
+                bias=True,
+                linear_method=linear_method,
+            )
+        self.dense = RowParallelLinear(
+            self.hidden_size,
+            self.hidden_size,
+            linear_method=linear_method,
+        )
+
+        scaling = self.head_size**-0.5
+        rotary_dim = int(config.partial_rotary_factor *
+                         (config.hidden_size // config.num_attention_heads))
+        assert rotary_dim % 2 == 0
+
+        # pylint: disable=C0301
+        # Refer to:
+        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
+        rope_theta = 10000
+        max_position_embeddings = getattr(config, "n_positions", 2048)
+        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
+        ) is None else linear_method.quant_config.rope_style()
+        self.rotary_emb = get_rope(
+            self.head_size,
+            rotary_dim=rotary_dim,
+            max_position=max_position_embeddings,
+            base=rope_theta,
+            is_neox_style=is_neox_style,
+        )
+        self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        if self.merge_weight:
+            qkv, _ = self.qkv_proj(hidden_states)
+            q, k, v = qkv.chunk(chunks=3, dim=-1)
+        else:
+            q, _ = self.q_proj(hidden_states)
+            k, _ = self.k_proj(hidden_states)
+            v, _ = self.v_proj(hidden_states)
+        q, k = self.rotary_emb(position_ids, q, k)
+        k_cache, v_cache = kv_cache
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        output, _ = self.dense(attn_output)
+        return output
+
+
+class PhiMLP(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+
+        n_inner = getattr(config, "n_inner", None)
+        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
+
+        self.fc1 = ColumnParallelLinear(
+            config.hidden_size,
+            n_inner,
+            linear_method=linear_method,
+        )
+        self.fc2 = RowParallelLinear(
+            n_inner,
+            config.hidden_size,
+            linear_method=linear_method,
+        )
+        quant_config = getattr(linear_method, "quant_config", None)
+        self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
+
+    def forward(self, hidden_states):
+        hidden_states, _ = self.fc1(hidden_states)
+        hidden_states = self.act(hidden_states)
+        hidden_states, _ = self.fc2(hidden_states)
+        return hidden_states
+
+
+class PhiLayer(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.input_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.layer_norm_eps)
+        self.self_attn = PhiAttention(config, linear_method)
+        self.mlp = PhiMLP(config, linear_method)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: KVCache,
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+        attn_outputs = self.self_attn(
+            position_ids=position_ids,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            input_metadata=input_metadata,
+        )
+        feed_forward_hidden_states = self.mlp(hidden_states)
+        hidden_states = attn_outputs + feed_forward_hidden_states + residual
+        return hidden_states
+
+
+class PhiModel(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+                                                   config.hidden_size,
+                                                   linear_method=linear_method)
+        self.layers = nn.ModuleList([
+            PhiLayer(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.final_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.layer_norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+        for i in range(self.config.num_hidden_layers):
+            layer = self.layers[i]
+            hidden_states = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                input_metadata,
+            )
+
+        hidden_states = self.final_layernorm(hidden_states)
+
+        return hidden_states
+
+
+class PhiForCausalLM(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+
+        self.model = PhiModel(config, linear_method)
+
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      bias=True,
+                                      linear_method=linear_method)
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   input_metadata)
+
+        return hidden_states
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        head = self.lm_head  # pylint: disable=unused-variable
+        next_tokens = self.sampler(self.lm_head(hidden_states),
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v")
+        ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
+        params_dict = dict(self.named_parameters())
+
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision):
+            if "rotary_emb.inv_freq" in name:
+                continue
+
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                # pylint: disable=E1136
+
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 71 - 22
aphrodite/modeling/models/yi.py

@@ -35,7 +35,8 @@ from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
-                                              RowParallelLinear)
+                                              RowParallelLinear,
+                                              ColumnParallelLinear)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
@@ -60,10 +61,23 @@ class YiMLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
-        self.gate_up_proj = MergedColumnParallelLinear(
-            hidden_size, [intermediate_size] * 2,
-            bias=False,
-            linear_method=linear_method)
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.gate_proj = ColumnParallelLinear(hidden_size,
+                                                  intermediate_size,
+                                                  bias=False,
+                                                  linear_method=linear_method)
+            self.up_proj = ColumnParallelLinear(hidden_size,
+                                                intermediate_size,
+                                                bias=False,
+                                                linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.gate_up_proj = MergedColumnParallelLinear(
+                hidden_size, [intermediate_size] * 2,
+                bias=False,
+                linear_method=linear_method)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
@@ -74,7 +88,12 @@ class YiMLP(nn.Module):
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
-        gate_up, _ = self.gate_up_proj(x)
+        if self.merge_weight:
+            gate_up, _ = self.gate_up_proj(x)
+        else:
+            up, _ = self.up_proj(x)
+            gate, _ = self.gate_proj(x)
+            gate_up = torch.cat([gate, up], dim=-1)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
         return x
@@ -115,26 +134,46 @@ class YiAttention(nn.Module):
         self.rope_theta = rope_theta
         self.max_position_embeddings = max_position_embeddings
 
-        self.qkv_proj = QKVParallelLinear(
-            hidden_size,
-            self.head_dim,
-            self.total_num_heads,
-            self.total_num_kv_heads,
-            bias=False,
-            linear_method=linear_method,
-        )
+        if linear_method is not None and not linear_method.quant_config.merge_weight(
+        ):
+            self.merge_weight = False
+            self.q_proj = ColumnParallelLinear(hidden_size,
+                                               self.q_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+            self.k_proj = ColumnParallelLinear(hidden_size,
+                                               self.kv_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+            self.v_proj = ColumnParallelLinear(hidden_size,
+                                               self.kv_size,
+                                               bias=False,
+                                               linear_method=linear_method)
+        else:
+            self.merge_weight = True
+            self.qkv_proj = QKVParallelLinear(
+                hidden_size,
+                self.head_dim,
+                self.total_num_heads,
+                self.total_num_kv_heads,
+                bias=False,
+                linear_method=linear_method,
+            )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
             linear_method=linear_method,
         )
+        is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
+        ) is None else linear_method.quant_config.rope_style()
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position_embeddings,
             base=self.rope_theta,
             rope_scaling=rope_scaling,
+            is_neox_style=is_neox_style,
         )
         self.attn = PagedAttention(self.num_heads,
                                    self.head_dim,
@@ -148,8 +187,14 @@ class YiAttention(nn.Module):
         kv_cache: KVCache,
         input_metadata: InputMetadata,
     ) -> torch.Tensor:
-        qkv, _ = self.qkv_proj(hidden_states)
-        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        if self.merge_weight:
+            qkv, _ = self.qkv_proj(hidden_states)
+            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
+                                dim=-1)
+        else:
+            q, _ = self.q_proj(hidden_states)
+            k, _ = self.k_proj(hidden_states)
+            v, _ = self.v_proj(hidden_states)
         q, k = self.rotary_emb(positions, q, k)
         k_cache, v_cache = kv_cache
         attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
@@ -226,10 +271,9 @@ class YiModel(nn.Module):
         self.config = config
         self.padding_idx = config.pad_token_id
         self.vocab_size = config.vocab_size
-        self.embed_tokens = VocabParallelEmbedding(
-            config.vocab_size,
-            config.hidden_size,
-        )
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+                                                   config.hidden_size,
+                                                   linear_method=linear_method)
         self.layers = nn.ModuleList([
             YiDecoderLayer(config, linear_method)
             for _ in range(config.num_hidden_layers)
@@ -269,7 +313,9 @@ class YiForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.model = YiModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      linear_method=linear_method)
         self.sampler = Sampler(config.vocab_size)
 
     def forward(
@@ -288,7 +334,7 @@ class YiForCausalLM(nn.Module):
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
+        next_tokens = self.sampler(self.lm_head(hidden_states),
                                    sampling_metadata)
         return next_tokens
 
@@ -305,6 +351,9 @@ class YiForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
+        if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
+        ):
+            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path, cache_dir, load_format, revision):

+ 2 - 0
aphrodite/task_handler/model_runner.py

@@ -670,6 +670,8 @@ class ModelRunner:
         # memory usage of CUDA graph.
         with custom_all_reduce.capture():
             for batch_size in reversed(batch_size_capture_list):
+                if batch_size > self.scheduler_config.max_num_seqs:
+                    continue
                 # Create dummy input_metadata.
                 input_metadata = InputMetadata(
                     is_prompt=False,

+ 151 - 0
examples/gguf_to_torch.py

@@ -0,0 +1,151 @@
+import json
+import os
+
+import torch
+import gguf
+from sentencepiece import sentencepiece_model_pb2
+
+def convert_to_state_dict(checkpoint, save_dir):
+    if not os.path.exists(save_dir):
+        os.makedirs(save_dir)
+    state_dict = {}
+    result = gguf.GGUFReader(checkpoint)
+    architecture = result.fields['general.architecture']
+    architecture = str(bytes(architecture.parts[architecture.data[0]]), encoding = 'utf-8')
+    if architecture != "llama":
+        print(f"Unsupported architecture {architecture}")
+        return
+
+    # write vocab
+    vocab = sentencepiece_model_pb2.ModelProto()
+    vocab_size = len(result.fields['tokenizer.ggml.token_type'].data)
+    vocab.trainer_spec.model_type = 2 # BPE
+    vocab.trainer_spec.vocab_size = vocab_size
+    vocab.trainer_spec.byte_fallback = True
+    vocab.normalizer_spec.remove_extra_whitespaces = False
+    tokens = result.fields['tokenizer.ggml.tokens']
+    scores = result.fields['tokenizer.ggml.scores']
+    types = result.fields['tokenizer.ggml.token_type']
+    for i in range(vocab_size):
+        new_token = vocab.SentencePiece()
+        new_token.piece = str(bytes(tokens.parts[tokens.data[i]]), encoding = 'utf-8')
+        new_token.score = scores.parts[scores.data[i]]
+        # llama.cpp tokentype is the same with sentencepiece token type
+        new_token.type = int(types.parts[types.data[i]])
+        vocab.pieces.append(new_token)
+    with open(os.path.join(save_dir, "tokenizer.model"), 'wb') as f:
+        f.write(vocab.SerializeToString())
+    tokenizer_config = {
+        "tokenizer_class": "LlamaTokenizer",
+        "legacy": False,
+        "clean_up_tokenization_spaces": False,
+    }
+    if 'tokenizer.ggml.bos_token_id' in result.fields:
+        tokenizer_config["bos_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.bos_token_id'].parts[-1])].piece
+    if 'tokenizer.ggml.eos_token_id' in result.fields:
+        tokenizer_config["eos_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.eos_token_id'].parts[-1])].piece
+    if 'tokenizer.ggml.padding_token_id' in result.fields:
+        tokenizer_config["pad_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.padding_token_id'].parts[-1])].piece
+    if 'tokenizer.ggml.unknown_token_id' in result.fields:
+        tokenizer_config["unk_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.unknown_token_id'].parts[-1])].piece
+    if 'tokenizer.ggml.add_bos_token' in result.fields:
+        tokenizer_config["add_bos_token"] = bool(result.fields['tokenizer.ggml.add_bos_token'].parts[-1])
+    if 'tokenizer.ggml.add_eos_token' in result.fields:
+        tokenizer_config["add_eos_token"] = bool(result.fields['tokenizer.ggml.add_eos_token'].parts[-1])
+    if 'tokenizer.chat_template' in result.fields:
+        tokenizer_config["chat_template"] = str(bytes(result.fields['tokenizer.chat_template'].parts[-1]), encoding="utf-8")
+    json.dump(tokenizer_config, open(os.path.join(save_dir, "tokenizer_config.json"), 'w'), indent=2)
+
+    # write config
+    context_length = int(result.fields['llama.context_length'].parts[-1])
+    n_layer = int(result.fields['llama.block_count'].parts[-1])
+    n_head = int(result.fields['llama.attention.head_count'].parts[-1])
+    n_local_heads = int(result.fields['llama.attention.head_count_kv'].parts[-1])
+    intermediate_size = int(result.fields['llama.feed_forward_length'].parts[-1])
+    norm_eps = float(result.fields['llama.attention.layer_norm_rms_epsilon'].parts[-1])
+    dim = int(result.fields['llama.embedding_length'].parts[-1])
+    kv_dim = dim // n_head * n_local_heads
+    arch = "MixtralForCausalLM"
+    if 'llama.expert_count' in result.fields:
+        arch = "MixtralForCausalLM"
+        name = "mixtral"
+    else:
+        arch = "LlamaForCausalLM"
+        name = "llama"
+    model_config= {
+        "architectures": [arch],
+        "bos_token_id": 1,
+        "eos_token_id": 2,
+        "hidden_act": "silu",
+        "hidden_size": dim,
+        "intermediate_size": intermediate_size,
+        "max_position_embeddings": context_length,
+        "model_type": name,
+        "num_attention_heads": n_head,
+        "num_hidden_layers": n_layer,
+        "num_key_value_heads": n_local_heads,
+        "rms_norm_eps": norm_eps,
+        "torch_dtype": "float16",
+        "vocab_size": vocab_size
+    }
+    if 'llama.rope.freq_base' in result.fields:
+        model_config['rope_theta'] = float(result.fields['llama.rope.freq_base'].parts[-1])
+    if 'llama.expert_count' in result.fields:
+        model_config['num_local_experts'] = int(result.fields['llama.expert_count'].parts[-1])
+        model_config['num_experts_per_tok'] = int(result.fields['llama.expert_used_count'].parts[-1])
+    json.dump(model_config, open(os.path.join(save_dir, "config.json"), 'w'), indent=2)
+
+    # write tensor
+    tensor_mapping = {
+        "token_embd": ("model.embed_tokens", vocab_size),
+        "output": ("lm_head", vocab_size),
+        "output_norm": ("model.norm", -1),
+        "blk.{bid}.attn_norm": ("model.layers.{bid}.input_layernorm", -1),
+        "blk.{bid}.attn_q": ("model.layers.{bid}.self_attn.q_proj", dim),
+        "blk.{bid}.attn_k": ("model.layers.{bid}.self_attn.k_proj", kv_dim),
+        "blk.{bid}.attn_v": ("model.layers.{bid}.self_attn.v_proj", kv_dim),
+        "blk.{bid}.attn_output": ("model.layers.{bid}.self_attn.o_proj", dim),
+        "blk.{bid}.attn_rot_embd": ("model.layers.{bid}.self_attn.rotary_emb.inv_freq", -1),
+        "blk.{bid}.ffn_norm": ("model.layers.{bid}.post_attention_layernorm", -1),
+        "blk.{bid}.ffn_up": ("model.layers.{bid}.mlp.up_proj", intermediate_size),
+        "blk.{bid}.ffn_down": ("model.layers.{bid}.mlp.down_proj", dim),
+        "blk.{bid}.ffn_gate": ("model.layers.{bid}.mlp.gate_proj", intermediate_size),
+        "blk.{bid}.ffn_up.{xid}": ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", intermediate_size),
+        "blk.{bid}.ffn_down.{xid}": ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", dim),
+        "blk.{bid}.ffn_gate.{xid}": ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", intermediate_size),
+        "blk.{bid}.ffn_gate_inp": ("model.layers.{bid}.block_sparse_moe.gate", model_config.get('num_local_experts', 1)),
+    }
+    mapping = {}
+    max_block_num = 200
+    max_expert_num = 8
+    for k, v in tensor_mapping.items():
+        for i in range(max_block_num):
+            for j in range(max_expert_num):
+                fk = k.format(bid=i, xid=j)
+                fv = v[0].format(bid=i, xid=j)
+                if k not in mapping:
+                    mapping[fk] = (fv, v[1])
+
+    for ts in result.tensors:
+        weight_type = torch.tensor(int(ts.tensor_type), dtype=torch.int)
+        layer, suffix = ts.name.rsplit(".", 1)
+        new_key, output_dim = mapping[layer]
+        new_key += f".{suffix}"
+        data = torch.tensor(ts.data)
+        if output_dim != -1:
+            data = data.view(output_dim, -1)
+        if weight_type > 1:
+            state_dict[new_key.replace("weight", "weight_type")] = weight_type
+        state_dict[new_key] = data
+    torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin"))
+
+
+
+if __name__ == '__main__':
+    import argparse
+    parser = argparse.ArgumentParser(description='Convert GGUF checkpoints to torch')
+
+    parser.add_argument('--input', type=str, help='The path to GGUF file')
+    parser.add_argument('--output', type=str, help='The path to output directory')
+    args = parser.parse_args()
+    convert_to_state_dict(args.input, args.output)

+ 46 - 0
kernels/ops.h

@@ -80,6 +80,24 @@ torch::Tensor awq_dequantize(
     int split_k_iters,
     int thx,
     int thy);
+
+void marlin_gemm(
+  const torch::Tensor& input,
+  const torch::Tensor& weights,
+        torch::Tensor& output,
+  const torch::Tensor& scales,
+        torch::Tensor& workspace);
+
+at::Tensor e8p_mm_origorder(
+    const at::Tensor& A,
+    const at::Tensor& B,
+    const at::Tensor& CB);
+
+void decompress_e8p_origorder(
+    torch::Tensor YIs,
+    torch::Tensor CB,
+    torch::Tensor &Y
+);
 #endif
 
 void squeezellm_gemm(
@@ -102,6 +120,34 @@ void gptq_shuffle(
   torch::Tensor q_perm,
   int bit);
 
+torch::Tensor ggml_dequantize(
+    torch::Tensor X,
+    int8_t type,
+    int64_t m,
+    int64_t n
+);
+
+torch::Tensor ggml_mul_mat_vec(
+    torch::Tensor W,  // quant weight
+    torch::Tensor X,  // input
+    int8_t type,
+    int64_t m
+);
+
+torch::Tensor ggml_mul_mat_vec_a8(
+    torch::Tensor W,  // quant weight
+    torch::Tensor X,  // input
+    int8_t type,
+    int64_t row
+);
+
+torch::Tensor ggml_mul_mat_a8(
+    torch::Tensor W,  // quant weight
+    torch::Tensor X,  // input
+    int8_t type,
+    int64_t row
+);
+
 #ifndef USE_ROCM
 using fptr_t = uint64_t;
 fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,

+ 7 - 0
kernels/pybind.cpp

@@ -52,10 +52,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
   // Quantization ops
   ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
   ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
+  ops.def("quip_decompress", &decompress_e8p_origorder, "decompress_packed_e8p");
+  ops.def("quip_gemv", &e8p_mm_origorder, "e8p_mm_origorder");
+  ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
 #endif
   ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
   ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
   ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
+  ops.def("ggml_dequantize", &ggml_dequantize, "ggml_dequantize");
+  ops.def("ggml_mul_mat_vec", &ggml_mul_mat_vec, "ggml_mul_mat_vec");
+  ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8, "ggml_mul_mat_vec_a8");
+  ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8, "ggml_mul_mat_a8");
 
   // Cache ops
   pybind11::module cache_ops = m.def_submodule("cache_ops", "Aphrodite cache ops");

+ 3925 - 0
kernels/quantization/gguf/gguf_kernel.cu

@@ -0,0 +1,3925 @@
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+#include <torch/all.h>
+#include <torch/python.h>
+#include <c10/cuda/CUDAGuard.h>
+
+
+#define QK_K 256
+#define K_QUANTS_PER_ITERATION 2
+#define WARP_SIZE 32
+#define K_SCALE_SIZE 12
+#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
+#define CUDA_QUANTIZE_BLOCK_SIZE 256
+#define GGML_CUDA_DMMV_X 32
+#define GGML_CUDA_MMV_Y 1
+
+
+// Data Structures
+// QK = number of values after dequantization
+// QR = QK / number of values before dequantization
+// QI = number of 32 bit integers before dequantization
+
+#define QK4_0 32
+#define QR4_0 2
+#define QI4_0 (QK4_0 / (4 * QR4_0))
+typedef struct {
+    half    d;              // delta
+    uint8_t qs[QK4_0 / 2];  // nibbles / quants
+} block_q4_0;
+
+#define QK4_1 32
+#define QR4_1 2
+#define QI4_1 (QK4_1 / (4 * QR4_1))
+typedef struct {
+    half2   dm;             // dm.x = delta, dm.y = min
+    uint8_t qs[QK4_1 / 2];  // nibbles / quants
+} block_q4_1;
+
+#define QK5_0 32
+#define QR5_0 2
+#define QI5_0 (QK5_0 / (4 * QR5_0))
+typedef struct {
+    half d;                 // delta
+    uint8_t qh[4];          // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2];  // nibbles / quants
+} block_q5_0;
+
+#define QK5_1 32
+#define QR5_1 2
+#define QI5_1 (QK5_1 / (4 * QR5_1))
+typedef struct {
+    half2 dm;               // dm.x = delta, dm.y = min
+    uint8_t qh[4];          // 5-th bit of quants
+    uint8_t qs[QK5_1 / 2];  // nibbles / quants
+} block_q5_1;
+
+#define QK8_0 32
+#define QR8_0 1
+#define QI8_0 (QK8_0 / (4 * QR8_0))
+typedef struct {
+    half    d;              // delta
+    int8_t  qs[QK8_0];      // quants
+} block_q8_0;
+
+#define QK8_1 32
+#define QR8_1 1
+#define QI8_1 (QK8_1 / (4 * QR8_1))
+typedef struct {
+    half2   ds;             // ds.x = delta, ds.y = sum
+    int8_t  qs[QK8_0];      // quants
+} block_q8_1;
+
+#define QR2_K 4
+#define QI2_K (QK_K / (4*QR2_K))
+typedef struct {
+    uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+    uint8_t qs[QK_K/4];      // quants
+    half2 dm;                // super-block scale for quantized scales/mins
+} block_q2_K;
+
+#define QR3_K 4
+#define QI3_K (QK_K / (4*QR3_K))
+typedef struct {
+    uint8_t hmask[QK_K/8];     // quants - high bit
+    uint8_t qs[QK_K/4];        // quants - low 2 bits
+    uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+    half d;             // super-block scale
+} block_q3_K;
+
+#define QR4_K 2
+#define QI4_K (QK_K / (4*QR4_K))
+typedef struct {
+    half2 dm;                  // super-block scale for quantized scales/mins
+    uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_K;
+
+#define QR5_K 2
+#define QI5_K (QK_K / (4*QR5_K))
+typedef struct {
+    half2 dm;                     // super-block scale for quantized scales/mins
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+    uint8_t qh[QK_K/8];           // quants, high bit
+    uint8_t qs[QK_K/2];           // quants, low 4 bits
+} block_q5_K;
+
+#define QR6_K 2
+#define QI6_K (QK_K / (4*QR6_K))
+typedef struct {
+    uint8_t ql[QK_K/2];   // quants, lower 4 bits
+    uint8_t qh[QK_K/4];   // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales
+    half    d;         // delta
+} block_q6_K;
+
+#define QR2_XXS 8
+#define QI2_XXS (QK_K / (4*QR2_XXS))
+typedef struct {
+    half d;
+    uint16_t qs[QK_K/8];
+} block_iq2_xxs;
+
+#define QR2_XS 8
+#define QI2_XS (QK_K / (4*QR2_XS))
+typedef struct {
+    half d;
+    uint16_t qs[QK_K/8];
+    uint8_t  scales[QK_K/32];
+} block_iq2_xs;
+
+static const __device__ uint64_t iq2xxs_grid[256] = {
+    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
+    0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
+    0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
+    0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
+    0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
+    0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
+    0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
+    0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
+    0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
+    0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
+    0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
+    0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
+    0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
+    0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
+    0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
+    0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
+    0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
+    0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
+    0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
+    0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
+    0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
+    0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
+    0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
+    0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
+    0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
+    0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
+    0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
+    0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
+    0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
+    0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
+    0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
+    0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
+    0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
+    0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
+    0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
+    0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
+    0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
+    0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
+    0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
+    0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
+    0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
+    0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
+    0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
+    0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
+    0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
+    0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
+    0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
+    0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
+    0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
+    0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
+    0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
+    0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
+    0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
+    0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
+    0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
+    0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
+    0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
+    0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
+    0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
+    0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
+    0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
+    0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
+    0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
+};
+
+static const __device__ uint64_t iq2xs_grid[512] = {
+    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
+    0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
+    0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
+    0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
+    0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
+    0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
+    0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
+    0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
+    0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
+    0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
+    0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
+    0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
+    0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
+    0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
+    0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
+    0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
+    0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
+    0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
+    0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
+    0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
+    0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
+    0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
+    0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
+    0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
+    0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
+    0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
+    0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
+    0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
+    0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
+    0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
+    0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
+    0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
+    0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
+    0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
+    0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
+    0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
+    0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
+    0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
+    0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
+    0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
+    0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
+    0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
+    0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
+    0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
+    0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
+    0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
+    0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
+    0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
+    0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
+    0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
+    0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
+    0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
+    0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
+    0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
+    0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
+    0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
+    0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
+    0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
+    0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
+    0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
+    0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
+    0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
+    0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
+    0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
+    0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
+    0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
+    0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
+    0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
+    0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
+    0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
+    0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
+    0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
+    0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
+    0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
+    0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
+    0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
+    0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
+    0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
+    0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
+    0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
+    0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
+    0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
+    0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
+    0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
+    0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
+    0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
+    0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
+    0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
+    0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
+    0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
+    0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
+    0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
+    0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
+    0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
+    0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
+    0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
+    0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
+    0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
+    0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
+    0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
+    0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
+    0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
+    0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
+    0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
+    0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
+    0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
+    0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
+    0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
+    0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
+    0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
+    0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
+    0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
+    0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
+    0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
+    0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
+    0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
+    0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
+    0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
+    0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
+    0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
+    0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
+    0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
+    0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
+    0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
+    0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
+    0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
+    0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
+};
+
+static const __device__ uint8_t ksigns_iq2xs[128] = {
+      0, 129, 130,   3, 132,   5,   6, 135, 136,   9,  10, 139,  12, 141, 142,  15,
+    144,  17,  18, 147,  20, 149, 150,  23,  24, 153, 154,  27, 156,  29,  30, 159,
+    160,  33,  34, 163,  36, 165, 166,  39,  40, 169, 170,  43, 172,  45,  46, 175,
+     48, 177, 178,  51, 180,  53,  54, 183, 184,  57,  58, 187,  60, 189, 190,  63,
+    192,  65,  66, 195,  68, 197, 198,  71,  72, 201, 202,  75, 204,  77,  78, 207,
+     80, 209, 210,  83, 212,  85,  86, 215, 216,  89,  90, 219,  92, 221, 222,  95,
+     96, 225, 226,  99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
+    240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
+};
+
+static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
+
+typedef half dfloat; // dequantize float
+typedef half2 dfloat2;
+typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
+typedef void (*to_fp16_cuda_t)(const void * __restrict__ x, dfloat * __restrict__ y, int k, cudaStream_t stream);
+typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
+typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
+typedef void (*load_tiles_cuda_t)(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row);
+typedef float (*vec_dot_q_mul_mat_cuda_t)(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
+
+// Utility function
+
+#if defined(USE_ROCM)
+
+#ifndef __has_builtin
+    #define __has_builtin(x) 0
+#endif
+
+typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
+static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
+    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+#if __has_builtin(__builtin_elementwise_sub_sat)
+    const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
+    return reinterpret_cast<const int &>(c);
+#else
+    int8x4_t c;
+    int16_t tmp;
+#pragma unroll
+    for (int i = 0; i < 4; i++) {
+        tmp = va[i] - vb[i];
+        if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
+        if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
+        c[i] = tmp;
+    }
+    return reinterpret_cast<int &>(c);
+#endif // __has_builtin(__builtin_elementwise_sub_sat)
+}
+
+static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
+#if __has_builtin(__builtin_amdgcn_sdot4)
+    c = __builtin_amdgcn_sdot4(a, b, c, false);
+#else
+    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+    c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
+#endif
+    return c;
+}
+#endif // defined(USE_ROCM)
+
+static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
+    const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
+    int x32 = 0;
+    x32 |= x16[0] <<  0;
+    x32 |= x16[1] << 16;
+    return x32;
+}
+
+static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
+    const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
+    int x32 = 0;
+    x32 |= x16[0] <<  0;
+    x32 |= x16[1] << 16;
+    return x32;
+}
+
+static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
+    return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+}
+
+static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
+    return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+}
+
+// Dequant functions
+static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
+    const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const dfloat d = x[ib].d;
+
+    const int vui = x[ib].qs[iqs];
+
+    v.x = __int2half_rn(vui & 0xF);
+    v.y = __int2half_rn(vui >> 4);
+
+    v = __hsub2(v, __floats2half2_rn(8.0f, 8.0f));
+    v = __hmul2(v, {d, d});
+}
+
+static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
+    const block_q4_1 * x = (const block_q4_1 *) vx;
+
+    const dfloat d = __low2half(x[ib].dm);
+    const dfloat m = __high2half(x[ib].dm);
+
+    const int vui = x[ib].qs[iqs];
+
+    v.x = __int2half_rn(vui & 0xF);
+    v.y = __int2half_rn(vui >> 4);
+
+    v = __hmul2(v, {d, d});
+    v = __hadd2(v, {m, m});
+}
+
+static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
+    const block_q5_0 * x = (const block_q5_0 *) vx;
+
+    const dfloat d = x[ib].d;
+
+    uint32_t qh;
+    memcpy(&qh, x[ib].qh, sizeof(qh));
+
+    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+
+    v.x = __int2half_rn((x[ib].qs[iqs] & 0xf) | xh_0);
+    v.y = __int2half_rn((x[ib].qs[iqs] >>  4) | xh_1);
+
+    v = __hsub2(v, __floats2half2_rn(16.0f, 16.0f));
+    v = __hmul2(v, {d, d});
+}
+
+static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
+    const block_q5_1 * x = (const block_q5_1 *) vx;
+
+    const dfloat d = __low2half(x[ib].dm);
+    const dfloat m = __high2half(x[ib].dm);
+
+    uint32_t qh;
+    memcpy(&qh, x[ib].qh, sizeof(qh));
+
+    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+
+    v.x = __int2half_rn((x[ib].qs[iqs] & 0xf) | xh_0);
+    v.y = __int2half_rn((x[ib].qs[iqs] >>  4) | xh_1);
+
+    v = __hmul2(v, {d, d});
+    v = __hadd2(v, {m, m});
+}
+
+static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
+    const block_q8_0 * x = (const block_q8_0 *) vx;
+
+    const dfloat d = x[ib].d;
+
+    v.x = __int2half_rn(x[ib].qs[iqs + 0]);
+    v.y = __int2half_rn(x[ib].qs[iqs + 1]);
+
+    v = __hmul2(v, {d, d});
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
+    const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
+
+    if (i >= k) {
+        return;
+    }
+
+    const int ib = i/qk; // block index
+    const int iqs = (i%qk)/qr; // quant index
+    const int iybs = i - i%qk; // y block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    dfloat2 v;
+    dequantize_kernel(vx, ib, iqs, v);
+
+    y[iybs + iqs + 0]        = v.x;
+    y[iybs + iqs + y_offset] = v.y;
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int i   = blockIdx.x;
+    const block_q2_K * x = (const block_q2_K *) vx;
+
+    const int tid = threadIdx.x;
+    const int n   = tid/32;
+    const int l   = tid - 32*n;
+    const int is  = 8*n + l/16;
+
+    const uint8_t q = x[i].qs[32*n + l];
+    dst_t * y = yy + i*QK_K + 128*n;
+
+    half dall = __low2half(x[i].dm);
+    half dmin = __high2half(x[i].dm);
+    y[l+ 0] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin,  __int2half_rn(x[i].scales[is+0] >> 4)));
+    y[l+32] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin,  __int2half_rn(x[i].scales[is+2] >> 4)));
+    y[l+64] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin,  __int2half_rn(x[i].scales[is+4] >> 4)));
+    y[l+96] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin,  __int2half_rn(x[i].scales[is+6] >> 4)));
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int i = blockIdx.x;
+    const block_q3_K * x = (const block_q3_K *) vx;
+
+    const int r = threadIdx.x/4;
+    const int tid = r/2;
+    const int is0 = r%2;
+    const int l0 = 16*is0 + 4*(threadIdx.x%4);
+    const int n = tid / 4;
+    const int j = tid - 4*n;
+
+    uint8_t m = 1 << (4*n + j);
+    int is = 8*n + 2*j + is0;
+    int shift = 2*j;
+
+    int8_t us = is <  4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
+                is <  8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
+                is < 12 ? (x[i].scales[is-8] >>  4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
+                          (x[i].scales[is-8] >>  4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
+    half d_all = x[i].d;
+    half dl = __hmul(d_all,  __int2half_rn(us - 32));
+
+    dst_t * y = yy + i*QK_K + 128*n + 32*j;
+    const uint8_t * q = x[i].qs + 32*n;
+    const uint8_t * hm = x[i].hmask;
+
+    for (int l = l0; l < l0+4; ++l) y[l] = __hmul(dl,  __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)));
+}
+
+static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
+    if (j < 4) {
+        d = q[j] & 63; m = q[j + 4] & 63;
+    } else {
+        d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+    const block_q4_K * x = (const block_q4_K *) vx;
+
+    const int i = blockIdx.x;
+
+    // assume 32 threads
+    const int tid = threadIdx.x;
+    const int il  = tid/8;
+    const int ir  = tid%8;
+    const int is  = 2*il;
+    const int n   = 4;
+
+    dst_t * y = yy + i*QK_K + 64*il + n*ir;
+
+    const half dall = __low2half(x[i].dm);
+    const half dmin = __high2half(x[i].dm);
+
+    const uint8_t * q = x[i].qs + 32*il + n*ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, sc, m);
+    const half d1 = __hmul(dall, __int2half_rn(sc));
+    const half m1 = __hmul(dmin,  __int2half_rn(m));
+    get_scale_min_k4(is + 1, x[i].scales, sc, m);
+    const half d2 = __hmul(dall, __int2half_rn(sc));
+    const half m2 = __hmul(dmin, __int2half_rn(m));
+    for (int l = 0; l < n; ++l) {
+        y[l + 0] = __hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1);
+        y[l +32] = __hsub(__hmul(d2,  __int2half_rn(q[l] >> 4)), m2);
+    }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+    const block_q5_K * x = (const block_q5_K *) vx;
+
+    const int i = blockIdx.x;
+
+    // assume 64 threads - this is very slightly better than the one below
+    const int tid = threadIdx.x;
+    const int il  = tid/16;   // il is in 0...3
+    const int ir  = tid%16;   // ir is in 0...15
+    const int is  = 2*il;     // is is in 0...6
+
+    dst_t * y = yy + i*QK_K + 64*il + 2*ir;
+
+    const half dall = __low2half(x[i].dm);
+    const half dmin = __high2half(x[i].dm);
+
+    const uint8_t * ql = x[i].qs + 32*il + 2*ir;
+    const uint8_t * qh = x[i].qh + 2*ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, sc, m);
+    const half d1 = __hmul(dall, __int2half_rn(sc)); const half m1 = __hmul(dmin, __int2half_rn(m));
+    get_scale_min_k4(is + 1, x[i].scales, sc, m);
+    const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m));
+
+    uint8_t   hm  = 1 << (2*il);
+    y[ 0] = __hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1);
+    y[ 1] = __hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1);
+    hm <<= 1;
+    y[32] = __hsub(__hmul(d2, __int2half_rn((ql[0] >>  4) + (qh[0] & hm ? 16 : 0))), m2);
+    y[33] = __hsub(__hmul(d2, __int2half_rn((ql[1] >>  4) + (qh[1] & hm ? 16 : 0))), m2);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+    const block_q6_K * x = (const block_q6_K *) vx;
+
+    const int i = blockIdx.x;
+
+    // assume 64 threads - this is very slightly better than the one below
+    const int tid = threadIdx.x;
+    const int ip  = tid/32;   // ip is 0 or 1
+    const int il  = tid - 32*ip; // 0...32
+    const int is  = 8*ip + il/16;
+
+    dst_t * y = yy + i*QK_K + 128*ip + il;
+
+    const half d = x[i].d;
+
+    const uint8_t * ql = x[i].ql + 64*ip + il;
+    const uint8_t   qh = x[i].qh[32*ip + il];
+    const int8_t  * sc = x[i].scales + is;
+
+    y[ 0] = __hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
+    y[32] = __hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
+    y[64] = __hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
+    y[96] = __hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int i   = blockIdx.x;
+    const block_iq2_xxs * x = (const block_iq2_xxs  *) vx;
+
+    const int tid = threadIdx.x;
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint16_t * q2 = x[i].qs + 4*ib;
+    const uint8_t  * aux8 = (const uint8_t *)q2;
+    const uint8_t  * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
+    const uint32_t aux32 = q2[2] | (q2[3] << 16);
+    const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f;
+    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+    for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int i   = blockIdx.x;
+    const block_iq2_xs * x = (const block_iq2_xs *) vx;
+
+    const int tid = threadIdx.x;
+    const int il = tid/8; // 0...3
+    const int ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint16_t * q2 = x[i].qs + 4*ib;
+    const uint8_t  * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
+    const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+    const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
+    for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
+
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
+    dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+template<typename dst_t>
+static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+static to_fp16_cuda_t ggml_get_to_fp16_cuda(int type) {
+    switch (type) {
+        case 2:
+            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+        case 3:
+            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+        case 6:
+            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+        case 7:
+            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+        case 8:
+            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+        case 10:
+            return dequantize_row_q2_K_cuda;
+        case 11:
+            return dequantize_row_q3_K_cuda;
+        case 12:
+            return dequantize_row_q4_K_cuda;
+        case 13:
+            return dequantize_row_q5_K_cuda;
+        case 14:
+            return dequantize_row_q6_K_cuda;
+        case 16:
+            return dequantize_row_iq2_xxs_cuda;
+        case 17:
+            return dequantize_row_iq2_xs_cuda;
+        default:
+            return nullptr;
+    }
+}
+
+// GEMV
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, dfloat * __restrict__ dst, const int ncols, const int nrows) {
+    // qk = quantized weights per x block
+    // qr = number of quantized weights per data value in x block
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int tid = threadIdx.x;
+
+    const int iter_stride = 2*GGML_CUDA_DMMV_X;
+    const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+    half2 tmp = __floats2half2_rn(0.0f, 0.0f); // two sums for f16 to take advantage of half2 intrinsics
+
+    for (int i = 0; i < ncols; i += iter_stride) {
+        const int col = i + vals_per_iter*tid;
+        const int ib = (row*ncols + col)/qk; // x block index
+        const int iqs = (col%qk)/qr; // x quant index
+        const int iybs = col - col%qk; // y block start index
+
+// processing >2 values per i iter is faster for fast GPUs
+#pragma unroll
+        for (int j = 0; j < vals_per_iter; j += 2) {
+            // process 2 vals per j iter
+
+            // dequantize
+            // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
+            dfloat2 v;
+            dequantize_kernel(vx, ib, iqs + j/qr, v);
+
+            // matrix multiplication
+            // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
+            tmp = __hadd2(tmp, __hmul2(v, {
+                y[iybs + iqs + j/qr + 0],
+                y[iybs + iqs + j/qr + y_offset]
+            }));
+        }
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp = __hadd2(tmp, __shfl_xor_sync(0xffffffff, tmp, mask, 32));
+    }
+
+    if (tid == 0) {
+        dst[row] = __hadd(tmp.x, tmp.y);
+    }
+}
+
+
+static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const dfloat * __restrict__ yy, dfloat * __restrict__ dst, const int ncols, int nrows) {
+
+    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q2_K * x = (const block_q2_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
+
+    const int step = 16/K_QUANTS_PER_ITERATION;
+
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2
+    const int q_offset = 32*im + l0;
+    const int s_offset = 8*im;
+    const int y_offset = 128*im + l0;
+
+    uint32_t aux[4];
+    const uint8_t * d = (const uint8_t *)aux;
+    const uint8_t * m = (const uint8_t *)(aux + 2);
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const half    * y = yy + i * QK_K + y_offset;
+        const uint8_t * q = x[i].qs + q_offset;
+
+        const float dall = __low2float(x[i].dm);
+        const float dmin = __high2float(x[i].dm);
+
+        const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
+        aux[0] = a[0] & 0x0f0f0f0f;
+        aux[1] = a[1] & 0x0f0f0f0f;
+        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
+        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
+
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            sum1 += __half2float(y[l+ 0]) * d[0] * ((q[l+ 0] >> 0) & 3)
+                  + __half2float(y[l+32]) * d[2] * ((q[l+ 0] >> 2) & 3)
+                  + __half2float(y[l+64]) * d[4] * ((q[l+ 0] >> 4) & 3)
+                  + __half2float(y[l+96]) * d[6] * ((q[l+ 0] >> 6) & 3)
+                  + __half2float(y[l+16]) * d[1] * ((q[l+16] >> 0) & 3)
+                  + __half2float(y[l+48]) * d[3] * ((q[l+16] >> 2) & 3)
+                  + __half2float(y[l+80]) * d[5] * ((q[l+16] >> 4) & 3)
+                  +__half2float(y[l+112]) * d[7] * ((q[l+16] >> 6) & 3);
+            sum2 += __half2float(y[l+ 0]) * m[0] + __half2float(y[l+32]) * m[2] + __half2float(y[l+64]) * m[4] + __half2float(y[ l+96]) * m[6]
+                  + __half2float(y[l+16]) * m[1] + __half2float(y[l+48]) * m[3] + __half2float(y[l+80]) * m[5] + __half2float(y[l+112]) * m[7];
+
+        }
+        tmp += dall * sum1 - dmin * sum2;
+
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (threadIdx.x == 0) {
+        dst[row] = __float2half(tmp);
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const dfloat * __restrict__ yy, dfloat * __restrict__ dst, const int ncols, int nrows) {
+
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q3_K * x = (const block_q3_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
+
+    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop
+    const int step = 16/K_QUANTS_PER_ITERATION;
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0....15 or 0...7
+
+    const uint8_t m = 1 << (4*im);
+
+    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2
+    const int q_offset =  32*im + l0;
+    const int y_offset = 128*im + l0;
+
+    uint16_t utmp[4];
+    const int8_t * s = (const int8_t *)utmp;
+
+    const uint16_t s_shift = 4*im;
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const half    * y  = yy + i * QK_K + y_offset;
+        const uint8_t * q = x[i].qs + q_offset;
+        const uint8_t * h = x[i].hmask + l0;
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
+        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
+
+        const float d = __half2float(x[i].d);
+
+        float sum = 0;
+        for (int l = 0; l < n; ++l) {
+            sum += __half2float(y[l+ 0]) * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+                 + __half2float(y[l+32]) * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+                 + __half2float(y[l+64]) * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+                 + __half2float(y[l+96]) * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
+            sum += __half2float(y[l+16]) * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+                 + __half2float(y[l+48]) * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+                 + __half2float(y[l+80]) * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+                + __half2float(y[l+112]) * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
+        }
+        tmp += d * sum;
+
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (threadIdx.x == 0) {
+        dst[row] = __float2half(tmp);
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const dfloat * __restrict__ yy, dfloat * __restrict__ dst, const int ncols, int nrows) {
+
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
+
+    const int step = 8/K_QUANTS_PER_ITERATION;           // 8 or 4
+
+    const int il  = tid/step;                            // 0...3
+    const int ir  = tid - step*il;                       // 0...7 or 0...3
+    const int n   = 2 * K_QUANTS_PER_ITERATION;          // 2 or 4
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+#if K_QUANTS_PER_ITERATION == 2
+    uint32_t q32[4];
+    const uint8_t * q4 = (const uint8_t *)q32;
+#else
+    uint16_t q16[4];
+    const uint8_t * q4 = (const uint8_t *)q16;
+#endif
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const half   * y1 = yy + i*QK_K + y_offset;
+        const half   * y2 = y1 + 128;
+
+        const float dall = __low2float(x[i].dm);
+        const float dmin = __high2float(x[i].dm);
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+#if K_QUANTS_PER_ITERATION == 2
+        const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
+        const uint32_t * q2 = q1 + 16;
+
+        q32[0] = q1[0] & 0x0f0f0f0f;
+        q32[1] = q1[0] & 0xf0f0f0f0;
+        q32[2] = q2[0] & 0x0f0f0f0f;
+        q32[3] = q2[0] & 0xf0f0f0f0;
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < 4; ++l) {
+            s.x += __half2float(y1[l]) * q4[l+0]; s.y += __half2float(y1[l+32]) * q4[l+ 4];
+            s.z += __half2float(y2[l]) * q4[l+8]; s.w += __half2float(y2[l+32]) * q4[l+12];
+            smin += __half2float(y1[l]) * sc[2] + __half2float(y1[l+32]) * sc[3] + __half2float(y2[l]) * sc[6] + __half2float(y2[l+32]) * sc[7];
+        }
+        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#else
+        const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
+        const uint16_t * q2 = q1 + 32;
+
+        q16[0] = q1[0] & 0x0f0f;
+        q16[1] = q1[0] & 0xf0f0;
+        q16[2] = q2[0] & 0x0f0f;
+        q16[3] = q2[0] & 0xf0f0;
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < 2; ++l) {
+            s.x += __half2float(y1[l]) * q4[l+0]; s.y += __half2float(y1[l+32]) * q4[l+2];
+            s.z += __half2float(y2[l]) * q4[l+4]; s.w += __half2float(y2[l+32]) * q4[l+6];
+            smin += __half2float(y1[l]) * sc[2] + __half2float(y1[l+32]) * sc[3] + __half2float(y2[l]) * sc[6] + __half2float(y2[l+32]) * sc[7];
+        }
+        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#endif
+
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = __float2half(tmp);
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const dfloat * __restrict__ yy, dfloat * __restrict__ dst, const int ncols) {
+
+    const int row = blockIdx.x;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int tid = threadIdx.x/2;  // 0...15
+    const int ix  = threadIdx.x%2;
+
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 2;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    const uint8_t hm1  = 1 << (2*im);
+    const uint8_t hm2  = hm1 << 4;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    uint16_t q16[8];
+    const uint8_t * q4 = (const uint8_t *)q16;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2) {
+
+        const uint8_t * ql1 = x[i].qs + q_offset;
+        const uint8_t * qh  = x[i].qh + l0;
+        const half    * y1  = yy + i*QK_K + y_offset;
+        const half    * y2  = y1 + 128;
+
+        const float dall = __low2float(x[i].dm);
+        const float dmin = __high2float(x[i].dm);
+
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 sum = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        const uint16_t * q1 = (const uint16_t *)ql1;
+        const uint16_t * q2 = q1 + 32;
+        q16[0] = q1[0] & 0x0f0f;
+        q16[1] = q1[8] & 0x0f0f;
+        q16[2] = (q1[0] >> 4) & 0x0f0f;
+        q16[3] = (q1[8] >> 4) & 0x0f0f;
+        q16[4] = q2[0] & 0x0f0f;
+        q16[5] = q2[8] & 0x0f0f;
+        q16[6] = (q2[0] >> 4) & 0x0f0f;
+        q16[7] = (q2[8] >> 4) & 0x0f0f;
+        for (int l = 0; l < n; ++l) {
+            sum.x += __half2float(y1[l+ 0]) * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+                   + __half2float(y1[l+16]) * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+            sum.y += __half2float(y1[l+32]) * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+                   + __half2float(y1[l+48]) * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+            sum.z += __half2float(y2[l+ 0]) * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+                   + __half2float(y2[l+16]) * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+            sum.w += __half2float(y2[l+32]) * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+                   + __half2float(y2[l+48]) * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+            smin += (__half2float(y1[l]) + __half2float(y1[l+16])) * sc[2] + (__half2float(y1[l+32]) + __half2float(y1[l+48])) * sc[3]
+                  + (__half2float(y2[l]) + __half2float(y2[l+16])) * sc[6] + (__half2float(y2[l+32]) + __half2float(y2[l+48])) * sc[7];
+        }
+        tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (threadIdx.x == 0) {
+        dst[row] = __float2half(tmp);
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const dfloat * __restrict__ yy, dfloat * __restrict__ dst, const int ncols, int nrows) {
+
+    static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_q6_K * x = (const block_q6_K *)vx + ib0;
+
+    const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
+
+    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8
+
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+#if K_QUANTS_PER_ITERATION == 1
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
+    const int is = 0;
+#else
+    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
+    const int is = in / 4;
+#endif
+    const int ql_offset = 64*im + l0;
+    const int qh_offset = 32*im + l0;
+    const int s_offset  =  8*im + is;
+    const int y_offset = 128*im + l0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        const half    * y  = yy + i * QK_K + y_offset;
+        const uint8_t * ql = x[i].ql + ql_offset;
+        const uint8_t * qh = x[i].qh + qh_offset;
+        const int8_t  * s  = x[i].scales + s_offset;
+
+        const float d = __half2float(x[i].d);
+
+#if K_QUANTS_PER_ITERATION == 1
+        float sum = __half2float(y[ 0]) * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+                  + __half2float(y[16]) * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+                  + __half2float(y[32]) * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+                  + __half2float(y[48]) * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+                  + __half2float(y[64]) * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+                  + __half2float(y[80]) * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+                  + __half2float(y[96]) * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+                  +__half2float(y[112]) * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
+        tmp += sum;
+#else
+        float sum = 0;
+        for (int l = 0; l < 4; ++l) {
+            sum += __half2float(y[l+ 0]) * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+                 + __half2float(y[l+32]) * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+                 + __half2float(y[l+64]) * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+                 + __half2float(y[l+96]) * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
+        }
+        tmp += sum;
+#endif
+
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (tid == 0) {
+        dst[row] = __float2half(tmp);
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_iq2_xxs(const void * __restrict__ vx, const dfloat * __restrict__ yy, dfloat * __restrict__ dst, const int ncols, int nrows) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_iq2_xxs * x = (const block_iq2_xxs *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    const int tid = threadIdx.x/4;
+    const int ix  = threadIdx.x%4;
+
+    const int q_offset = tid * 4;
+    const int y_offset = tid * 32;
+
+    for (int i = ix; i < num_blocks_per_row; i += 4) {
+
+        const half    * y = yy + i * QK_K + y_offset;
+        const uint16_t * q = x[i].qs + q_offset;
+
+        const uint8_t  * aux8 = (const uint8_t *)q;
+        uint32_t aux32 = q[2] | (q[3] << 16);
+        float sumi = 0;
+        for (int l = 0; l < 4; ++l) {
+            const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
+            const uint8_t  signs = ksigns_iq2xs[aux32 & 127];
+            for (int j = 0; j < 8; ++j) {
+                sumi += __half2float(y[j]) * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+            }
+            y += 8;
+            aux32 >>= 7;
+        }
+        tmp += sumi * __half2float(x[i].d) * (0.5f + aux32) * 0.25f;;
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (threadIdx.x == 0) {
+        dst[row] = __float2half(tmp);
+    }
+}
+
+static __global__ void dequantize_mul_mat_vec_iq2_xs(const void * __restrict__ vx, const dfloat * __restrict__ yy, dfloat * __restrict__ dst, const int ncols, int nrows) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    if (row > nrows) return;
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const block_iq2_xs * x = (const block_iq2_xs *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    const int tid = threadIdx.x/4;
+    const int ix  = threadIdx.x%4;
+
+    const int q_offset = tid * 4;
+    const int s_offset = tid;
+    const int y_offset = tid * 32;
+
+    for (int i = ix; i < num_blocks_per_row; i += 4) {
+        const half    * y = yy + i * QK_K + y_offset;
+        const uint16_t * q = x[i].qs + q_offset;
+        const uint8_t ls1 = x[i].scales[s_offset] & 0xf;
+        const uint8_t ls2 = x[i].scales[s_offset] >>  4;
+
+        float sumi1 = 0;
+        for (int l = 0; l < 2; ++l) {
+            const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q[l] & 511));
+            const uint8_t  signs = ksigns_iq2xs[q[l] >> 9];
+            for (int j = 0; j < 8; ++j) {
+                sumi1 += __half2float(y[j]) * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+            }
+            y += 8;
+        }
+        float sumi2 = 0;
+        for (int l = 2; l < 4; ++l) {
+            const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q[l] & 511));
+            const uint8_t  signs = ksigns_iq2xs[q[l] >> 9];
+            for (int j = 0; j < 8; ++j) {
+                sumi2 += __half2float(y[j]) * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+            }
+            y += 8;
+        }
+        const float d = __half2float(x[i].d) * 0.25f;
+        tmp += d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);;
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (threadIdx.x == 0) {
+        dst[row] = __float2half(tmp);
+    }
+}
+
+static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, dfloat * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, dfloat * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, dfloat * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, dfloat * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, dfloat * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const dfloat * y, dfloat * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const dfloat * y, dfloat * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const dfloat * y, dfloat * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const dfloat * y, dfloat * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const dim3 block_dims(32, 1, 1);
+    dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const dfloat * y, dfloat * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int ny = 2 / K_QUANTS_PER_ITERATION;
+    const int block_num_y = (nrows + ny - 1) / ny;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(32, ny, 1);
+    dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_iq2_xxs_cuda(const void * vx, const dfloat * y, dfloat * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const dim3 block_dims(32, 1, 1);
+    dequantize_mul_mat_vec_iq2_xxs<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_iq2_xs_cuda(const void * vx, const dfloat * y, dfloat * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const dim3 block_dims(32, 1, 1);
+    dequantize_mul_mat_vec_iq2_xs<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+// Q8 gemv
+static __global__ void quantize_q8_1(const half * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
+    const int ix = blockDim.x*blockIdx.x + threadIdx.x;
+    if (ix >= kx_padded) {
+        return;
+    }
+    const int iy = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i_padded = iy*kx_padded + ix;
+
+    block_q8_1 * y = (block_q8_1 *) vy;
+
+    const int ib = i_padded / QK8_1; // block index
+    const int iqs = i_padded % QK8_1; // quant index
+
+    const float xi = ix < kx ? __half2float(x[iy*kx + ix]) : 0.0f;
+    float amax = fabsf(xi);
+    float sum = xi;
+
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
+        sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
+    }
+
+    const float d = amax / 127;
+    const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
+
+    y[ib].qs[iqs] = q;
+
+    if (iqs > 0) {
+        return;
+    }
+
+    y[ib].ds.x = __float2half(d);
+    y[ib].ds.y = __float2half(sum);
+}
+
+static void quantize_row_q8_1_cuda(const half * x, void * vy, const int kx, const int ky, cudaStream_t stream) {
+    const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
+    const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    const dim3 num_blocks(block_num_x, ky, 1);
+    const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
+    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
+}
+
+#define VDR_Q4_0_Q8_1_MMVQ 2
+#define VDR_Q4_0_Q8_1_MMQ  4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
+    const int * v, const int * u, const float & d4, const half2 & ds8) {
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+        // SIMD dot product of quantized values
+        sumi = __dp4a(vi0, u[2*i+0], sumi);
+        sumi = __dp4a(vi1, u[2*i+1], sumi);
+    }
+
+    const float2 ds8f = __half22float2(ds8);
+
+    // second part effectively subtracts 8 from each quant value
+    return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
+}
+
+#define VDR_Q4_1_Q8_1_MMVQ 2
+#define VDR_Q4_1_Q8_1_MMQ  4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
+    const int * v, const int * u, const half2 & dm4, const half2 & ds8) {
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+        // SIMD dot product of quantized values
+        sumi = __dp4a(vi0, u[2*i+0], sumi);
+        sumi = __dp4a(vi1, u[2*i+1], sumi);
+    }
+
+    const float2 tmp = __half22float2(__hmul2(dm4, ds8));
+    const float d4d8 = tmp.x;
+    const float m4s8 = tmp.y;
+
+    // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
+    return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
+}
+
+#define VDR_Q5_0_Q8_1_MMVQ 2
+#define VDR_Q5_0_Q8_1_MMQ  4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
+    const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
+        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+        sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4
+        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12
+        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20
+        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28
+        sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+    }
+
+    const float2 ds8f = __half22float2(ds8);
+
+    // second part effectively subtracts 16 from each quant value
+    return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
+}
+
+
+#define VDR_Q5_1_Q8_1_MMVQ 2
+#define VDR_Q5_1_Q8_1_MMQ  4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
+    const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
+        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+        sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4
+        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12
+        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20
+        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28
+        sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+    }
+
+    const float2 tmp = __half22float2(__hmul2(dm5, ds8));
+    const float d5d8 = tmp.x;
+    const float m5s8 = tmp.y;
+
+    // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
+    return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
+}
+
+#define VDR_Q8_0_Q8_1_MMVQ 2
+#define VDR_Q8_0_Q8_1_MMQ 8
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
+    const int * v, const int * u, const float & d8_0, const float & d8_1) {
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        // SIMD dot product of quantized values
+        sumi = __dp4a(v[i], u[i], sumi);
+    }
+    return d8_0*d8_1 * sumi;
+}
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(
+    const int * v, const int * u, const half2 & dm8, const half2 & ds8) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        // SIMD dot product of quantized values
+        sumi = __dp4a(v[i], u[i], sumi);
+    }
+
+    const float2 tmp = __half22float2(__hmul2(dm8, ds8));
+    const float d8d8 = tmp.x;
+    const float m8s8 = tmp.y;
+
+    // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
+    return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
+}
+
+#define VDR_Q2_K_Q8_1_MMVQ 1
+#define VDR_Q2_K_Q8_1_MMQ  2
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
+    const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+    const half2 & dm2, const float * __restrict__ d8) {
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR2_K; ++i) {
+        const int sc = scales[2*i];
+
+        const int vi = (v >> (2*i)) & 0x03030303;
+
+        sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
+
+        // fill int with 4x m
+        int m = sc >> 4;
+        m |= m <<  8;
+        m |= m << 16;
+        sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
+    }
+
+    const float2 dm2f = __half22float2(dm2);
+
+    return dm2f.x*sumf_d - dm2f.y*sumf_m;
+}
+
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+    const half2 & dm2, const float & d8) {
+    int sumi_d = 0;
+    int sumi_m = 0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
+        int sumi_d_sc = 0;
+
+        const int sc = scales[i0 / (QI8_1/2)];
+
+        // fill int with 4x m
+        int m = sc >> 4;
+        m |= m <<  8;
+        m |= m << 16;
+
+#pragma unroll
+        for (int i = i0; i < i0 + QI8_1/2; ++i) {
+            sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
+            sumi_m    = __dp4a(m,    u[i], sumi_m); // multiply sum of q8_1 values with m
+        }
+
+        sumi_d += sumi_d_sc * (sc & 0xF);
+    }
+
+    const float2 dm2f = __half22float2(dm2);
+
+    return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
+}
+
+#define VDR_Q3_K_Q8_1_MMVQ 1
+#define VDR_Q3_K_Q8_1_MMQ  2
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
+    const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+    const int & scale_offset, const float & d3, const float * __restrict__ d8) {
+
+    float sumf = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR3_K; ++i) {
+        const int isc = scale_offset + 2*i;
+
+        const int isc_low = isc % (QK_K/32);
+        const int sc_shift_low = 4 * (isc / (QK_K/32));
+        const int sc_low  = (scales[isc_low] >> sc_shift_low) & 0xF;
+
+        const int isc_high = isc % (QK_K/64);
+        const int sc_shift_high = 2 * (isc / (QK_K/64));
+        const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
+
+        const int sc = (sc_low | sc_high) - 32;
+
+        const int vil = (vl >> (2*i)) & 0x03030303;
+
+        const int vih = ((vh >> i) << 2) & 0x04040404;
+
+        const int vi = __vsubss4(vil, vih);
+
+        sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
+    }
+
+    return d3 * sumf;
+}
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
+    const float & d3, const float & d8) {
+    int sumi = 0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
+        int sumi_sc = 0;
+
+        for (int i = i0; i < i0 + QI8_1/2; ++i) {
+            sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product
+        }
+
+        sumi += sumi_sc * scales[i0 / (QI8_1/2)];
+    }
+
+    return d3*d8 * sumi;
+}
+
+#define VDR_Q4_K_Q8_1_MMVQ 2
+#define VDR_Q4_K_Q8_1_MMQ  8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR4_K; ++i) {
+        const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
+        const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
+
+        const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
+        const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u
+
+        sumf_d += d8[i] * (dot1 * sc[i]);
+        sumf_m += d8[i] * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
+    }
+
+    const float2 dm4f = __half22float2(dm4);
+    return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
+        int sumi_d = 0;
+
+#pragma unroll
+        for (int j = 0; j < QI8_1; ++j) {
+            sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
+        }
+
+        const float2 ds8f = __half22float2(ds8[i]);
+
+        sumf_d += ds8f.x * (sc[i] * sumi_d);
+        sumf_m += ds8f.y *   m[i]; // sum of q8_1 block * q4_K min val
+    }
+
+    const float2 dm4f = __half22float2(dm4);
+
+    return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+#define VDR_Q5_K_Q8_1_MMVQ 2
+#define VDR_Q5_K_Q8_1_MMQ  8
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
+    const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K; ++i) {
+        const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
+        const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
+
+        const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
+        const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
+
+        const int v0i = vl0i | vh0i;
+        const int v1i = vl1i | vh1i;
+
+        const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
+        const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u
+
+        sumf_d += d8[i] * (dot1 * sc[i]);
+        sumf_m += d8[i] * (dot2 * m[i]);
+    }
+
+    const float2 dm5f = __half22float2(dm5);
+    return dm5f.x*sumf_d - dm5f.y*sumf_m;
+}
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
+        int sumi_d = 0;
+
+#pragma unroll
+        for (int j = 0; j < QI8_1; ++j) {
+            sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
+        }
+
+        const float2 ds8f = __half22float2(ds8[i]);
+
+        sumf_d += ds8f.x * (sc[i] * sumi_d);
+        sumf_m += ds8f.y *   m[i]; // sum of q8_1 block * q4_K min val
+    }
+
+    const float2 dm4f = __half22float2(dm4);
+
+    return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+#define VDR_Q6_K_Q8_1_MMVQ 1
+#define VDR_Q6_K_Q8_1_MMQ  8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
+    const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
+    const float & d, const float * __restrict__ d8) {
+    float sumf = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR6_K; ++i) {
+        const int sc = scales[4*i];
+        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+        const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
+        const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
+
+        sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
+    }
+
+    return d*sumf;
+}
+
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
+    const float & d6, const float * __restrict__ d8) {
+    float sumf_d = 0.0f;
+
+#pragma unroll
+    for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
+        int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
+
+#pragma unroll
+        for (int i = i0; i < i0 + 2; ++i) {
+            sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
+            sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
+
+            sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
+            sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
+        }
+
+        sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
+    }
+
+    return d6 * sumf_d;
+}
+
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
+
+    int v[VDR_Q4_0_Q8_1_MMVQ];
+    int u[2*VDR_Q4_0_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
+        v[i]     = get_int_from_uint8(bq4_0->qs, iqs + i);
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
+    }
+
+    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, __half2float(bq4_0->d), bq8_1->ds);
+}
+
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y];
+    __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0];
+    *x_ql = tile_x_qs;
+    *x_dm = (half2 *) tile_x_d;
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    const int kbx  = k / QI4_0;
+    const int kqsx = k % QI4_0;
+
+    const block_q4_0 * bx0 = (const block_q4_0 *) vx;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
+        // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
+        int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+        x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = __half2float(bxi->d);
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh; (void)x_sc;
+
+    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+    const float * x_dmf = (const float *) x_dm;
+
+    int u[2*VDR_Q4_0_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
+    }
+
+    return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
+         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
+}
+
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
+
+    int v[VDR_Q4_1_Q8_1_MMVQ];
+    int u[2*VDR_Q4_1_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
+        v[i]    = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
+    }
+
+    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
+}
+
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    __shared__ int   tile_x_qs[mmq_y * (WARP_SIZE) +     + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1];
+    *x_ql = tile_x_qs;
+    *x_dm = tile_x_dm;
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    const int kbx  = k / QI4_1;
+    const int kqsx = k % QI4_1;
+
+    const block_q4_1 * bx0 = (const block_q4_1 *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
+        int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
+        x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+
+    int u[2*VDR_Q4_1_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
+    }
+
+    return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+        (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
+         y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
+
+    int vl[VDR_Q5_0_Q8_1_MMVQ];
+    int vh[VDR_Q5_0_Q8_1_MMVQ];
+    int  u[2*VDR_Q5_0_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
+        vl[i]    = get_int_from_uint8(bq5_0->qs, iqs + i);
+        vh[i]    = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i));
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0);
+    }
+
+    return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, __half2float(bq5_0->d), bq8_1->ds);
+}
+
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    __shared__ int  tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
+    __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0];
+
+    *x_ql = tile_x_ql;
+    *x_dm = (half2 *) tile_x_d;
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    const int kbx  = k / QI5_0;
+    const int kqsx = k % QI5_0;
+
+    const block_q5_0 * bx0 = (const block_q5_0 *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
+        const int ql = get_int_from_uint8(bxi->qs, kqsx);
+        const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
+
+        int qs0 = (ql >>  0)   & 0x0F0F0F0F;
+        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
+        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12
+        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20
+        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
+        qs0     = __vsubss4(qs0, 0x10101010); // subtract 16
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
+
+        int qs1 = (ql >>  4)   & 0x0F0F0F0F;
+        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
+        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
+        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20
+        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
+        qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
+    const int kbxd = k % blocks_per_tile_x_row;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
+        int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+        x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = __half2float(bxi->d);
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+    const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
+
+    int u[2*VDR_Q5_0_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
+    }
+
+    return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
+        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
+
+    int vl[VDR_Q5_1_Q8_1_MMVQ];
+    int vh[VDR_Q5_1_Q8_1_MMVQ];
+    int  u[2*VDR_Q5_1_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
+        vl[i]   = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
+        vh[i]   = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
+        u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
+    }
+
+    return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
+}
+
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1];
+
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    const int kbx  = k / QI5_1;
+    const int kqsx = k % QI5_1;
+
+    const block_q5_1 * bx0 = (const block_q5_1 *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
+
+        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
+
+        int qs0 = (ql >>  0) & 0x0F0F0F0F;
+        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
+        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12
+        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
+        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
+
+        int qs1 = (ql >>  4) & 0x0F0F0F0F;
+        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
+        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
+        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
+        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
+
+        x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
+        int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+    const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
+
+    int u[2*VDR_Q5_1_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
+        u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l)         % WARP_SIZE];
+        u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
+    }
+
+    return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+        (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
+}
+
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
+
+    int v[VDR_Q8_0_Q8_1_MMVQ];
+    int u[VDR_Q8_0_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
+        v[i] = get_int_from_int8(bq8_0->qs, iqs + i);
+        u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+    }
+
+    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, __half2float(bq8_0->d), __low2float(bq8_1->ds));
+}
+
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y];
+    __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
+
+    *x_ql = tile_x_qs;
+    *x_dm = (half2 *) tile_x_d;
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    const int kbx  = k / QI8_0;
+    const int kqsx = k % QI8_0;
+    float * x_dmf = (float *) x_dm;
+
+    const block_q8_0 * bx0 = (const block_q8_0 *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
+        int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+        x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = __half2float(bxi->d);
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
+
+    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
+        (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
+         y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q2_K * bq2_K = (const block_q2_K *) vbq;
+
+    const int bq8_offset = QR2_K * (iqs / QI8_1);
+    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+    const uint8_t * scales = bq2_K->scales + scale_offset;
+
+    const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
+    int    u[QR2_K];
+    float d8[QR2_K];
+
+#pragma unroll
+    for (int i = 0; i < QR2_K; ++ i) {
+        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+        d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
+    }
+
+    return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
+}
+
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K];
+    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/4)     + mmq_y/4];
+
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+    *x_sc = tile_x_sc;
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    const int kbx  = k / QI2_K;
+    const int kqsx = k % QI2_K;
+
+    const block_q2_K * bx0 = (const block_q2_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
+    const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
+        int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
+        x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
+        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    const int kbx = k / QI2_K;
+    const int ky  = (k % QI2_K) * QR2_K;
+    const float * y_df = (const float *) y_ds;
+
+    int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
+
+    const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
+    const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
+
+#pragma unroll
+    for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
+        v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
+    }
+
+    const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
+
+    const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
+    return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q3_K * bq3_K = (const block_q3_K *) vbq;
+
+    const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
+    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+    const float d = __half2float(bq3_K->d);
+
+    const int vl = get_int_from_uint8(bq3_K->qs, iqs);
+
+    // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+    const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
+
+    int    u[QR3_K];
+    float d8[QR3_K];
+
+#pragma unroll
+    for (int i = 0; i < QR3_K; ++i) {
+        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+        d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
+    }
+
+    return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
+}
+
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K];
+    __shared__ int   tile_x_qh[mmq_y * (WARP_SIZE/2)     + mmq_y/2];
+    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/4)     + mmq_y/4];
+
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+    *x_qh = tile_x_qh;
+    *x_sc = tile_x_sc;
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    const int kbx  = k / QI3_K;
+    const int kqsx = k % QI3_K;
+
+    const block_q3_K * bx0 = (const block_q3_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
+    const int kbxd = k % blocks_per_tile_x_row;
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
+        int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
+        x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = __half2float(bxi->d);
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
+        int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
+        // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+        x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+        int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
+
+        const int ksc = k % (QI3_K/4);
+
+        const int ksc_low = ksc % (QI3_K/8);
+        const int shift_low = 4 * (ksc / (QI3_K/8));
+        const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+
+        const int ksc_high = QI3_K/8;
+        const int shift_high = 2 * ksc;
+        const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+
+        const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
+
+        x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+    const int kbx  = k / QI3_K;
+    const int ky  = (k % QI3_K) * QR3_K;
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
+
+    const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+
+    int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
+
+#pragma unroll
+    for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
+        const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
+        const int shift = 2 * ((ky % 32) / 8);
+        const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
+
+        const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
+        const int vlh = (vh << 2) & 0x04040404;
+
+        v[l] = __vsubss4(vll, vlh);
+    }
+
+    const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
+    return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+    int    v[2];
+    int    u[2*QR4_K];
+    float d8[QR4_K];
+
+    // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
+    const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
+
+    // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
+    // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
+    // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
+    // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+
+    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+    v[0] = q4[0];
+    v[1] = q4[4];
+
+    const uint16_t * scales = (const uint16_t *)bq4_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
+
+    for (int i = 0; i < QR4_K; ++i) {
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        d8[i] = __low2float(bq8i->ds);
+
+        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+        u[2*i+0] = q8[0];
+        u[2*i+1] = q8[4];
+    }
+
+    return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
+}
+
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K];
+    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/8)     + mmq_y/8];
+
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+    *x_sc = tile_x_sc;
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    const int kbx  = k / QI4_K; // == 0 if QK_K == 256
+    const int kqsx = k % QI4_K; // == k if QK_K == 256
+
+    const block_q4_K * bx0 = (const block_q4_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
+        x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
+    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
+        int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
+        if (need_check) {
+            i = min(i, i_max);
+        }
+        const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
+        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
+
+        const int * scales = (const int *) bxi->scales;
+
+        const int ksc = k % (WARP_SIZE/8);
+        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    (void)x_qh;
+
+    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
+
+    const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
+    return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
+                                      x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+
+    int   vl[2];
+    int   vh[2];
+    int    u[2*QR5_K];
+    float d8[QR5_K];
+
+    const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
+    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+    const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
+
+    vl[0] = ql[0];
+    vl[1] = ql[4];
+
+    vh[0] = qh[0] >> bq8_offset;
+    vh[1] = qh[4] >> bq8_offset;
+
+    const uint16_t * scales = (const uint16_t *)bq5_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K; ++i) {
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        d8[i] = __low2float(bq8i->ds);
+
+        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+        u[2*i+0] = q8[0];
+        u[2*i+1] = q8[4];
+    }
+
+    return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
+}
+
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K];
+    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/8)     + mmq_y/8];
+
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+    *x_sc = tile_x_sc;
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    const int kbx  = k / QI5_K; // == 0 if QK_K == 256
+    const int kqsx = k % QI5_K; // == k if QK_K == 256
+
+    const block_q5_K * bx0 = (const block_q5_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
+        const int ky = QR5_K*kqsx;
+
+        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
+        const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+        const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+
+        const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
+        const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
+
+        x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
+        x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
+    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
+        int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
+        x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
+
+        const int * scales = (const int *) bxi->scales;
+
+        const int ksc = k % (WARP_SIZE/8);
+
+        // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+        int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+        scales8    |= (scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030; // upper 2 bits
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
+
+    const int index_x = i * (QR5_K*WARP_SIZE + 1) +  QR5_K*k;
+    const int index_y = j * WARP_SIZE             + (QR5_K*k) % WARP_SIZE;
+    return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
+                                      x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+    const block_q6_K * bq6_K = (const block_q6_K *) vbq;
+
+    const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
+    const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
+    const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
+
+    const int vl = get_int_from_uint8(bq6_K->ql, iqs);
+    const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
+
+    const int8_t * scales = bq6_K->scales + scale_offset;
+
+    int    u[QR6_K];
+    float d8[QR6_K];
+
+#pragma unroll
+    for (int i = 0; i < QR6_K; ++i) {
+        u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
+        d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds);
+    }
+
+    return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, __half2float(bq6_K->d), d8);
+}
+
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+    __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y];
+    __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K];
+    __shared__ int   tile_x_sc[mmq_y * (WARP_SIZE/8)     + mmq_y/8];
+
+    *x_ql = tile_x_ql;
+    *x_dm = tile_x_dm;
+    *x_sc = tile_x_sc;
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
+    const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+    int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
+    const int kbx  = k / QI6_K; // == 0 if QK_K == 256
+    const int kqsx = k % QI6_K; // == k if QK_K == 256
+
+    const block_q6_K * bx0 = (const block_q6_K *) vx;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + i_offset;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
+        const int ky = QR6_K*kqsx;
+
+        const int ql = get_int_from_uint8(bxi->ql, kqsx);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
+        const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
+        const int qh1 =  (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4))))       & 0x30303030;
+
+        const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
+        const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
+
+        x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
+    const int kbxd = k % blocks_per_tile_x_row;          // == 0 if QK_K == 256
+    float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
+        int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+        x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = __half2float(bxi->d);
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
+
+        x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
+    }
+}
+
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
+    const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+    const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+    const float * x_dmf = (const float *) x_dm;
+    const float * y_df  = (const float *) y_ds;
+
+    const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
+
+    const int index_x = i * (QR6_K*WARP_SIZE + 1) +  QR6_K*k;
+    const int index_y = j * WARP_SIZE             + (QR6_K*k) % WARP_SIZE;
+    return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
+}
+
+static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
+
+    const int ib32 = iqs;
+    const uint16_t * q2 = bq2->qs + 4*ib32;
+    const uint8_t  * aux8 = (const uint8_t *)q2;
+    const int8_t   * q8 = bq8_1[ib32].qs;
+    uint32_t aux32 = q2[2] | (q2[3] << 16);
+    int sumi = 0;
+    for (int l = 0; l < 4; ++l) {
+        const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
+        const uint8_t  signs = ksigns_iq2xs[aux32 & 127];
+        for (int j = 0; j < 8; ++j) {
+            sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+        }
+        q8 += 8;
+        aux32 >>= 7;
+    }
+    const float d = __half2float(bq2->d) * (0.5f + aux32) * __half2float(bq8_1[ib32].ds.x) * 0.25f;
+    return d * sumi;
+}
+
+static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+    const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
+
+    const int ib32 = iqs;
+    const uint16_t * q2 = bq2->qs + 4*ib32;
+    const int8_t   * q8 = bq8_1[ib32].qs;
+    const uint8_t ls1 = bq2->scales[ib32] & 0xf;
+    const uint8_t ls2 = bq2->scales[ib32] >>  4;
+    int sumi1 = 0;
+    for (int l = 0; l < 2; ++l) {
+        const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
+        const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
+        for (int j = 0; j < 8; ++j) {
+            sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+        }
+        q8 += 8;
+    }
+    int sumi2 = 0;
+    for (int l = 2; l < 4; ++l) {
+        const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
+        const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
+        for (int j = 0; j < 8; ++j) {
+            sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+        }
+        q8 += 8;
+    }
+    const float d = __half2float(bq2->d) * __half2float(bq8_1[ib32].ds.x) * 0.25f;
+    return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
+}
+
+template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
+static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, const int ncols, const int nrows) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+
+    if (row >= nrows) {
+        return;
+    }
+
+    const int blocks_per_row = ncols / qk;
+    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp = 0.0f;
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
+        const int ibx = row*blocks_per_row + i; // x block index
+
+        const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+        const int iqs  = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
+
+        tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    }
+
+    if (threadIdx.x == 0) {
+        dst[row] = __float2half(tmp);
+    }
+}
+
+static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
+        <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
+              allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
+static __device__ __forceinline__ void mul_mat_q(
+    const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+    const block_q_t  * x = (const block_q_t  *) vx;
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    const int blocks_per_row_x = ncols_x / qk;
+    const int blocks_per_col_y = nrows_y / QK8_1;
+    const int blocks_per_warp = WARP_SIZE / qi;
+
+    const int & ncols_dst = ncols_y;
+
+    const int row_dst_0 = blockIdx.x*mmq_y;
+    const int & row_x_0 = row_dst_0;
+
+    const int col_dst_0 = blockIdx.y*mmq_x;
+    const int & col_y_0 = col_dst_0;
+
+    int   * tile_x_ql = nullptr;
+    half2 * tile_x_dm = nullptr;
+    int   * tile_x_qh = nullptr;
+    int   * tile_x_sc = nullptr;
+
+    allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
+
+    __shared__ int    tile_y_qs[mmq_x * WARP_SIZE];
+    __shared__ half2  tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
+
+    float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
+
+    for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
+
+        load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
+                   threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
+
+#pragma unroll
+        for (int ir = 0; ir < qr; ++ir) {
+            const int kqs = ir*WARP_SIZE + threadIdx.x;
+            const int kbxd = kqs / QI8_1;
+
+#pragma unroll
+            for (int i = 0; i < mmq_x; i += nwarps) {
+                const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
+                const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
+                const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE;
+                tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
+            }
+
+#pragma unroll
+            for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
+                const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
+                const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
+                const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
+
+                // if the sum is not needed it's faster to transform the scale to f32 ahead of time
+                const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
+                half2       * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
+                if (need_sum) {
+                    *dsi_dst = *dsi_src;
+                } else {
+                    float * dfi_dst = (float *) dsi_dst;
+                    *dfi_dst = __low2float(*dsi_src);
+                }
+            }
+
+            __syncthreads();
+
+// #pragma unroll // unrolling this loop causes too much register pressure
+            for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
+#pragma unroll
+                for (int j = 0; j < mmq_x; j += nwarps) {
+#pragma unroll
+                    for (int i = 0; i < mmq_y; i += WARP_SIZE) {
+                        sum[i/WARP_SIZE][j/nwarps] += vec_dot(
+                            tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
+                            threadIdx.x + i, threadIdx.y + j, k);
+                    }
+                }
+            }
+            __syncthreads();
+        }
+    }
+
+#pragma unroll
+    for (int j = 0; j < mmq_x; j += nwarps) {
+        const int col_dst = col_dst_0 + j + threadIdx.y;
+        if (col_dst >= ncols_dst) {
+            return;
+        }
+
+#pragma unroll
+        for (int i = 0; i < mmq_y; i += WARP_SIZE) {
+            const int row_dst = row_dst_0 + threadIdx.x + i;
+            if (row_dst >= nrows_dst) {
+                continue;
+            }
+            dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE][j/nwarps]);
+        }
+    }
+}
+
+#if defined(USE_ROCM)
+#define  MMQ_X_Q4_0  64
+#define  MMQ_Y_Q4_0  128
+#define NWARPS_Q4_0  8
+#else
+#define  MMQ_X_Q4_0 4
+#define  MMQ_Y_Q4_0 32
+#define NWARPS_Q4_0 4
+#endif
+
+template <bool need_check> static __global__ void
+#if defined(USE_ROCM)
+__launch_bounds__(WARP_SIZE*NWARPS_Q4_0, 2)
+#endif
+mul_mat_q4_0(
+    const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+    const int mmq_x  =  MMQ_X_Q4_0;
+    const int mmq_y  =  MMQ_Y_Q4_0;
+    const int nwarps = NWARPS_Q4_0;
+
+    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+        load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q4_0_q8_1_cuda(
+    const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    int mmq_x  =  MMQ_X_Q4_0;
+    int mmq_y  =  MMQ_Y_Q4_0;
+    int nwarps = NWARPS_Q4_0;
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+#if defined(USE_ROCM)
+#define  MMQ_X_Q4_1 64
+#define  MMQ_Y_Q4_1 128
+#define NWARPS_Q4_1 8
+#else
+#define  MMQ_X_Q4_1 4
+#define  MMQ_Y_Q4_1 32
+#define NWARPS_Q4_1 4
+#endif
+
+template <bool need_check> static __global__ void
+#if defined(USE_ROCM)
+__launch_bounds__(WARP_SIZE*NWARPS_Q4_1, 2)
+#endif
+mul_mat_q4_1(
+    const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+    const int mmq_x  =  MMQ_X_Q4_1;
+    const int mmq_y  =  MMQ_Y_Q4_1;
+    const int nwarps = NWARPS_Q4_1;
+
+    mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+        load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q4_1_q8_1_cuda(
+    const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    int mmq_x  =  MMQ_X_Q4_1;
+    int mmq_y  =  MMQ_Y_Q4_1;
+    int nwarps = NWARPS_Q4_1;
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+#if defined(USE_ROCM)
+#define  MMQ_X_Q5_0 64
+#define  MMQ_Y_Q5_0 128
+#define NWARPS_Q5_0 8
+#else
+#define  MMQ_X_Q5_0 4
+#define  MMQ_Y_Q5_0 32
+#define NWARPS_Q5_0 4
+#endif
+
+template <bool need_check> static __global__ void
+#if defined(USE_ROCM)
+__launch_bounds__(WARP_SIZE*NWARPS_Q5_0, 2)
+#endif
+mul_mat_q5_0(
+    const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+    const int mmq_x  =  MMQ_X_Q5_0;
+    const int mmq_y  =  MMQ_Y_Q5_0;
+    const int nwarps = NWARPS_Q5_0;
+
+    mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+        load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q5_0_q8_1_cuda(
+    const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int mmq_x  =  MMQ_X_Q5_0;
+    const int mmq_y  =  MMQ_Y_Q5_0;
+    const int nwarps = NWARPS_Q5_0;
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+#if defined(USE_ROCM)
+#define  MMQ_X_Q5_1 64
+#define  MMQ_Y_Q5_1 128
+#define NWARPS_Q5_1 8
+#else
+#define  MMQ_X_Q5_1 4
+#define  MMQ_Y_Q5_1 32
+#define NWARPS_Q5_1 4
+#endif
+
+template <bool need_check> static __global__ void
+#if defined(USE_ROCM)
+__launch_bounds__(WARP_SIZE*NWARPS_Q5_1, 2)
+#endif
+mul_mat_q5_1(
+    const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+    const int mmq_x  =  MMQ_X_Q5_1;
+    const int mmq_y  =  MMQ_Y_Q5_1;
+    const int nwarps = NWARPS_Q5_1;
+
+    mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+        load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q5_1_q8_1_cuda(
+    const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+    const int mmq_x  =  MMQ_X_Q5_1;
+    const int mmq_y  =  MMQ_Y_Q5_1;
+    const int nwarps = NWARPS_Q5_1;
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+#if defined(USE_ROCM)
+#define  MMQ_X_Q8_0 64
+#define  MMQ_Y_Q8_0 128
+#define NWARPS_Q8_0 8
+#else
+#define  MMQ_X_Q8_0 4
+#define  MMQ_Y_Q8_0 32
+#define NWARPS_Q8_0 4
+#endif
+
+template <bool need_check> static __global__ void
+#if defined(USE_ROCM)
+__launch_bounds__(WARP_SIZE*NWARPS_Q8_0, 2)
+#endif
+mul_mat_q8_0(
+    const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+    const int mmq_x  =  MMQ_X_Q8_0;
+    const int mmq_y  =  MMQ_Y_Q8_0;
+    const int nwarps = NWARPS_Q8_0;
+
+    mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+        load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q8_0_q8_1_cuda(
+    const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+    const int mmq_x  =  MMQ_X_Q8_0;
+    const int mmq_y  =  MMQ_Y_Q8_0;
+    const int nwarps = NWARPS_Q8_0;
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+#if defined(USE_ROCM)
+#define  MMQ_X_Q2_K 64
+#define  MMQ_Y_Q2_K 128
+#define NWARPS_Q2_K 8
+#else
+#define  MMQ_X_Q2_K 4
+#define  MMQ_Y_Q2_K 32
+#define NWARPS_Q2_K 4
+#endif
+
+template <bool need_check> static __global__ void
+#if defined(USE_ROCM)
+__launch_bounds__(WARP_SIZE*NWARPS_Q2_K, 2)
+#endif
+mul_mat_q2_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+    const int mmq_x  =  MMQ_X_Q2_K;
+    const int mmq_y  =  MMQ_Y_Q2_K;
+    const int nwarps = NWARPS_Q2_K;
+
+    mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+        load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q2_K_q8_1_cuda(
+    const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+    const int mmq_x  =  MMQ_X_Q2_K;
+    const int mmq_y  =  MMQ_Y_Q2_K;
+    const int nwarps = NWARPS_Q2_K;
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+#if defined(USE_ROCM)
+#define  MMQ_X_Q3_K 64
+#define  MMQ_Y_Q3_K 128
+#define NWARPS_Q3_K 8
+#else
+#define  MMQ_X_Q3_K 4
+#define  MMQ_Y_Q3_K 32
+#define NWARPS_Q3_K 4
+#endif
+
+template <bool need_check> static __global__ void
+#if defined(USE_ROCM)
+__launch_bounds__(WARP_SIZE*NWARPS_Q3_K, 2)
+#endif
+mul_mat_q3_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+    const int mmq_x  =  MMQ_X_Q3_K;
+    const int mmq_y  =  MMQ_Y_Q3_K;
+    const int nwarps = NWARPS_Q3_K;
+
+    mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+        load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q3_K_q8_1_cuda(
+    const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int mmq_x  =  MMQ_X_Q3_K;
+    const int mmq_y  =  MMQ_Y_Q3_K;
+    const int nwarps = NWARPS_Q3_K;
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+#if defined(USE_ROCM)
+#define  MMQ_X_Q4_K 64
+#define  MMQ_Y_Q4_K 128
+#define NWARPS_Q4_K 8
+#else
+#define  MMQ_X_Q4_K 4
+#define  MMQ_Y_Q4_K 32
+#define NWARPS_Q4_K 4
+#endif
+
+template <bool need_check> static __global__ void
+#if defined(USE_ROCM)
+__launch_bounds__(WARP_SIZE*NWARPS_Q4_K, 2)
+#endif
+mul_mat_q4_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+    const int mmq_x  =  MMQ_X_Q4_K;
+    const int mmq_y  =  MMQ_Y_Q4_K;
+    const int nwarps = NWARPS_Q4_K;
+
+    mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+        load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q4_K_q8_1_cuda(
+    const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+    const int mmq_x  =  MMQ_X_Q4_K;
+    const int mmq_y  =  MMQ_Y_Q4_K;
+    const int nwarps = NWARPS_Q4_K;
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+#if defined(USE_ROCM)
+#define  MMQ_X_Q5_K 64
+#define  MMQ_Y_Q5_K 128
+#define NWARPS_Q5_K 8
+#else
+#define  MMQ_X_Q5_K 4
+#define  MMQ_Y_Q5_K 32
+#define NWARPS_Q5_K 4
+#endif
+
+template <bool need_check> static __global__ void
+#if defined(USE_ROCM)
+__launch_bounds__(WARP_SIZE*NWARPS_Q5_K, 2)
+#endif
+mul_mat_q5_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+    const int mmq_x  =  MMQ_X_Q5_K;
+    const int mmq_y  =  MMQ_Y_Q5_K;
+    const int nwarps = NWARPS_Q5_K;
+
+    mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+        load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q5_K_q8_1_cuda(
+    const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+    const int mmq_x  =  MMQ_X_Q5_K;
+    const int mmq_y  =  MMQ_Y_Q5_K;
+    const int nwarps = NWARPS_Q5_K;
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+#if defined(USE_ROCM)
+#define  MMQ_X_Q6_K 64
+#define  MMQ_Y_Q6_K 128
+#define NWARPS_Q6_K 8
+#else
+#define  MMQ_X_Q6_K 4
+#define  MMQ_Y_Q6_K 32
+#define NWARPS_Q6_K 4
+#endif
+
+template <bool need_check> static __global__ void
+#if defined(USE_ROCM)
+__launch_bounds__(WARP_SIZE*NWARPS_Q6_K, 2)
+#endif
+mul_mat_q6_K(
+    const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+    const int mmq_x  =  MMQ_X_Q6_K;
+    const int mmq_y  =  MMQ_Y_Q6_K;
+    const int nwarps = NWARPS_Q6_K;
+
+    mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+        load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q6_K_q8_1_cuda(
+    const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
+    const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+    const int mmq_x  =  MMQ_X_Q6_K;
+    const int mmq_y  =  MMQ_Y_Q6_K;
+    const int nwarps = NWARPS_Q6_K;
+
+    const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+    const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+    const dim3 block_nums(block_num_x, block_num_y, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    if (nrows_x % mmq_y == 0) {
+        const bool need_check = false;
+        mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    } else {
+        const bool need_check = true;
+        mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
+            (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+    }
+}
+
+torch::Tensor ggml_dequantize(
+    torch::Tensor W,   // quant weight
+    int8_t type,
+    int64_t m,
+    int64_t n
+){
+    const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
+    auto options = torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
+    at::Tensor DW = torch::empty({m, n}, options);
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(type);
+    to_fp16_cuda(
+        (void*)W.data_ptr(), (half*)DW.data_ptr(), m * n, stream
+    );
+    return DW;
+}
+
+torch::Tensor ggml_mul_mat_vec(
+    torch::Tensor W,  // quant weight
+    torch::Tensor X,  // input
+    int8_t type,
+    int64_t row
+){
+    size_t col = X.sizes()[1];
+    const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
+    auto options = torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
+    at::Tensor Y = torch::empty({1, row}, options);
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    switch (type) {
+        case 2:
+            dequantize_mul_mat_vec_q4_0_cuda((void*)W.data_ptr(), (half*)X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 3:
+            dequantize_mul_mat_vec_q4_1_cuda((void*)W.data_ptr(), (half*)X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 6:
+            dequantize_mul_mat_vec_q5_0_cuda((void*)W.data_ptr(), (half*)X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 7:
+            dequantize_mul_mat_vec_q5_1_cuda((void*)W.data_ptr(), (half*)X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 8:
+            dequantize_mul_mat_vec_q8_0_cuda((void*)W.data_ptr(), (half*)X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 10:
+            dequantize_mul_mat_vec_q2_K_cuda((void*)W.data_ptr(), (half*)X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 11:
+            dequantize_mul_mat_vec_q3_K_cuda((void*)W.data_ptr(), (half*)X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 12:
+            dequantize_mul_mat_vec_q4_K_cuda((void*)W.data_ptr(), (half*)X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 13:
+            dequantize_mul_mat_vec_q5_K_cuda((void*)W.data_ptr(), (half*)X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 14:
+            dequantize_mul_mat_vec_q6_K_cuda((void*)W.data_ptr(), (half*)X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 16:
+            dequantize_mul_mat_vec_iq2_xxs_cuda((void*)W.data_ptr(), (half*)X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 17:
+            dequantize_mul_mat_vec_iq2_xs_cuda((void*)W.data_ptr(), (half*)X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+    }
+    return Y;
+}
+
+torch::Tensor ggml_mul_mat_vec_a8(
+    torch::Tensor W,  // quant weight
+    torch::Tensor X,  // input
+    int8_t type,
+    int64_t row
+){
+    int col = X.sizes()[1];
+    const int padded = (col + 512 - 1) / 512 * 512;
+    const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
+    auto options = torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
+    at::Tensor Y = torch::empty({1, row}, options);
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
+    at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options);
+    quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, 1, stream);
+    switch (type) {
+        case 2:
+            mul_mat_vec_q4_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 3:
+            mul_mat_vec_q4_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 6:
+            mul_mat_vec_q5_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 7:
+            mul_mat_vec_q5_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 8:
+            mul_mat_vec_q8_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 10:
+            mul_mat_vec_q2_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 11:
+            mul_mat_vec_q3_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 12:
+            mul_mat_vec_q4_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 13:
+            mul_mat_vec_q5_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 14:
+            mul_mat_vec_q6_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 16:
+            mul_mat_vec_iq2_xxs_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+        case 17:
+            mul_mat_vec_iq2_xs_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream);
+            break;
+    }
+    return Y;
+}
+
+torch::Tensor ggml_mul_mat_a8(
+    torch::Tensor W,  // quant weight
+    torch::Tensor X,  // input
+    int8_t type,
+    int64_t row
+) {
+    int col = X.sizes()[1];
+    int padded = (col + 512 - 1) / 512 * 512;
+    int batch = X.sizes()[0];
+    const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
+    auto options = torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
+    at::Tensor Y = torch::empty({batch, row}, options);
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
+    at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options);
+    quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, batch, stream);
+
+    switch (type) {
+        case 2:
+            ggml_mul_mat_q4_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, batch, padded, row, stream);
+            break;
+        case 3:
+            ggml_mul_mat_q4_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, batch, padded, row, stream);
+            break;
+        case 6:
+            ggml_mul_mat_q5_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, batch, padded, row, stream);
+            break;
+        case 7:
+            ggml_mul_mat_q5_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, batch, padded, row, stream);
+            break;
+        case 8:
+            ggml_mul_mat_q8_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, batch, padded, row, stream);
+            break;
+        case 10:
+            ggml_mul_mat_q2_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, batch, padded, row, stream);
+            break;
+        case 11:
+            ggml_mul_mat_q3_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, batch, padded, row, stream);
+            break;
+        case 12:
+            ggml_mul_mat_q4_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, batch, padded, row, stream);
+            break;
+        case 13:
+            ggml_mul_mat_q5_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, batch, padded, row, stream);
+            break;
+        case 14:
+            ggml_mul_mat_q6_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, batch, padded, row, stream);
+            break;
+    }
+    return Y;
+}

+ 843 - 0
kernels/quantization/marlin/marlin_cuda_kernel.cu

@@ -0,0 +1,843 @@
+/*
+ * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *         http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+#ifndef MARLIN_CUDA_KERNEL_CUH
+#define MARLIN_CUDA_KERNEL_CUH
+
+#include <torch/extension.h>
+#include <c10/cuda/CUDAStream.h>
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+namespace aphrodite {
+namespace marlin {
+
+constexpr int ceildiv(int a, int b) {
+  return (a + b - 1) / b;
+}
+
+// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core
+// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we
+// extensively use `#pragma unroll` throughout the kernel code to guarantee this.
+template <typename T, int n>
+struct Vec {
+  T elems[n];
+  __device__ T& operator[](int i) {
+    return elems[i];
+  }
+};
+
+using I4 = Vec<int, 4>;
+
+// Matrix fragments for tensor core instructions; their precise layout is documented here:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+using FragA = Vec<half2, 4>;
+using FragB = Vec<half2, 2>;
+using FragC = Vec<float, 4>;
+using FragS = Vec<half2, 1>; // quantization scales
+
+// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that
+// are not multiples of 16.
+__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
+  const int BYTES = 16;
+  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
+  asm volatile(
+    "{\n"
+    "   .reg .pred p;\n"
+    "   setp.ne.b32 p, %0, 0;\n"
+    "   @p cp.async.cg.shared.global [%1], [%2], %3;\n"
+    "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES)
+  );
+}
+
+// Asynchronous global->shared copy with a chache hint indicating that the values may be evicted immediately; used for
+// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need
+// for inputs A and outputs C.
+__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
+  const int BYTES = 16;
+  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
+  asm volatile(
+    "{\n"
+    "   .reg .b64 p;\n"
+    "   createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
+    "   cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
+    "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)
+  );
+}
+
+// Async copy fence.
+__device__ inline void cp_async_fence() {
+  asm volatile("cp.async.commit_group;\n" ::);
+}
+
+// Wait until at most `n` async copy stages are still pending.
+template <int n>
+__device__ inline void cp_async_wait() {
+  asm volatile("cp.async.wait_group %0;\n" :: "n"(n));
+}
+
+// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation.
+__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) {
+  const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
+  const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
+  float* c = reinterpret_cast<float*>(&frag_c);
+  asm volatile(
+    "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
+    "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
+    : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
+    :  "r"(a[0]),  "r"(a[1]),  "r"(a[2]),  "r"(a[3]),  "r"(b[0]),  "r"(b[1]),
+       "f"(c[0]),  "f"(c[1]),  "f"(c[2]),  "f"(c[3])
+  );
+}
+
+// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout.
+__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
+  uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
+  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
+  asm volatile(
+    "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
+    : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)
+  );
+}
+
+// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to
+// automatically recognize it in all cases.
+template <int lut>
+__device__ inline int lop3(int a, int b, int c) {
+  int res;
+  asm volatile(
+    "lop3.b32 %0, %1, %2, %3, %4;\n"
+    : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)
+  );
+  return res;
+}
+
+// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values.
+// We mostly follow the strategy in the link below, with some small changes:
+// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
+__device__ inline FragB dequant(int q) {
+  const int LO = 0x000f000f;
+  const int HI = 0x00f000f0;
+  const int EX = 0x64006400;
+  // Guarantee that the `(a & b) | c` operations are LOP3s.
+  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
+  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
+  // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`.
+  const int SUB = 0x64086408;
+  const int MUL = 0x2c002c00;
+  const int ADD = 0xd480d480;
+  FragB frag_b;
+  frag_b[0] = __hsub2(
+    *reinterpret_cast<half2*>(&lo),
+    *reinterpret_cast<const half2*>(&SUB)
+  );
+  frag_b[1] = __hfma2(
+    *reinterpret_cast<half2*>(&hi),
+    *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD)
+  );
+  return frag_b;
+}
+
+// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization.
+__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
+  half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
+  frag_b[0] = __hmul2(frag_b[0], s);
+  frag_b[1] = __hmul2(frag_b[1], s);
+}
+
+// Wait until barrier reaches `count`, then lock for current threadblock.
+__device__ inline void barrier_acquire(int* lock, int count) {
+  if (threadIdx.x == 0) {
+    int state = -1;
+    do
+      // Guarantee that subsequent writes by this threadblock will be visible globally.
+      asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
+    while (state != count);
+  }
+  __syncthreads();
+}
+
+// Release barrier and increment visitation count.
+__device__ inline void barrier_release(int* lock, bool reset = false) {
+  __syncthreads();
+  if (threadIdx.x == 0) {
+    if (reset) {
+      lock[0] = 0;
+      return;
+    }
+    int val = 1;
+    // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier.
+    asm volatile ("fence.acq_rel.gpu;\n");
+    asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val));
+  }
+}
+
+template <
+  const int threads, // number of threads in a threadblock
+  const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock
+  const int thread_n_blocks, // same for n dimension (output)
+  const int thread_k_blocks, // same for k dimension (reduction)
+  const int stages, // number of stages for the async global->shared fetch pipeline
+  const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale
+>
+__global__ void Marlin(
+  const int4* __restrict__ A, // fp16 input matrix of shape mxk
+  const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
+        int4* __restrict__ C, // fp16 output buffer of shape mxn
+  const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
+  int  prob_m, // batch dimension m
+  int  prob_n, // output dimension n
+  int  prob_k, // reduction dimension k
+  int* locks // extra global storage for barrier synchronization
+) {
+  // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple
+  // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example:
+  //   0 1 3
+  //   0 2 3
+  //   1 2 4
+  // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs
+  // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as
+  // possible.
+
+  int k_tiles = prob_k / 16 / thread_k_blocks;
+  int n_tiles = prob_n / 16 / thread_n_blocks;
+  int iters = ceildiv(k_tiles * n_tiles, gridDim.x);
+  // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case
+  // where a stripe starts in the middle of group.
+  if (group_blocks != -1)
+    iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks));
+
+  int slice_row = (iters * blockIdx.x) % k_tiles;
+  int slice_col = (iters * blockIdx.x) / k_tiles;
+  int slice_iters; // number of threadblock tiles in the current slice
+  int slice_count = 0; // total number of active threadblocks in the current slice
+  int slice_idx; // index of threadblock in current slice; numbered bottom to top
+
+  // Compute all information about the current slice which is required for synchronization.
+  auto init_slice = [&] () {
+    slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col + slice_row);
+    if (slice_iters < 0 || slice_col >= n_tiles)
+      slice_iters = 0;
+    if (slice_iters == 0)
+      return;
+    if (slice_row + slice_iters > k_tiles)
+      slice_iters = k_tiles - slice_row;
+    slice_count = 1;
+    slice_idx = 0;
+    int col_first = iters * ceildiv(k_tiles * slice_col, iters);
+    if (col_first <= k_tiles * (slice_col + 1)) {
+      int col_off = col_first - k_tiles * slice_col;
+      slice_count = ceildiv(k_tiles - col_off, iters);
+      if (col_off > 0)
+        slice_count++;
+      int delta_first = iters * blockIdx.x - col_first;
+      if (delta_first < 0 || (col_off == 0 && delta_first == 0))
+        slice_idx = slice_count - 1;
+      else {
+        slice_idx = slice_count - 1 - delta_first / iters;
+        if (col_off > 0)
+          slice_idx--;
+      }
+    }
+  };
+  init_slice();
+
+  int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
+  // We typically use `constexpr` to indicate that this value is a compile-time constant
+  constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
+  constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory
+  int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
+  constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes
+  constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads
+  constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile
+  constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
+  constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile
+
+  int b_gl_stride = 16 * prob_n / 32;
+  constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
+  int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
+  int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
+  constexpr int b_sh_wr_delta = threads;
+  constexpr int b_sh_rd_delta = threads;
+  constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
+  constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
+
+  int s_gl_stride = prob_n / 8;
+  constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
+  constexpr int s_sh_stage = s_sh_stride;
+  int s_gl_rd_delta = s_gl_stride;
+
+  // Global A read index of current thread.
+  int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
+  a_gl_rd += a_gl_rd_delta_o * slice_row;
+  // Shared write index of current thread.
+  int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
+  // Shared read index.
+  int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
+  a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
+
+  int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
+  b_gl_rd += b_sh_stride * slice_col;
+  b_gl_rd += b_gl_rd_delta_o * slice_row;
+  int b_sh_wr = threadIdx.x;
+  int b_sh_rd = threadIdx.x;
+
+  int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x;
+  int s_sh_wr = threadIdx.x;
+  int s_sh_rd;
+  // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major
+  // layout in the former and in row-major in the latter case.
+  if (group_blocks != -1)
+    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
+  else
+    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;
+
+  // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than
+  // required for a certain tilesize or when the batchsize is not a multiple of 16.
+  bool a_sh_wr_pred[a_sh_wr_iters];
+  #pragma unroll
+  for (int i = 0; i < a_sh_wr_iters; i++)
+    a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
+  bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
+
+  // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank
+  // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of
+  // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based
+  // on NSight-Compute) that each warp must also write a consecutive memory segment?
+  auto transform_a = [&] (int i) {
+    int row = i / a_gl_rd_delta_o;
+    return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
+  };
+  // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory
+  // accesses are static, we simply precompute both transformed reads and writes.
+  int a_sh_wr_trans[a_sh_wr_iters];
+  #pragma unroll
+  for (int i = 0; i < a_sh_wr_iters; i++)
+    a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
+  int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
+  #pragma unroll
+  for (int i = 0; i < b_sh_wr_iters; i++) {
+    #pragma unroll
+    for (int j = 0; j < thread_m_blocks; j++)
+      a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
+  }
+
+  // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between
+  // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization.
+  const int4* B_ptr[b_sh_wr_iters];
+  #pragma unroll
+  for (int i = 0; i < b_sh_wr_iters; i++)
+    B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
+
+  extern __shared__ int4 sh[];
+  // Shared memory storage for global fetch pipelines.
+  int4* sh_a = sh;
+  int4* sh_b = sh_a + (stages * a_sh_stage);
+  int4* sh_s = sh_b + (stages * b_sh_stage);
+  // Register storage for double buffer of shared memory reads.
+  FragA frag_a[2][thread_m_blocks];
+  I4 frag_b_quant[2];
+  FragC frag_c[thread_m_blocks][4][2];
+  FragS frag_s[2][4];
+
+  // Zero accumulators.
+  auto zero_accums = [&] () {
+    #pragma unroll
+    for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
+      reinterpret_cast<float*>(frag_c)[i] = 0;
+  };
+
+  // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location.
+  auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) {
+    if (pred) {
+      int4* sh_a_stage = sh_a + a_sh_stage * pipe;
+      #pragma unroll
+      for (int i = 0; i < a_sh_wr_iters; i++) {
+        cp_async4_pred(
+          &sh_a_stage[a_sh_wr_trans[i]],
+          &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
+          a_sh_wr_pred[i]
+        );
+      }
+      int4* sh_b_stage = sh_b + b_sh_stage * pipe;
+      #pragma unroll
+      for (int i = 0; i < b_sh_wr_iters; i++) {
+        cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
+        B_ptr[i] += b_gl_rd_delta_o;
+      }
+      // Only fetch scales if this tile starts a new group
+      if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
+        int4* sh_s_stage = sh_s + s_sh_stage * pipe;
+        if (s_sh_wr_pred)
+          cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
+        s_gl_rd += s_gl_rd_delta;
+      }
+    }
+    // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point.
+    cp_async_fence();
+  };
+
+  // Wait until the next thread tile has been loaded to shared memory.
+  auto wait_for_stage = [&] () {
+    // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when
+    // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten).
+    cp_async_wait<stages - 2>();
+    __syncthreads();
+  };
+
+  // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer.
+  auto fetch_to_registers = [&] (int k, int pipe) {
+    // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a
+    // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the
+    // compiler and correspondingly a noticable drop in performance.
+    if (group_blocks != -1) {
+      int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));
+      reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
+    }
+    int4* sh_a_stage = sh_a + a_sh_stage * pipe;
+    #pragma unroll
+    for (int i = 0; i < thread_m_blocks; i++)
+      ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
+    int4* sh_b_stage = sh_b + b_sh_stage * pipe;
+    frag_b_quant[k % 2] = *reinterpret_cast<I4*>(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
+  };
+
+  // Execute the actual tensor core matmul of a sub-tile.
+  auto matmul = [&] (int k) {
+    // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations.
+    #pragma unroll
+    for (int j = 0; j < 4; j++) {
+      int b_quant = frag_b_quant[k % 2][j];
+      int b_quant_shift = b_quant >> 8;
+      FragB frag_b0 = dequant(b_quant);
+      // If there are no groups, we can just scale the final output once and can avoid doing so for each weight.
+      if (group_blocks != -1)
+        scale(frag_b0, frag_s[k % 2][j], 0);
+      FragB frag_b1 = dequant(b_quant_shift);
+      if (group_blocks != -1)
+        scale(frag_b1, frag_s[k % 2][j], 1);
+      #pragma unroll
+      for (int i = 0; i < thread_m_blocks; i++) {
+        mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
+        mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
+      }
+    }
+  };
+
+  // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n
+  // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output
+  // location; which we have to reduce over in the end. We do in shared memory.
+  auto thread_block_reduce = [&] () {
+    constexpr int red_off = threads / b_sh_stride / 2;
+    if (red_off >= 1) {
+      int red_idx = threadIdx.x / b_sh_stride;
+      constexpr int red_sh_stride = b_sh_stride * 4 * 2;
+      constexpr int red_sh_delta = b_sh_stride;
+      int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
+
+      // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations,
+      // e.g., for two warps we write only once by warp 1 and read only once by warp 0.
+
+      #pragma unroll
+      for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
+        #pragma unroll
+        for (int i = red_off; i > 0; i /= 2) {
+          if (i <= red_idx && red_idx < 2 * i) {
+            #pragma unroll
+            for (int j = 0; j < 4 * 2; j++) {
+              int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
+              if (i < red_off) {
+                float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
+                float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
+                #pragma unroll
+                for (int k = 0; k < 4; k++)
+                  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];
+              }
+              sh[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
+            }
+          }
+          __syncthreads();
+        }
+        if (red_idx == 0) {
+          #pragma unroll
+          for (int i = 0; i < 4 * 2; i++) {
+            float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
+            #pragma unroll
+            for (int j = 0; j < 4; j++)
+              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];
+          }
+        }
+        __syncthreads();
+      }
+    }
+  };
+
+  // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over
+  // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather
+  // small, we perform this reduction serially in L2 cache.
+  auto global_reduce = [&] (bool first = false, bool last = false) {
+    // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step.
+    // To do this, we write out results in FP16 (but still reduce with FP32 compute).
+    constexpr int active_threads = 32 * thread_n_blocks / 4;
+    if (threadIdx.x < active_threads) {
+      int c_gl_stride = prob_n / 8;
+      int c_gl_wr_delta_o = 8 * c_gl_stride;
+      int c_gl_wr_delta_i = 4 * (active_threads / 32);
+      int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4;
+      c_gl_wr += (2 * thread_n_blocks) * slice_col;
+      constexpr int c_sh_wr_delta = active_threads;
+      int c_sh_wr = threadIdx.x;
+
+      int row = (threadIdx.x % 32) / 4;
+
+      if (!first) {
+        // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns,
+        // hence we also use async-copies even though these fetches are not actually asynchronous.
+        #pragma unroll
+        for (int i = 0; i < thread_m_blocks * 4; i++) {
+          cp_async4_pred(
+            &sh[c_sh_wr + c_sh_wr_delta * i],
+            &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],
+            i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m
+          );
+        }
+        cp_async_fence();
+        cp_async_wait<0>();
+      }
+
+      #pragma unroll
+      for (int i = 0; i < thread_m_blocks * 4; i++) {
+        if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
+          if (!first) {
+            int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
+            #pragma unroll
+            for (int j = 0; j < 2 * 4; j++) {
+              reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float(
+                reinterpret_cast<__half*>(&c_red)[j]
+              );
+            }
+          }
+          if (!last) {
+            int4 c;
+            #pragma unroll
+            for (int j = 0; j < 2 * 4; j++) {
+              reinterpret_cast<__half*>(&c)[j] = __float2half(
+                reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]
+              );
+            }
+            C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c;
+          }
+        }
+      }
+    }
+  };
+
+  // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step,
+  // the reduction above is performed in fragment layout.
+  auto write_result = [&] () {
+    int c_gl_stride = prob_n / 8;
+    constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
+    int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
+    constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks));
+
+    int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
+    c_gl_wr += (2 * thread_n_blocks) * slice_col;
+    int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
+    c_sh_wr += 32 * (threadIdx.x / 32);
+    int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
+
+    int c_gl_wr_end = c_gl_stride * prob_m;
+
+    // We first reorder in shared memory to guarantee the most efficient final global write patterns
+    auto write = [&] (int idx, float c0, float c1, FragS& s) {
+      half2 res = __halves2half2(__float2half(c0), __float2half(c1));
+      if (group_blocks == -1) // for per-column quantization we finally apply the scale here
+        res = __hmul2(res, s[0]);
+      ((half2*) sh)[idx] = res;
+    };
+    if (threadIdx.x / 32 < thread_n_blocks / 4) {
+      #pragma unroll
+      for (int i = 0; i < thread_m_blocks; i++) {
+        #pragma unroll
+        for (int j = 0; j < 4; j++) {
+          int wr = c_sh_wr + 8 * j;
+          write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
+          write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
+          write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
+          write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
+        }
+        c_sh_wr += 16 * (4 * c_sh_stride);
+      }
+    }
+    __syncthreads();
+
+    #pragma unroll
+    for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {
+      if (c_gl_wr < c_gl_wr_end) {
+        C[c_gl_wr] = sh[c_sh_rd];
+        c_gl_wr += c_gl_wr_delta;
+        c_sh_rd += c_sh_rd_delta;
+      }
+    }
+  };
+
+  // Start global fetch and register load pipelines.
+  auto start_pipes = [&] () {
+    #pragma unroll
+    for (int i = 0; i < stages - 1; i++)
+      fetch_to_shared(i, i, i < slice_iters);
+    zero_accums();
+    wait_for_stage();
+    fetch_to_registers(0, 0);
+    a_gl_rd += a_gl_rd_delta_o * (stages - 1);
+  };
+  start_pipes();
+
+  // Main loop.
+  while (slice_iters) {
+    // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are
+    // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0.
+    #pragma unroll
+    for (int pipe = 0; pipe < stages;) {
+      #pragma unroll
+      for (int k = 0; k < b_sh_wr_iters; k++) {
+        fetch_to_registers(k + 1, pipe % stages);
+        if (k == b_sh_wr_iters - 2) {
+          fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages);
+          pipe++;
+          wait_for_stage();
+        }
+        matmul(k);
+      }
+      slice_iters--;
+      if (slice_iters == 0)
+        break;
+    }
+    a_gl_rd += a_gl_rd_delta_o * stages;
+
+    // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most
+    // readable, other ways of writing the loop seemed to noticeably worse performance after compliation.
+    if (slice_iters == 0) {
+      cp_async_wait<0>();
+      bool last = slice_idx == slice_count - 1;
+      // For per-column scales, we only fetch them here in the final step before write-out
+      if (group_blocks == -1 && last) {
+        if (s_sh_wr_pred)
+          cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
+        cp_async_fence();
+      }
+      thread_block_reduce();
+      if (group_blocks == -1 && last) {
+        cp_async_wait<0>();
+        __syncthreads();
+        if (threadIdx.x / 32 < thread_n_blocks / 4) {
+          reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
+          reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
+        }
+      }
+      if (slice_count > 1) { // only globally reduce if there is more than one block in a slice
+        barrier_acquire(&locks[slice_col], slice_idx);
+        global_reduce(slice_idx == 0, last);
+        barrier_release(&locks[slice_col], last);
+      }
+      if (last) // only the last block in a slice actually writes the result
+        write_result();
+      slice_row = 0;
+      slice_col++;
+      init_slice();
+      if (slice_iters) {
+        a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
+        #pragma unroll
+        for (int i = 0; i < b_sh_wr_iters; i++)
+          B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
+        s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
+        start_pipes();
+      }
+    }
+  }
+}
+
+
+// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more
+// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles.
+const int THREADS = 256;
+const int STAGES = 4; // 4 pipeline stages fit into shared memory
+const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
+
+#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \
+  else if ( \
+    thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \
+    group_blocks == GROUP_BLOCKS \
+  ) { \
+    cudaFuncSetAttribute( \
+      Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
+      cudaFuncAttributeMaxDynamicSharedMemorySize, \
+      SHARED_MEM \
+    ); \
+    Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS><<<blocks, THREADS, SHARED_MEM, stream>>>( \
+      A_ptr, B_ptr, C_ptr, s_ptr, \
+      prob_m, prob_n, prob_k, \
+      locks \
+    ); \
+  }
+
+const int ERR_PROB_SHAPE = 1;
+const int ERR_KERN_SHAPE = 2;
+
+int marlin_cuda(
+  const void* A,
+  const void* B,
+        void* C,
+        void* s,
+  int prob_m,
+  int prob_n,
+  int prob_k,
+  void* workspace,
+  int groupsize = -1,
+  int dev = 0,
+  cudaStream_t stream = 0,
+  int thread_k = -1,
+  int thread_n = -1,
+  int sms = -1
+) {
+  int tot_m = prob_m;
+  int tot_m_blocks = ceildiv(tot_m, 16);
+
+  if (sms == -1)
+    cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
+  if (thread_k == -1 || thread_n == -1) {
+    if (prob_m <= 16) {
+      // For small batchizes, better partioning is slightly more important than better compute utilization
+      thread_k = 128;
+      thread_n = 128;
+    } else {
+      thread_k = 64;
+      thread_n = 256;
+    }
+  }
+
+  int thread_k_blocks = thread_k / 16;
+  int thread_n_blocks = thread_n / 16;
+  int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
+  int blocks = sms;
+
+  if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0))
+    return ERR_PROB_SHAPE;
+  if (prob_m == 0 || prob_n == 0 || prob_k == 0)
+    return 0;
+
+  const int4* A_ptr = (const int4*) A;
+  const int4* B_ptr = (const int4*) B;
+  int4* C_ptr = (int4*) C;
+  const int4* s_ptr = (const int4*) s;
+
+  int cols = prob_n / thread_n;
+  int* locks = (int*) workspace;
+
+  int ret = 0;
+  for (int i = 0; i < tot_m_blocks; i += 4) {
+    int thread_m_blocks = tot_m_blocks - i;
+    prob_m = tot_m - 16 * i;
+    if (thread_m_blocks > 4) {
+      thread_m_blocks = 4;
+      prob_m = 64;
+    }
+
+    // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance)
+    // in our testing, however many more are, in principle, possible.
+    if (false) {}
+    CALL_IF(1,  8,  8, -1)
+    CALL_IF(1,  8,  8,  8)
+    CALL_IF(1, 16,  4, -1)
+    CALL_IF(1, 16,  4,  8)
+    CALL_IF(2, 16,  4, -1)
+    CALL_IF(2, 16,  4,  8)
+    CALL_IF(3, 16,  4, -1)
+    CALL_IF(3, 16,  4,  8)
+    CALL_IF(4, 16,  4, -1)
+    CALL_IF(4, 16,  4,  8)
+    else
+      ret = ERR_KERN_SHAPE;
+
+    A_ptr += 16 * thread_m_blocks * (prob_k / 8);
+    C_ptr += 16 * thread_m_blocks * (prob_n / 8);
+  }
+
+  return ret;
+}
+
+#endif
+
+} // namespace marlin
+} // namespace aphrodite
+
+const int ERR_PROB_SHAPE = 1;
+const int ERR_KERN_SHAPE = 2;
+
+// input:     `torch.half` input matrix of shape `(m, k)` in standard row-major layout
+// weights:   `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()`
+// output:    `torch.half` out matrix of shape `(m, n)` in standard row-major layout
+// scales:    `torch.half` scales of shape `(m / groupsize, n)`
+// workspace: `torch.int` tensor with at least `n / 128` entries that are all zero
+
+void marlin_gemm(
+  const torch::Tensor& input,
+  const torch::Tensor& weights,
+        torch::Tensor& output,
+  const torch::Tensor& scales,
+        torch::Tensor& workspace
+) {
+  // thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1)
+  int thread_k = -1;
+  // thread_n: `n` size of a thread_tile in `weights` (can usually be left as auto -1)
+  int thread_n = -1;
+  // sms: number of SMs to use for the kernel (can usually be left as auto -1)
+  int sms = -1;
+
+  int prob_m = input.size(0);
+  int prob_n = output.size(1);
+  int prob_k = input.size(1);
+  int groupsize = (scales.size(0) == 1) ? -1 : prob_k / scales.size(0);
+  if (groupsize != -1 && groupsize * scales.size(0) != prob_k)
+    AT_ERROR("k=", prob_k, " not compatible with ", scales.size(0), " groups.");
+  int dev = input.get_device();
+  int err = aphrodite::marlin::marlin_cuda(
+    input.data_ptr(),
+    weights.data_ptr(),
+    output.data_ptr(),
+    scales.data_ptr(),
+    prob_m, prob_n, prob_k,
+    workspace.data_ptr(),
+    groupsize,
+    dev,
+    at::cuda::getCurrentCUDAStream(dev),
+    thread_k,
+    thread_n,
+    sms
+  );
+  if (err == ERR_PROB_SHAPE) {
+    AT_ERROR(
+      "Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")",
+      " not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "."
+    );
+  } else if (err == ERR_KERN_SHAPE) {
+    AT_ERROR(
+      "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "."
+    );
+  }
+}

+ 722 - 0
kernels/quantization/quip/origin_order.cu

@@ -0,0 +1,722 @@
+#include <cuda_bf16.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+#include <mma.h>
+
+#include <ATen/ATen.h>
+#include <ATen/core/Tensor.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/DeviceGuard.h>
+#include <torch/extension.h>
+#include <c10/cuda/CUDAGuard.h>
+
+
+template <typename U, typename V>
+constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) {
+  static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
+  return (a / b);
+}
+
+template <typename U, typename V>
+constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
+  static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
+  // Overflow safe variant of (a + b - 1) / b
+  const uint64_t blocks = a / b + (a % b != 0);
+  return blocks;
+}
+
+template <typename U, typename V>
+constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) {
+  static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
+  return divDown(a, b) * b;
+}
+
+template <typename U, typename V>
+constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) {
+  static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
+  return divUp(a, b) * b;
+}
+
+constexpr int32_t kWarpSize = 32;
+constexpr int32_t KTilesPerWarp = 8;
+constexpr int32_t kMTileSize = 16;
+constexpr int32_t kNTileSize = 8;
+constexpr int32_t kKTileSize = 16;
+
+struct __align__(16) f16x2x4_u32 {
+  uint32_t vals[4];
+};
+struct __align__(16) f16x2x2_u32 {
+  uint32_t vals[2];
+};
+
+struct ALayout_RM {
+template <int KTilesToLoad>
+static __device__ void load(
+    const half* A,
+    int32_t m,
+    int32_t k,
+    int32_t mTiles,
+    int32_t mTile,
+    int32_t kTiles,
+    int32_t kTileStart,
+    int32_t laneId,
+    f16x2x4_u32 out[KTilesToLoad]) {
+  const auto mLane = mTile * kMTileSize + (laneId / 4);
+  const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 4;
+
+  // access
+  // [mTile * kMTileSize + (laneId / 4)]
+  // [kTileStart * kKTileSize + (laneId % 4) * 2]
+  auto aPtr = A + mLane * k + kLane;
+
+  auto aPtrPlus8Rows = aPtr + 8 * k;
+
+  bool m0InBounds = mLane < m;
+  bool m1InBounds = (mLane + 8) < m;
+
+#pragma unroll
+  for (int i = 0; i < KTilesToLoad; ++i) {
+    out[i].vals[0] = m0InBounds
+          ? *reinterpret_cast<const uint32_t*>(aPtr  + i * kKTileSize)
+          : uint32_t(0);
+    out[i].vals[1] = m1InBounds
+          ? *reinterpret_cast<const uint32_t*>(aPtrPlus8Rows  + i * kKTileSize)
+          : uint32_t(0);
+
+    out[i].vals[2] = m0InBounds
+          ? *reinterpret_cast<const uint32_t*>(aPtr  + i * kKTileSize + 2)
+          : uint32_t(0);
+    out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
+                                        aPtrPlus8Rows  + i * kKTileSize + 2)
+                                  : uint32_t(0);
+  }
+}
+
+static __device__ void store(
+    half* C,
+    int32_t m,
+    int32_t n,
+    int32_t mOutTiles,
+    int32_t mTile,
+    int32_t nOutTiles,
+    int32_t nTile,
+    int32_t laneId,
+    const float4& out) {
+
+  // sum.x / sum.y are written at
+  // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
+  // sum.z / sum.w are written at
+  // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
+  // i.e., same columns, different row.
+  const int outRow = mTile * kMTileSize + (laneId / 4);
+  const int outCol = nTile * kNTileSize + (laneId % 4) * 2;
+
+  // Pointer where sum.x / sum.y is written
+  auto cPtr = C + outRow * n + outCol;
+
+  auto v01 = __float22half2_rn(float2{out.x, out.y});
+  auto v23 = __float22half2_rn(float2{out.z, out.w});
+
+  if (outRow < m) {
+    *reinterpret_cast<half2*>(cPtr) = v01;
+  }
+
+  // sum.z, sum.w at +8 rows from cPtr
+  if (outRow + 8 < m) {
+    *reinterpret_cast<half2*>(cPtr + 8 * n) = v23;
+  }
+}
+};
+
+struct BLayout_D4 {
+static constexpr bool use_codebook = true;
+
+template <int KTilesPerIteration>
+static __device__ void load(
+    const void* __restrict__ B,
+    const uint64_t* __restrict__ CB,
+    int32_t n,
+    int32_t k,
+    int32_t nTiles,
+    int32_t nTile,
+    int32_t kTiles,
+    int32_t kTileStart,
+    int32_t laneId,
+    f16x2x2_u32 b[KTilesPerIteration]) {
+  auto Bptr = reinterpret_cast<const uint8_t*>(B);
+  #pragma unroll
+  for (int i = 0; i < KTilesPerIteration; ++i) {
+       const int row = nTile * kNTileSize + laneId / 4;
+       const int col = (kTileStart + i) * kKTileSize / 4 + laneId % 4;
+       *(reinterpret_cast<uint64_t*>(b[i].vals)) = CB[Bptr[row * k/4 + col]];
+  }
+}
+};
+
+struct BLayout_HI {
+static constexpr bool use_codebook = false;
+
+template <int KTilesPerIteration>
+static __device__ void load(
+    const void* __restrict__ B,
+    const uint64_t* __restrict__ CB,
+    int32_t n,
+    int32_t k,
+    int32_t nTiles,
+    int32_t nTile,
+    int32_t kTiles,
+    int32_t kTileStart,
+    int32_t laneId,
+    f16x2x2_u32 b[KTilesPerIteration]) {
+  auto Bptr = reinterpret_cast<const uint32_t*>(B);
+  #pragma unroll
+  for (int i = 0; i < KTilesPerIteration; ++i) {
+      const int row = nTile * kNTileSize + laneId / 4;
+      const int col = (kTileStart + i) * kKTileSize / 8 + (laneId % 4) / 2;
+      // simply use code - 7.5 instead of reading codebook
+      uint32_t code = Bptr[row * k/8 + col];
+
+      const uint32_t c0 = 0x64086408;
+      const half y16_ = __float2half_rn(1.0f / 16.0f);
+      const half2 y16 = __halves2half2(y16_, y16_);
+      const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
+      const half2 z16 = __halves2half2(z16_, z16_);
+
+      uint32_t qa = code >> ((laneId & 1) * 8);
+      uint32_t q0 = (((qa & 0x000f000f) << 4)| c0);
+      uint32_t q1 = ((qa & 0x00f000f0) | c0);
+      *(half2*)(b[i].vals) = __hfma2(*((half2*)(&q0)), y16, z16);
+      *(half2*)(b[i].vals+1) = __hfma2(*((half2*)(&q1)), y16, z16);
+  }
+}
+};
+
+struct BLayout_E8 {
+static constexpr bool use_codebook = true;
+
+__device__ static inline uint64_t decode8weights(
+    uint16_t weight_compressed,
+    const int64_t *__restrict__ codebook_abs
+) {
+
+    uint8_t bits_sign = weight_compressed & 0xff;
+    uint8_t parity = __popc(bits_sign) & 1;
+    uint8_t sign_vec = bits_sign ^ parity;
+    uint8_t bits_abs = (weight_compressed >> 8);
+    int64_t packed = codebook_abs[bits_abs];
+
+    uint64_t decoded_sign = sign_vec * 0x8040201008040201ll;
+    decoded_sign &= 0x8080808080808080;
+    decoded_sign >>= 7;
+    decoded_sign *= 255 - 3;
+    packed ^= decoded_sign;
+    packed |= 0x0101010101010101;
+    packed -= parity * 0x0202020202020202;
+
+    return packed;
+}
+
+__device__ static inline uint32_t decode8weights(
+    uint16_t weight_compressed,
+    const int64_t *__restrict__ codebook_abs,
+    int idx
+) {
+    uint8_t bits_sign = weight_compressed & 0xff; //__brev(weight_compressed) >> 24;
+    const uint32_t magic_nums[2] = {0x08040201ll, 0x80402010ll};
+    uint8_t parity = __popc(bits_sign) & 1;
+    uint8_t sign_vec = bits_sign ^ parity; // (parity << 7);
+    uint16_t bits_abs = (weight_compressed >> 8);
+    uint32_t packed = ((uint32_t*)codebook_abs)[(bits_abs << 1) + idx];
+    uint32_t magic_num = magic_nums[idx];
+    uint32_t decoded_sign = sign_vec * magic_num;
+    decoded_sign &= 0x80808080;
+    decoded_sign >>= 7;
+    decoded_sign *= 255 - 3;
+    packed ^= decoded_sign;
+    packed |= 0x01010101;
+    packed -= parity * 0x02020202;
+    return packed;
+};
+
+template <int KTilesPerIteration>
+static __device__ void load(
+    const void* __restrict__ B,
+    const uint64_t* __restrict__ CB,
+    int32_t n,
+    int32_t k,
+    int32_t nTiles,
+    int32_t nTile,
+    int32_t kTiles,
+    int32_t kTileStart,
+    int32_t laneId,
+    f16x2x2_u32 b[KTilesPerIteration]) {
+  auto Bptr = (const uint16_t*) B;
+  #pragma unroll
+  for (int i = 0; i < KTilesPerIteration; ++i) {
+       const int row = nTile * kNTileSize + laneId / 4;
+       const int col = (kTileStart + i) * kKTileSize / 8 + laneId % 4 / 2;
+       uint32_t decoded = decode8weights(Bptr[row * k/8 + col], (const int64_t*)CB, laneId & 1);
+       half2 unpacked[2];
+       uint32_t lower_half = decoded & 0x00ff00ff;
+       lower_half = (lower_half ^ 0x5c805c80);
+       memcpy(unpacked, &lower_half, sizeof(uint32_t));
+       uint32_t upper_half = (decoded & 0xff00ff00) >> 8;
+       upper_half = (upper_half ^ 0x5c805c80);
+       memcpy(unpacked + 1, &upper_half, sizeof(uint32_t));
+
+       const half adjust_ = __float2half_rn(-288.0f);
+       const half2 adjust = __halves2half2(adjust_, adjust_);
+       unpacked[0] = __hadd2(unpacked[0], adjust);
+       unpacked[1] = __hadd2(unpacked[1], adjust);
+       *(reinterpret_cast<uint64_t*>(b[i].vals)) = *(reinterpret_cast<uint64_t*>(unpacked));
+       //*((half*)(b[i].vals)) = unpacked[0];
+       //*((half*)(b[i].vals) + 1) = unpacked[0].y;
+       //*((half*)(b[i].vals) + 2) = unpacked[1].x;
+       //*((half*)(b[i].vals) + 3) = unpacked[1].y;
+  }
+}
+};
+
+
+template <
+    typename ALayout,
+    typename BLayout,
+    typename CLayout,
+    int Warps,
+    int KTilesPerIteration>
+__global__
+__launch_bounds__(256) void tinygemm_m16n8k16_chunk_kernel(
+    // Data for the A matrix, loaded as per ALayout
+    const half* __restrict__ A,
+    const void* __restrict__ B,
+    const uint64_t* __restrict__ CB,
+
+    // Output data for the C matrix, stored as per CLayout
+    half* __restrict__ C,
+
+    // The size of the matrix multiplication
+    int32_t m,
+    int32_t n,
+    int32_t k,
+
+    // The size of the matrix multiplication, in multiples of our TC tile size
+    int32_t mTiles,
+    int32_t nTiles,
+    int32_t kTiles) {
+  __shared__ uint64_t CB_[256];
+  if (BLayout::use_codebook) {
+    CB_[threadIdx.x + threadIdx.y * 32] = CB[threadIdx.x + threadIdx.y * 32];
+    __syncthreads();
+  }
+
+  auto warpId = threadIdx.y;
+  auto laneId = threadIdx.x;
+
+  int32_t mTile = blockIdx.z;
+  int32_t nTile = blockIdx.y;
+
+  float4 c{0.0f, 0.0f, 0.0f, 0.0f};
+
+ // First, handle whole multiples of KTilesPerIteration
+  auto kTilesLimit = roundDown(kTiles, KTilesPerIteration);
+
+  // Each warp handles a set of KTilesPerIteration under the above limit
+  for (int32_t kTileBase = warpId * KTilesPerIteration; kTileBase < kTilesLimit; kTileBase += Warps * KTilesPerIteration) {
+    //
+    // Load data from A
+    //
+    f16x2x4_u32 a[KTilesPerIteration];
+    ALayout::template load<KTilesPerIteration>(
+        A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a);
+
+    //
+    // Load data from B and de-quantize as needed
+    //
+    f16x2x2_u32 b[KTilesPerIteration];
+    BLayout::template load<KTilesPerIteration>(
+        B, CB_, n, k, nTiles, nTile, kTiles, kTileBase, laneId, b);
+
+    // Now, perform the matrix multiplication
+    //
+    #pragma unroll
+    for (int i = 0; i < KTilesPerIteration / 2; ++i) {
+      float4 cTmp[2];
+
+      #pragma unroll
+      for (int k = 0; k < 2; ++k) {
+        cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
+      }
+
+      #pragma unroll
+      for (int k = 0; k < 2; ++k) {
+        asm volatile(
+              "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
+              "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
+              : "=f"(cTmp[k].x),
+                "=f"(cTmp[k].y),
+                "=f"(cTmp[k].z),
+                "=f"(cTmp[k].w)
+              : "r"(a[i * 2 + k].vals[0]),
+                "r"(a[i * 2 + k].vals[1]),
+                "r"(a[i * 2 + k].vals[2]),
+                "r"(a[i * 2 + k].vals[3]),
+                "r"(b[i * 2 + k].vals[0]),
+                "r"(b[i * 2 + k].vals[1]),
+                "f"(cTmp[k].x),
+                "f"(cTmp[k].y),
+                "f"(cTmp[k].z),
+                "f"(cTmp[k].w));
+      }
+      #pragma unroll
+      for (int k = 0; k < 2; ++k) {
+        c.x += cTmp[k].x;
+        c.y += cTmp[k].y;
+        c.z += cTmp[k].z;
+        c.w += cTmp[k].w;
+      }
+    }
+
+  } // for all tiles under kTilesLimit
+
+
+  auto kTileBaseRemaining = kTilesLimit + warpId;
+
+  // If we have any remainder k-tiles, some warps will handle them, processing
+  // kInnerKTiles k-tiles at a time
+  if (kTileBaseRemaining < kTiles) {
+    f16x2x4_u32 a;
+    ALayout::template load<1>(
+        A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, &a);
+
+    f16x2x2_u32 b;
+    BLayout::template load<1>(
+        B, CB, n, k, nTiles, nTile, kTiles, kTileBaseRemaining, laneId, &b);
+
+    asm volatile(
+              "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
+              "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
+              : "=f"(c.x),
+                "=f"(c.y),
+                "=f"(c.z),
+                "=f"(c.w)
+              : "r"(a.vals[0]),
+                "r"(a.vals[1]),
+                "r"(a.vals[2]),
+                "r"(a.vals[3]),
+                "r"(b.vals[0]),
+                "r"(b.vals[1]),
+                "f"(c.x),
+                "f"(c.y),
+                "f"(c.z),
+                "f"(c.w));
+  }
+  // Reduce independent k-tiles (same m/n) across warps
+  __shared__ float4 smem_sum[Warps][kWarpSize];
+
+  smem_sum[warpId][laneId] = c;
+
+  __syncthreads();
+
+  if (warpId == 0) {
+    float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f};
+
+    // Reduce across the block in the first warp
+    for (int i = 0; i < Warps; ++i) {
+      float4 v = smem_sum[i][laneId];
+      sum_f32.x += v.x;
+      sum_f32.y += v.y;
+      sum_f32.z += v.z;
+      sum_f32.w += v.w;
+    }
+
+    // Write the reduced result (in the first warp) into the output
+    CLayout::store(
+        C,
+        m,
+        n,
+        mTiles,
+        mTile,
+        // n for C output becomes k for A input, so for m16n8k16,
+        // we need to halve the tiles
+        nTiles / 2,
+        nTile,
+        laneId,
+        sum_f32);
+  }
+}
+
+at::Tensor d4_mm_origorder(
+    const at::Tensor& A,
+    const at::Tensor& B,
+    const at::Tensor& CB) {
+  c10::cuda::CUDAGuard g(A.device());
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  constexpr int Warps = 8;
+
+  // row major layout
+  auto m = A.size(0);
+  auto mTiles = divUp(m, kMTileSize);
+
+  // tensor core layout
+  auto n = B.size(0);
+  auto nTiles = divUp(n, kNTileSize);
+
+  // row major layout
+  auto k = A.size(1);
+  auto kTiles = divUp(k, kKTileSize);
+
+  // Output is a standard row-major matrix
+  auto C_final = at::empty(
+      {m, n}, at::TensorOptions().dtype(A.dtype()).device(A.device()));
+
+  auto grid = dim3(1, nTiles, mTiles);
+  auto block = dim3(kWarpSize, Warps);
+  auto kernel = tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_D4, ALayout_RM, 8, 8>;
+
+  kernel<<<grid, block, 0, stream>>>(
+      (const half*)A.data_ptr(),
+      (const void*)B.data_ptr(),
+      (const uint64_t*)CB.data_ptr(),
+      (half*)C_final.data_ptr(),
+      m,
+      n,
+      k,
+      mTiles,
+      nTiles,
+      kTiles);
+
+  return C_final;
+}
+
+at::Tensor e8p_mm_origorder(
+    const at::Tensor& A,
+    const at::Tensor& B,
+    const at::Tensor& CB) {
+  c10::cuda::CUDAGuard g(A.device());
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  constexpr int Warps = 8;
+
+  // row major layout
+  auto m = A.size(0);
+  auto mTiles = divUp(m, kMTileSize);
+
+  // tensor core layout
+  auto n = B.size(0);
+  auto nTiles = divUp(n, kNTileSize);
+
+  // row major layout
+  auto k = A.size(1);
+  auto kTiles = divUp(k, kKTileSize);
+
+  // Output is a standard row-major matrix
+  auto C_final = at::empty(
+      {m, n}, at::TensorOptions().dtype(A.dtype()).device(A.device()));
+
+  auto grid = dim3(1, nTiles, mTiles);
+  auto block = dim3(kWarpSize, Warps);
+  auto kernel = tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_E8, ALayout_RM, 8, 8>;
+  kernel<<<grid, block, 0, stream>>>(
+      (const half*)A.data_ptr(),
+      (const void*)B.data_ptr(),
+      (const uint64_t*)CB.data_ptr(),
+      (half*)C_final.data_ptr(),
+      m,
+      n,
+      k,
+      mTiles,
+      nTiles,
+      kTiles);
+
+  return C_final;
+}
+
+at::Tensor hi_mm_origorder(
+    const at::Tensor& A,
+    const at::Tensor& B) {
+  c10::cuda::CUDAGuard g(A.device());
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  constexpr int Warps = 8;
+
+  // row major layout
+  auto m = A.size(0);
+  auto mTiles = divUp(m, kMTileSize);
+
+  // tensor core layout
+  auto n = B.size(0);
+  auto nTiles = divUp(n, kNTileSize);
+
+  // row major layout
+  auto k = A.size(1);
+  auto kTiles = divUp(k, kKTileSize);
+
+  // Output is a standard row-major matrix
+  auto C_final = at::empty(
+      {m, n}, at::TensorOptions().dtype(A.dtype()).device(A.device()));
+
+  auto grid = dim3(1, nTiles, mTiles);
+  auto block = dim3(kWarpSize, Warps);
+  auto kernel = tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_HI, ALayout_RM, 8, 8>;
+  kernel<<<grid, block, 0, stream>>>(
+      (const half*)A.data_ptr(),
+      (const void*)B.data_ptr(),
+      nullptr,
+      (half*)C_final.data_ptr(),
+      m,
+      n,
+      k,
+      mTiles,
+      nTiles,
+      kTiles);
+
+  return C_final;
+}
+
+#define DECOMPRESS_D4_BLOCK_SIZE 256
+
+__global__ void cuda_decompress_d4_origorder_kernel(
+    const uint8_t* __restrict__ YIs,	  // m x (n/4)
+    const c10::Half* __restrict__ CB,           // 256 x 4
+    c10::Half* __restrict__ Y             // m x n
+) {
+  const long i = threadIdx.x + DECOMPRESS_D4_BLOCK_SIZE * blockIdx.x;
+
+  for(long r = 0; r < 4; r++) {
+    uint8_t yidx = ((uint8_t*)YIs)[i*4 + r];
+    ((uint64_t*)Y)[i*4 + r] = ((uint64_t*)CB)[yidx & 255];
+  }
+}
+
+
+void decompress_d4_origorder(
+    torch::Tensor YIs,      // m x (n/4)
+    torch::Tensor CB,       // 256 x 4
+    torch::Tensor Y         // m x n
+) {
+  size_t m = Y.sizes()[0];
+  size_t n = Y.sizes()[1];
+
+  assert(YIs.is_contiguous());
+  assert(CB.is_contiguous());
+  assert(Y.is_contiguous());
+
+  assert(YIs.sizes()[0] == m);
+  assert(YIs.sizes()[1] * 4 == n);
+  assert(CB.sizes()[0] == 256);
+
+  const dim3 threads(DECOMPRESS_D4_BLOCK_SIZE);
+  const dim3 blocks(m*n/(16*DECOMPRESS_D4_BLOCK_SIZE));
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+  cuda_decompress_d4_origorder_kernel<<<blocks, threads, 0, stream>>>(
+    YIs.data_ptr<uint8_t>(),
+    CB.data_ptr<c10::Half>(),
+    Y.data_ptr<c10::Half>()
+  );
+}
+
+#define DECOMPRESS_E8P_BLOCK_SIZE 256
+
+__global__ void cuda_decompress_e8p_origorder_kernel(
+    const int16_t* __restrict__ YIs,	  // m x (n/8)
+    const int64_t* __restrict__ CB, // 256 x 8
+    c10::Half* __restrict__ Y             // m x n
+) {
+  const long i = threadIdx.x + DECOMPRESS_E8P_BLOCK_SIZE * blockIdx.x;
+  uint16_t yidx = ((uint16_t*)YIs)[i];
+  uint64_t decoded =  BLayout_E8::decode8weights(yidx, CB);
+
+  half2 unpacked[2][2];
+  uint64_t lower_half = decoded & 0x00ff00ff00ff00ff;
+  lower_half = (lower_half ^ 0x5c805c805c805c80);
+  memcpy(unpacked[0], &lower_half, sizeof(uint64_t));
+  uint64_t upper_half = (decoded & 0xff00ff00ff00ff00) >> 8;
+  upper_half = (upper_half ^ 0x5c805c805c805c80);
+  memcpy(unpacked[1], &upper_half, sizeof(uint64_t));
+
+  const half adjust_ = __float2half_rn(-288.0f);
+  const half2 adjust = __halves2half2(adjust_, adjust_);
+
+  ((__half2*)Y)[i*4] = __hadd2(unpacked[0][0], adjust); // 01
+  ((__half2*)Y)[i*4+2] = __hadd2(unpacked[0][1], adjust); // 45
+  ((__half2*)Y)[i*4+1] = __hadd2(unpacked[1][0], adjust); // 23
+  ((__half2*)Y)[i*4+3] = __hadd2(unpacked[1][1], adjust); // 67
+}
+
+
+void decompress_e8p_origorder(
+    torch::Tensor YIs,      // m x (n/8)
+    torch::Tensor CB,       // 256 x 8
+    torch::Tensor &Y         // m x n
+) {
+  size_t m = Y.sizes()[0];
+  size_t n = Y.sizes()[1];
+
+  assert(YIs.is_contiguous());
+  assert(CB.is_contiguous());
+  assert(Y.is_contiguous());
+
+  assert(YIs.sizes()[0] == m);
+  assert(YIs.sizes()[1] * 8 == n);
+  assert(CB.sizes()[0] == 256);
+
+  const dim3 threads(DECOMPRESS_E8P_BLOCK_SIZE);
+  const dim3 blocks(m*n/(8*DECOMPRESS_E8P_BLOCK_SIZE));
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+  cuda_decompress_e8p_origorder_kernel<<<blocks, threads, 0, stream>>>(
+    YIs.data_ptr<int16_t>(),
+    CB.data_ptr<int64_t>(),
+    Y.data_ptr<c10::Half>()
+  );
+}
+
+#define DECOMPRESS_HI_BLOCK_SIZE 256
+
+__global__ void cuda_decompress_hi_origorder_kernel(
+    const uint32_t* __restrict__ YIs,	  // m x (n/8)
+    c10::Half* __restrict__ Y             // m x n
+) {
+  const long i = threadIdx.x + DECOMPRESS_HI_BLOCK_SIZE * blockIdx.x;
+  uint32_t qa = YIs[i];
+
+  const uint32_t c0 = 0x64086408;
+  const half y16_ = __float2half_rn(1.0f / 16.0f);
+  const half2 y16 = __halves2half2(y16_, y16_);
+  const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
+  const half2 z16 = __halves2half2(z16_, z16_);
+
+
+  uint32_t q0 = (((qa & 0x000f000f) << 4) | c0);
+  uint32_t q1 = ((qa & 0x00f000f0)| c0);
+  qa >>= 8;
+  uint32_t q2 = (((qa & 0x000f000f) << 4) | c0);
+  uint32_t q3 = ((qa & 0x00f000f0) | c0);
+  ((__half2*)Y)[i*4] = __hfma2(*((half2*)(&q0)), y16, z16);
+  ((__half2*)Y)[i*4+1] = __hfma2(*((half2*)(&q1)), y16, z16);
+  ((__half2*)Y)[i*4+2] = __hfma2(*((half2*)(&q2)), y16, z16);
+  ((__half2*)Y)[i*4+3] = __hfma2(*((half2*)(&q3)), y16, z16);
+}
+
+void decompress_hi_origorder(
+    torch::Tensor YIs,      // m x (n/8)
+    torch::Tensor Y         // m x n
+){
+  size_t m = Y.sizes()[0];
+  size_t n = Y.sizes()[1];
+
+  assert(YIs.is_contiguous());
+  assert(Y.is_contiguous());
+
+  assert(YIs.sizes()[0] == m);
+  assert(YIs.sizes()[1] * 8 == n);
+
+  const dim3 threads(DECOMPRESS_HI_BLOCK_SIZE);
+  const dim3 blocks(m*n/(8*DECOMPRESS_HI_BLOCK_SIZE));
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+  cuda_decompress_hi_origorder_kernel<<<blocks, threads, 0, stream>>>(
+    (uint32_t*)YIs.data_ptr<int32_t>(),
+    Y.data_ptr<c10::Half>()
+  );
+}

+ 3 - 0
requirements.txt

@@ -18,3 +18,6 @@ aioprometheus[starlette] # for prometheus metrics
 triton >= 2.1.0
 lark == 1.1.8 # for grammars
 pynvml == 11.5.0
+gguf # for gguf
+scipy # for quip
+fast-hadamard-transform # for quip

+ 4 - 0
setup.py

@@ -272,6 +272,7 @@ aphrodite_extension_sources = [
     "kernels/activation_kernels.cu",
     "kernels/layernorm_kernels.cu",
     "kernels/quantization/squeezellm/quant_cuda_kernel.cu",
+    "kernels/quantization/gguf/gguf_kernel.cu",
     "kernels/quantization/gptq/q_gemm.cu",
     "kernels/cuda_utils_kernels.cu",
     "kernels/pybind.cpp",
@@ -279,6 +280,8 @@ aphrodite_extension_sources = [
 
 if _is_cuda():
     aphrodite_extension_sources.append("kernels/quantization/awq/gemm_kernels.cu")
+    aphrodite_extension_sources.append("kernels/quantization/quip/origin_order.cu")
+    aphrodite_extension_sources.append("kernels/quantization/marlin/marlin_cuda_kernel.cu")
     aphrodite_extension_sources.append("kernels/all_reduce/custom_all_reduce.cu")
 
 aphrodite_extension = CUDAExtension(
@@ -389,6 +392,7 @@ setuptools.setup(
     ext_modules=ext_modules,
     cmdclass={"build_ext": BuildExtension},
     package_data={"aphrodite-engine": ["aphrodite/endpoints/kobold/klite.embd",
+                                       "aphrodite/modeling/layers/quantization/hadamard.safetensors",
                                        "py.typed"]},
     include_package_data=True,
 )

+ 4 - 2
tests/benchmarks/throughput.py

@@ -203,11 +203,13 @@ if __name__ == "__main__":
                         type=str,
                         required=True,
                         help="Path to the dataset.")
-    parser.add_argument("--model", type=str, default="facebook/opt-125m")
+    parser.add_argument("--model",
+                        type=str,
+                        default="EleutherAI/pythia-70m-deduped")
     parser.add_argument("--tokenizer", type=str, default=None)
     parser.add_argument("--quantization",
                         "-q",
-                        choices=["awq", None],
+                        choices=["awq", "gguf", "gptq", "squeezellm", None],
                         default=None)
     parser.add_argument("--gpu-memory-utilization", type=float, default=0.88)
     parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)