Przeglądaj źródła

feat: better marlin kernels (#285)

* re-integrate marlin with updated kernels

Co-authored-by: Robert Shaw <114415538+rib-2@users.noreply.github.com>
Co-authored-by: alexm <alexm@neuralmagic.com>

* add conversion script

* formatting

---------

Co-authored-by: Robert Shaw <114415538+rib-2@users.noreply.github.com>
Co-authored-by: alexm <alexm@neuralmagic.com>
AlpinDale 1 rok temu
rodzic
commit
72229a94da

+ 1 - 0
.pylintrc

@@ -54,6 +54,7 @@ disable=abstract-method,
         consider-using-with,
         consider-using-in,
         apply-builtin,
+        unnecessary-comprehension,
         arguments-differ,
         attribute-defined-outside-init,
         backtick,

+ 13 - 4
aphrodite/common/config.py

@@ -150,7 +150,9 @@ class ModelConfig:
         self.tokenizer_mode = tokenizer_mode
 
     def _verify_quantization(self) -> None:
-        supported_quantization = ["awq", "gguf", "gptq", "quip", "squeezellm"]
+        supported_quantization = [
+            "awq", "gguf", "gptq", "quip", "squeezellm", "marlin"
+        ]
         rocm_not_supported_quantization = ["awq", "quip"]
         if self.quantization is not None:
             self.quantization = self.quantization.lower()
@@ -166,6 +168,11 @@ class ModelConfig:
         hf_quant_config = getattr(self.hf_config, "quantization_config", None)
         if hf_quant_config is not None:
             hf_quant_method = str(hf_quant_config["quant_method"]).lower()
+            # If the GPTQ model is serialized in marlin format, use marlin.
+            if (hf_quant_method == "gptq"
+                    and "is_marlin_format" in hf_quant_config
+                    and hf_quant_config["is_marlin_format"]):
+                hf_quant_method = "marlin"
             if self.quantization is None:
                 self.quantization = hf_quant_method
             elif self.quantization != hf_quant_method:
@@ -185,9 +192,11 @@ class ModelConfig:
                 raise ValueError(
                     f"{self.quantization} quantization is currently not "
                     "supported in ROCm.")
-            logger.warning(f"{self.quantization} quantization is not fully "
-                           "optimized yet. The speed can be slower than "
-                           "non-quantized models.")
+            if self.quantization != "marlin":
+                logger.warning(
+                    f"{self.quantization} quantization is not fully "
+                    "optimized yet. The speed can be slower than "
+                    "non-quantized models.")
 
     def _verify_cuda_graph(self) -> None:
         if self.max_context_len_to_capture is None:

+ 14 - 12
aphrodite/engine/args_tools.py

@@ -204,18 +204,20 @@ class EngineArgs:
                             action='store_true',
                             help='disable logging statistics')
         # Quantization settings.
-        parser.add_argument(
-            '--quantization',
-            '-q',
-            type=str,
-            choices=['awq', 'gguf', 'gptq', 'quip', 'squeezellm', None],
-            default=EngineArgs.quantization,
-            help='Method used to quantize the weights. If '
-            'None, we first check the `quantization_config` '
-            'attribute in the model config file. If that is '
-            'None, we assume the model weights are not '
-            'quantized and use `dtype` to determine the data '
-            'type of the weights.')
+        parser.add_argument('--quantization',
+                            '-q',
+                            type=str,
+                            choices=[
+                                'awq', 'gguf', 'gptq', 'quip', 'squeezellm',
+                                'marlin', None
+                            ],
+                            default=EngineArgs.quantization,
+                            help='Method used to quantize the weights. If '
+                            'None, we first check the `quantization_config` '
+                            'attribute in the model config file. If that is '
+                            'None, we assume the model weights are not '
+                            'quantized and use `dtype` to determine the data '
+                            'type of the weights.')
         parser.add_argument('--enforce-eager',
                             action='store_true',
                             help='Always use eager-mode PyTorch. If False, '

+ 25 - 0
aphrodite/modeling/layers/linear.py

@@ -17,6 +17,14 @@ from aphrodite.common.logger import init_logger
 logger = init_logger(__name__)
 
 
+def adjust_marlin_shard(param, shard_size, shard_offset):
+    marlin_tile_size = getattr(param, "marlin_tile_size", None)
+    if marlin_tile_size is None:
+        return shard_size, shard_offset
+
+    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
+
+
 class LinearMethodBase(ABC):
     """Base class for different (maybe quantized) linear methods."""
 
@@ -288,6 +296,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
                 if packed_dim == output_dim:
                     shard_size = shard_size // param.pack_factor
                     shard_offset = shard_offset // param.pack_factor
+
+                    # If marlin, we need to adjust the offset and size to account for the tiling.
+                    shard_size, shard_offset = adjust_marlin_shard(
+                        param, shard_size, shard_offset)
+
                 loaded_weight_shard = loaded_weight.narrow(
                     output_dim, shard_offset, shard_size)
                 self.weight_loader(param, loaded_weight_shard, shard_id)
@@ -305,6 +318,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
             if packed_dim == output_dim:
                 shard_size = shard_size // param.pack_factor
                 shard_offset = shard_offset // param.pack_factor
+
+                # If marlin, we need to adjust the offset and size to account for the tiling.
+                shard_size, shard_offset = adjust_marlin_shard(
+                    param, shard_size, shard_offset)
             param_data = param_data.narrow(output_dim, shard_offset,
                                            shard_size)
             start_idx = tp_rank * shard_size
@@ -405,6 +422,10 @@ class QKVParallelLinear(ColumnParallelLinear):
                 if packed_dim == output_dim:
                     shard_size = shard_size // param.pack_factor
                     shard_offset = shard_offset // param.pack_factor
+
+                    # If marlin, we need to adjust the offset and size to account for the tiling.
+                    shard_size, shard_offset = adjust_marlin_shard(
+                        param, shard_size, shard_offset)
                 loaded_weight_shard = loaded_weight.narrow(
                     output_dim, shard_offset, shard_size)
                 self.weight_loader(param, loaded_weight_shard, shard_id)
@@ -429,6 +450,10 @@ class QKVParallelLinear(ColumnParallelLinear):
             if packed_dim == output_dim:
                 shard_size = shard_size // param.pack_factor
                 shard_offset = shard_offset // param.pack_factor
+
+                # If marlin, we need to adjust the offset and size to account for the tiling.
+                shard_size, shard_offset = adjust_marlin_shard(
+                    param, shard_size, shard_offset)
             param_data = param_data.narrow(output_dim, shard_offset,
                                            shard_size)
             if loaded_shard_id == "q":

+ 2 - 0
aphrodite/modeling/layers/quantization/__init__.py

@@ -6,6 +6,7 @@ from aphrodite.modeling.layers.quantization.gguf import GGUFConfig
 from aphrodite.modeling.layers.quantization.gptq import GPTQConfig
 from aphrodite.modeling.layers.quantization.quip import QuipConfig
 from aphrodite.modeling.layers.quantization.squeezellm import SqueezeLLMConfig
+from aphrodite.modeling.layers.quantization.marlin import MarlinConfig
 
 _QUANTIZATION_CONFIG_REGISTRY = {
     "awq": AWQConfig,
@@ -13,6 +14,7 @@ _QUANTIZATION_CONFIG_REGISTRY = {
     "gptq": GPTQConfig,
     "quip": QuipConfig,
     "squeezellm": SqueezeLLMConfig,
+    "marlin": MarlinConfig,
 }
 
 

+ 216 - 0
aphrodite/modeling/layers/quantization/marlin.py

@@ -0,0 +1,216 @@
+from typing import Any, Dict, List, Optional
+
+import torch
+from torch.nn.parameter import Parameter
+
+from aphrodite._C import ops
+from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
+from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
+
+
+class MarlinConfig(QuantizationConfig):
+    """Config class for Marlin.
+
+    Reference: https://github.com/IST-DASLab/marlin/tree/master
+    """
+
+    def __init__(
+        self,
+        group_size: int,
+    ) -> None:
+        # Group size for the quantization.
+        self.group_size = group_size
+        if self.group_size != 128 and self.group_size != -1:
+            raise ValueError(
+                "Currently, only group size 128 and -1 (channelwise) is supported for "
+                f"Marlin, but got group_size of {self.group_size}")
+
+        # 4 Bits packed into 32 bit datatype.
+        self.pack_factor = 32 // 4
+
+        # Tile size used by marlin kernels.
+        self.tile_size = 16
+
+        # Min out_features dim
+        self.min_n_threads = 64
+
+        # Min in_features dim
+        self.min_k_threads = 128
+
+        # Max parallel problems to solve at once (improves large batch performance)
+        self.max_parallel = 16
+
+        # Permutation length used by the marlin kernels.
+        self.perm_len = 1024
+
+    def __repr__(self) -> str:
+        return f"MarlinConfig(group_size={self.group_size}"
+
+    @classmethod
+    def get_name(cls) -> str:
+        return "marlin"
+
+    @classmethod
+    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+        return [torch.half]
+
+    @classmethod
+    # Need to figure it out
+    def get_min_capability(cls) -> int:
+        return 80
+
+    @classmethod
+    def get_config_filenames(cls) -> List[str]:
+        return ["quantize_config.json"]
+
+    @classmethod
+    def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
+        group_size = cls.get_from_keys(config, ["group_size"])
+        return cls(group_size)
+
+    def get_linear_method(self) -> "MarlinLinearMethod":
+        return MarlinLinearMethod(self)
+
+    def get_scaled_act_names(self) -> List[str]:
+        return []
+
+    def merge_weight(self) -> bool:
+        return False
+
+    def rope_style(self) -> Optional[bool]:
+        return None
+
+
+class MarlinLinearMethod(LinearMethodBase):
+    """Linear method for Marlin.
+
+    Args:
+        quant_config: The Marlin quantization config.
+    """
+
+    def __init__(self, quant_config: MarlinConfig):
+        self.quant_config = quant_config
+
+    def create_weights(
+        self,
+        input_size_per_partition: int,
+        output_size_per_partition: int,
+        input_size: int,
+        output_size: int,
+        params_dtype: torch.dtype,
+    ) -> Dict[str, Any]:
+        del output_size  # Unused.
+
+        if params_dtype != torch.float16:
+            raise ValueError(
+                f"The params dtype must be float16, but got {params_dtype}")
+
+        # Validate output_size_per_partition
+        if output_size_per_partition % self.quant_config.min_n_threads != 0:
+            raise ValueError(
+                f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}."
+            )
+        if output_size_per_partition % self.quant_config.pack_factor != 0:
+            raise ValueError(
+                f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}."
+            )
+
+        # Validate input_size_per_partition
+        if input_size_per_partition % self.quant_config.min_k_threads != 0:
+            raise ValueError(
+                f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}."
+            )
+        if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0:
+            raise ValueError(
+                f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}."
+            )
+
+        # Check that we have at least 4 tiles horizontally in the shard
+        num_tiles_per_perm = self.quant_config.perm_len // (
+            self.quant_config.tile_size**2)
+        if output_size_per_partition % num_tiles_per_perm != 0:
+            raise ValueError(
+                "Each permutation group must reside on the same gpu")
+
+        # Quantized 4Bit weights packed into Int32.
+        qweight = Parameter(
+            torch.empty(
+                input_size_per_partition // self.quant_config.tile_size,
+                output_size_per_partition * self.quant_config.tile_size //
+                self.quant_config.pack_factor,
+                device="cuda",
+                dtype=torch.int32,
+            ),
+            requires_grad=False,
+        )
+        set_weight_attrs(
+            qweight,
+            {
+                "input_dim": 0,
+                "output_dim": 1,
+                "packed_dim": 1,
+                "pack_factor": self.quant_config.pack_factor,
+                "marlin_tile_size": self.quant_config.tile_size,
+            },
+        )
+
+        # Determine if channelwise or not
+        input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size
+
+        scales = Parameter(
+            torch.empty(
+                input_groups,
+                output_size_per_partition,
+                device="cuda",
+                dtype=params_dtype,
+            ),
+            requires_grad=False,
+        )
+        set_weight_attrs(
+            scales,
+            {
+                "input_dim": None if input_groups == 1 else 0,
+                "output_dim": 1,
+            },
+        )
+
+        # Allocate workspace (Used for internal locking mechanism)
+        max_workspace_size = (
+            output_size_per_partition //
+            self.quant_config.min_n_threads) * self.quant_config.max_parallel
+        workspace = Parameter(torch.zeros(max_workspace_size,
+                                          device="cuda",
+                                          dtype=torch.int),
+                              requires_grad=False)
+
+        return {
+            "B": qweight,
+            "s": scales,
+            "workspace": workspace,
+        }
+
+    def apply_weights(
+        self,
+        weights: Dict[str, Any],
+        x: torch.Tensor,
+        bias: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        qweight = weights["B"]
+        scales = weights["s"]
+        workspace = weights["workspace"]
+
+        x_2d = x.view(-1, x.shape[-1])
+
+        size_m = x_2d.shape[0]
+        size_k = x_2d.shape[1]
+        size_n = scales.shape[1]
+
+        output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
+                                    size_n, size_k)
+
+        output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
+
+        if bias is not None:
+            output.add_(bias)  # In-place add
+
+        return output

+ 66 - 0
examples/marlin/README.md

@@ -0,0 +1,66 @@
+## GPTQ Conversion to Marlin
+
+First, you will need a GPTQ model that satisfies the following conditions:
+
+### Acquiring a compatible GPTQ model
+- `group_size=-1` OR `128`
+- `bits=4`
+- `desc_act=False`
+
+If your model does not meet the requirements above, then run the following script to convert an FP16 model to the appropriate GPTQ format:
+
+```py
+from transformers import AutoTokenizer, TextGenerationPipeline
+from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
+
+pretrained_model_dir = "mistralai/Mistral-7B-Instruct-v0.2"
+quantized_model_dir = "/path/to/output"
+
+
+tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
+examples = [
+    tokenizer(
+        "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
+    )
+]
+
+quantize_config = BaseQuantizeConfig(
+    bits=4,
+    group_size=128,
+    desc_act=False,
+)
+
+model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
+model.quantize(examples)
+
+model.save_quantized(quantized_model_dir, use_safetensors=True)
+model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0")
+```
+
+Replace the `pretrained_model_dir` and `quantized_model_dir` with the appropriate paths to your base model and output directory. Save the script, and run it like this:
+
+```sh
+CUDA_VISIBLE_DEVICES=0 python quantize.py
+```
+You may need to install the AutoGPTQ library via `pip install auto-gptq`.
+
+
+Once you have your compatible GPTQ model, follow the steps below to convert it to Marlin format.
+
+### Converting GPTQ models to Marlin
+
+You will need to clone and install the Marlin repository:
+
+```sh
+git clone https://github.com/IST-DASLab/marlin && cd marlin
+
+pip install -e .
+```
+
+Then simply run the following in this directory:
+
+```sh
+python convert.py --model-id /path/to/gptq/model --save-path /path/to/output/marlin
+```
+
+That should be all you'll need to do. Then simply launch Aphrodite, point `--model` to the marlin checkpoint, and that will be all. Happy prompting.

+ 161 - 0
examples/marlin/convert.py

@@ -0,0 +1,161 @@
+import torch, argparse, copy
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
+from marlin import Layer as MarlinLayer
+import gc
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--model-id", type=str)
+parser.add_argument("--save-path", type=str)
+parser.add_argument("--do-generation", action="store_true")
+
+def _validate_compatibility(model):
+    if not hasattr(model.config, "quantization_config"):
+        raise ValueError("Must be a quantized model to convert to Marlin Format")
+    quantization_config = model.config.quantization_config
+    if quantization_config.quant_method != "gptq":
+        raise ValueError(f"Only GPTQ models can be converted to Marlin format. You passed a model with quant_method={quantization_config.quant_method}")
+    if quantization_config.bits != 4:
+        raise ValueError(f"Only 4 bit quantized models can be converted to Marlin format. You passed a model with bits={quantization_config.bits}")
+    if quantization_config.group_size != 128:
+        raise ValueError(f"Only group size 128 models can be converted to Marlin format. You passed a model with group_size={quantization_config.group_size}")
+    if not quantization_config.sym:
+        raise ValueError(f"Only models with symmetric quantization can be converted to Marlin Format. You passed a model with sym={quantization_config.sym}")
+    if quantization_config.desc_act:
+        raise ValueError(f"Models with act order quantization cannot be converted to Marlin Format. You passed a model with desc_act={quantization_config.desc_act}")
+
+@torch.no_grad()
+def unpack_4bit_to_32bit_signed(qweight, qzeros):
+    # Unpack 4-bit values and interpret them as signed integers
+    unpacked_weights = torch.zeros((qweight.shape[0]*8, qweight.shape[1]), dtype=torch.int8, device=qweight.device, requires_grad=False)
+    unpacked_zeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*8), dtype=torch.int8, device=qzeros.device, requires_grad=False)
+
+    for row in range(unpacked_weights.shape[0]):
+        i = row % 8
+        unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF
+
+    for col in range(unpacked_zeros.shape[1]):
+        i = col % 8
+        unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF
+
+    return unpacked_weights, unpacked_zeros + 1
+
+@torch.no_grad()
+def dequantize_weight(layer):
+    qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales
+    unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros)
+    group_size = unpacked_qweight.shape[0] // scales.shape[0]
+    scales = scales.repeat_interleave(group_size, dim=0)
+    unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
+    unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales
+
+    return unpacked_qweight.T
+
+@torch.no_grad()
+def convert_model(model, verbose=True):
+    for name, module in model.named_modules():
+        if not isinstance(module, QuantLinear):
+            continue
+
+        if verbose:
+            print(f"--- Converting Module: {name}")
+        parent_name = ".".join(name.split(".")[:-1])
+        layer_name = name[len(parent_name) + 1:]
+
+        # Dequantize the weight.
+        dequantized_weight = dequantize_weight(module).to(torch.float16)
+        linear_module = torch.nn.Linear(
+            in_features=dequantized_weight.shape[1],
+            out_features=dequantized_weight.shape[0],
+            bias=False,
+            dtype=torch.float16,
+            device="cuda")
+        linear_module.weight.data.copy_(dequantized_weight)
+
+        # Create new linear method and copy to model.
+        new_module = MarlinLayer(
+            infeatures=linear_module.in_features,
+            outfeatures=linear_module.out_features,
+            groupsize=model.config.quantization_config.group_size)
+        new_module.pack(linear_module, scales=copy.deepcopy(module.scales.data.t()))
+
+        # Save to parent.
+        parent_module = model.get_submodule(parent_name)
+        setattr(parent_module, layer_name, new_module)
+
+        # Free cuda memory.
+        del dequantized_weight, module
+        torch.cuda.empty_cache()
+        gc.collect()
+
+    return model
+
+@torch.no_grad()
+def dequantize_model(model, verbose=True):
+    for name, module in model.named_modules():
+        if not isinstance(module, QuantLinear):
+            continue
+
+        if verbose:
+            print(f"--- Dequantizing Module: {name}")
+        parent_name = ".".join(name.split(".")[:-1])
+        layer_name = name[len(parent_name) + 1:]
+
+        # Dequantize the weight.
+        dequantized_weight = dequantize_weight(module)
+        dequantized_weight_cpu = dequantized_weight.to("cpu")
+
+        # Create new linear method and copy to model.
+        new_module = torch.nn.Linear(
+            in_features=dequantized_weight_cpu.shape[1],
+            out_features=dequantized_weight_cpu.shape[0],
+            bias=False,
+            dtype=torch.float16)
+        new_module.weight.data.copy_(dequantized_weight_cpu)
+        new_module.scales = torch.nn.Parameter(copy.deepcopy(module.scales.data))
+
+        # Save to parent.
+        parent_module = model.get_submodule(parent_name)
+        setattr(parent_module, layer_name, new_module)
+
+        # Free cuda memory.
+        del dequantized_weight, dequantized_weight_cpu, module
+        torch.cuda.empty_cache()
+
+    return model
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    model_id = args.model_id
+    save_path = args.save_path
+    do_generation = args.do_generation
+
+    print("Loading gptq model...")
+    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
+    tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+    # Validate that this model is compatible with Marlin.
+    print("Validating compatibility...")
+    _validate_compatibility(model)
+
+    # Dequantize the Model.
+    print("Converting model...")
+    model = convert_model(model).to("cpu")
+
+    # Save after updating quantization config.
+    print("Saving marlin model...")
+    model.config.quantization_config = {
+        "group_size": model.config.quantization_config.group_size,
+        "quant_method": "marlin"
+    }
+    model.save_pretrained(save_path)
+    tokenizer.save_pretrained(save_path)
+
+    if do_generation:
+        print("Generating sample text...")
+        model.to("cuda")
+        prompt = "My favorite song is"
+        inputs = tokenizer(prompt, return_tensors="pt")
+        inputs = {k: v.to("cuda") for k, v in inputs.items()}
+        outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False)
+        print(tokenizer.batch_decode(outputs)[0])

+ 8 - 6
kernels/ops.h

@@ -85,12 +85,14 @@ torch::Tensor awq_dequantize(
     int thx,
     int thy);
 
-void marlin_gemm(
-  const torch::Tensor& input,
-  const torch::Tensor& weights,
-        torch::Tensor& output,
-  const torch::Tensor& scales,
-        torch::Tensor& workspace);
+torch::Tensor marlin_gemm(
+    torch::Tensor& a, 
+    torch::Tensor& b_q_weight,
+    torch::Tensor& b_scales, 
+    torch::Tensor& workspace,
+    int64_t size_m, 
+    int64_t size_n, 
+    int64_t size_k);
 
 at::Tensor e8p_mm_origorder(
     const at::Tensor& A,

+ 210 - 0
kernels/quantization/marlin/LICENSE

@@ -0,0 +1,210 @@
+
+Contains code from https://github.com/IST-DASLab/marlin
+
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "{}"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright {yyyy} {name of copyright owner}
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+
+------------------------------------------------------------------------------------
+
+This product bundles various third-party components under other open source licenses.
+This section summarizes those components and their licenses. See licenses/
+for text of these licenses.

+ 1131 - 856
kernels/quantization/marlin/marlin_cuda_kernel.cu

@@ -1,5 +1,6 @@
 /*
- * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
+ * Modified by Neural Magic
+ * Copyright (C) Marlin.2024 Elias Frantar
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -14,858 +15,1132 @@
  * limitations under the License.
  */
 
-
-#ifndef MARLIN_CUDA_KERNEL_CUH
-#define MARLIN_CUDA_KERNEL_CUH
-
-#include <torch/extension.h>
-#include <c10/cuda/CUDAStream.h>
-#include <cuda.h>
-#include <cuda_fp16.h>
-#include <cuda_runtime.h>
-
-namespace aphrodite {
-namespace marlin {
-
-constexpr int ceildiv(int a, int b) {
-  return (a + b - 1) / b;
-}
-
-// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core
-// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we
-// extensively use `#pragma unroll` throughout the kernel code to guarantee this.
-template <typename T, int n>
-struct Vec {
-  T elems[n];
-  __device__ T& operator[](int i) {
-    return elems[i];
-  }
-};
-
-using I4 = Vec<int, 4>;
-
-// Matrix fragments for tensor core instructions; their precise layout is documented here:
-// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
-using FragA = Vec<half2, 4>;
-using FragB = Vec<half2, 2>;
-using FragC = Vec<float, 4>;
-using FragS = Vec<half2, 1>; // quantization scales
-
-// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that
-// are not multiples of 16.
-__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  const int BYTES = 16;
-  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
-  asm volatile(
-    "{\n"
-    "   .reg .pred p;\n"
-    "   setp.ne.b32 p, %0, 0;\n"
-    "   @p cp.async.cg.shared.global [%1], [%2], %3;\n"
-    "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES)
-  );
-#endif
-}
-
-// Asynchronous global->shared copy with a chache hint indicating that the values may be evicted immediately; used for
-// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need
-// for inputs A and outputs C.
-__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  const int BYTES = 16;
-  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
-  asm volatile(
-    "{\n"
-    "   .reg .b64 p;\n"
-    "   createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
-    "   cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
-    "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)
-  );
-#endif
-}
-
-// Async copy fence.
-__device__ inline void cp_async_fence() {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  asm volatile("cp.async.commit_group;\n" ::);
-#endif
-}
-
-// Wait until at most `n` async copy stages are still pending.
-template <int n>
-__device__ inline void cp_async_wait() {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  asm volatile("cp.async.wait_group %0;\n" :: "n"(n));
-#endif
-}
-
-// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation.
-__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
-  const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
-  float* c = reinterpret_cast<float*>(&frag_c);
-  asm volatile(
-    "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
-    "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
-    : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
-    :  "r"(a[0]),  "r"(a[1]),  "r"(a[2]),  "r"(a[3]),  "r"(b[0]),  "r"(b[1]),
-       "f"(c[0]),  "f"(c[1]),  "f"(c[2]),  "f"(c[3])
-  );
-#endif
-}
-
-// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout.
-__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
-  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
-  asm volatile(
-    "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
-    : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)
-  );
-#endif
-}
-
-// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to
-// automatically recognize it in all cases.
-template <int lut>
-__device__ inline int lop3(int a, int b, int c) {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  int res;
-  asm volatile(
-    "lop3.b32 %0, %1, %2, %3, %4;\n"
-    : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)
-  );
-  return res;
-#endif
-}
-
-// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values.
-// We mostly follow the strategy in the link below, with some small changes:
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
-__device__ inline FragB dequant(int q) {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  const int LO = 0x000f000f;
-  const int HI = 0x00f000f0;
-  const int EX = 0x64006400;
-  // Guarantee that the `(a & b) | c` operations are LOP3s.
-  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
-  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
-  // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`.
-  const int SUB = 0x64086408;
-  const int MUL = 0x2c002c00;
-  const int ADD = 0xd480d480;
-  FragB frag_b;
-  frag_b[0] = __hsub2(
-    *reinterpret_cast<half2*>(&lo),
-    *reinterpret_cast<const half2*>(&SUB)
-  );
-  frag_b[1] = __hfma2(
-    *reinterpret_cast<half2*>(&hi),
-    *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD)
-  );
-  return frag_b;
-#endif
-}
-
-// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization.
-__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
-  frag_b[0] = __hmul2(frag_b[0], s);
-  frag_b[1] = __hmul2(frag_b[1], s);
-#endif
-}
-
-// Wait until barrier reaches `count`, then lock for current threadblock.
-__device__ inline void barrier_acquire(int* lock, int count) {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  if (threadIdx.x == 0) {
-    int state = -1;
-    do
-      // Guarantee that subsequent writes by this threadblock will be visible globally.
-      asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
-    while (state != count);
-  }
-  __syncthreads();
-#endif
-}
-
-// Release barrier and increment visitation count.
-__device__ inline void barrier_release(int* lock, bool reset = false) {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  __syncthreads();
-  if (threadIdx.x == 0) {
-    if (reset) {
-      lock[0] = 0;
-      return;
-    }
-    int val = 1;
-    // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier.
-    asm volatile ("fence.acq_rel.gpu;\n");
-    asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val));
-  }
-#endif
-}
-
-template <
-  const int threads, // number of threads in a threadblock
-  const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock
-  const int thread_n_blocks, // same for n dimension (output)
-  const int thread_k_blocks, // same for k dimension (reduction)
-  const int stages, // number of stages for the async global->shared fetch pipeline
-  const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale
->
-__global__ void Marlin(
-  const int4* __restrict__ A, // fp16 input matrix of shape mxk
-  const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
-        int4* __restrict__ C, // fp16 output buffer of shape mxn
-  const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
-  int  prob_m, // batch dimension m
-  int  prob_n, // output dimension n
-  int  prob_k, // reduction dimension k
-  int* locks // extra global storage for barrier synchronization
-) {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple
-  // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example:
-  //   0 1 3
-  //   0 2 3
-  //   1 2 4
-  // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs
-  // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as
-  // possible.
-
-  int k_tiles = prob_k / 16 / thread_k_blocks;
-  int n_tiles = prob_n / 16 / thread_n_blocks;
-  int iters = ceildiv(k_tiles * n_tiles, gridDim.x);
-  // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case
-  // where a stripe starts in the middle of group.
-  if (group_blocks != -1)
-    iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks));
-
-  int slice_row = (iters * blockIdx.x) % k_tiles;
-  int slice_col = (iters * blockIdx.x) / k_tiles;
-  int slice_iters; // number of threadblock tiles in the current slice
-  int slice_count = 0; // total number of active threadblocks in the current slice
-  int slice_idx; // index of threadblock in current slice; numbered bottom to top
-
-  // Compute all information about the current slice which is required for synchronization.
-  auto init_slice = [&] () {
-    slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col + slice_row);
-    if (slice_iters < 0 || slice_col >= n_tiles)
-      slice_iters = 0;
-    if (slice_iters == 0)
-      return;
-    if (slice_row + slice_iters > k_tiles)
-      slice_iters = k_tiles - slice_row;
-    slice_count = 1;
-    slice_idx = 0;
-    int col_first = iters * ceildiv(k_tiles * slice_col, iters);
-    if (col_first <= k_tiles * (slice_col + 1)) {
-      int col_off = col_first - k_tiles * slice_col;
-      slice_count = ceildiv(k_tiles - col_off, iters);
-      if (col_off > 0)
-        slice_count++;
-      int delta_first = iters * blockIdx.x - col_first;
-      if (delta_first < 0 || (col_off == 0 && delta_first == 0))
-        slice_idx = slice_count - 1;
-      else {
-        slice_idx = slice_count - 1 - delta_first / iters;
-        if (col_off > 0)
-          slice_idx--;
-      }
-    }
-  };
-  init_slice();
-
-  int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
-  // We typically use `constexpr` to indicate that this value is a compile-time constant
-  constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
-  constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory
-  int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
-  constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes
-  constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads
-  constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile
-  constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
-  constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile
-
-  int b_gl_stride = 16 * prob_n / 32;
-  constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
-  int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
-  int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
-  constexpr int b_sh_wr_delta = threads;
-  constexpr int b_sh_rd_delta = threads;
-  constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
-  constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
-
-  int s_gl_stride = prob_n / 8;
-  constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
-  constexpr int s_sh_stage = s_sh_stride;
-  int s_gl_rd_delta = s_gl_stride;
-
-  // Global A read index of current thread.
-  int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
-  a_gl_rd += a_gl_rd_delta_o * slice_row;
-  // Shared write index of current thread.
-  int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
-  // Shared read index.
-  int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
-  a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
-
-  int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
-  b_gl_rd += b_sh_stride * slice_col;
-  b_gl_rd += b_gl_rd_delta_o * slice_row;
-  int b_sh_wr = threadIdx.x;
-  int b_sh_rd = threadIdx.x;
-
-  int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x;
-  int s_sh_wr = threadIdx.x;
-  int s_sh_rd;
-  // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major
-  // layout in the former and in row-major in the latter case.
-  if (group_blocks != -1)
-    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
-  else
-    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;
-
-  // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than
-  // required for a certain tilesize or when the batchsize is not a multiple of 16.
-  bool a_sh_wr_pred[a_sh_wr_iters];
-  #pragma unroll
-  for (int i = 0; i < a_sh_wr_iters; i++)
-    a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
-  bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
-
-  // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank
-  // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of
-  // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based
-  // on NSight-Compute) that each warp must also write a consecutive memory segment?
-  auto transform_a = [&] (int i) {
-    int row = i / a_gl_rd_delta_o;
-    return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
-  };
-  // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory
-  // accesses are static, we simply precompute both transformed reads and writes.
-  int a_sh_wr_trans[a_sh_wr_iters];
-  #pragma unroll
-  for (int i = 0; i < a_sh_wr_iters; i++)
-    a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
-  int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
-  #pragma unroll
-  for (int i = 0; i < b_sh_wr_iters; i++) {
-    #pragma unroll
-    for (int j = 0; j < thread_m_blocks; j++)
-      a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
-  }
-
-  // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between
-  // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization.
-  const int4* B_ptr[b_sh_wr_iters];
-  #pragma unroll
-  for (int i = 0; i < b_sh_wr_iters; i++)
-    B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
-
-  extern __shared__ int4 sh[];
-  // Shared memory storage for global fetch pipelines.
-  int4* sh_a = sh;
-  int4* sh_b = sh_a + (stages * a_sh_stage);
-  int4* sh_s = sh_b + (stages * b_sh_stage);
-  // Register storage for double buffer of shared memory reads.
-  FragA frag_a[2][thread_m_blocks];
-  I4 frag_b_quant[2];
-  FragC frag_c[thread_m_blocks][4][2];
-  FragS frag_s[2][4];
-
-  // Zero accumulators.
-  auto zero_accums = [&] () {
-    #pragma unroll
-    for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
-      reinterpret_cast<float*>(frag_c)[i] = 0;
-  };
-
-  // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location.
-  auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) {
-    if (pred) {
-      int4* sh_a_stage = sh_a + a_sh_stage * pipe;
-      #pragma unroll
-      for (int i = 0; i < a_sh_wr_iters; i++) {
-        cp_async4_pred(
-          &sh_a_stage[a_sh_wr_trans[i]],
-          &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
-          a_sh_wr_pred[i]
-        );
-      }
-      int4* sh_b_stage = sh_b + b_sh_stage * pipe;
-      #pragma unroll
-      for (int i = 0; i < b_sh_wr_iters; i++) {
-        cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
-        B_ptr[i] += b_gl_rd_delta_o;
-      }
-      // Only fetch scales if this tile starts a new group
-      if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
-        int4* sh_s_stage = sh_s + s_sh_stage * pipe;
-        if (s_sh_wr_pred)
-          cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
-        s_gl_rd += s_gl_rd_delta;
-      }
-    }
-    // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point.
-    cp_async_fence();
-  };
-
-  // Wait until the next thread tile has been loaded to shared memory.
-  auto wait_for_stage = [&] () {
-    // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when
-    // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten).
-    cp_async_wait<stages - 2>();
-    __syncthreads();
-  };
-
-  // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer.
-  auto fetch_to_registers = [&] (int k, int pipe) {
-    // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a
-    // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the
-    // compiler and correspondingly a noticable drop in performance.
-    if (group_blocks != -1) {
-      int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));
-      reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
-    }
-    int4* sh_a_stage = sh_a + a_sh_stage * pipe;
-    #pragma unroll
-    for (int i = 0; i < thread_m_blocks; i++)
-      ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
-    int4* sh_b_stage = sh_b + b_sh_stage * pipe;
-    frag_b_quant[k % 2] = *reinterpret_cast<I4*>(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
-  };
-
-  // Execute the actual tensor core matmul of a sub-tile.
-  auto matmul = [&] (int k) {
-    // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations.
-    #pragma unroll
-    for (int j = 0; j < 4; j++) {
-      int b_quant = frag_b_quant[k % 2][j];
-      int b_quant_shift = b_quant >> 8;
-      FragB frag_b0 = dequant(b_quant);
-      // If there are no groups, we can just scale the final output once and can avoid doing so for each weight.
-      if (group_blocks != -1)
-        scale(frag_b0, frag_s[k % 2][j], 0);
-      FragB frag_b1 = dequant(b_quant_shift);
-      if (group_blocks != -1)
-        scale(frag_b1, frag_s[k % 2][j], 1);
-      #pragma unroll
-      for (int i = 0; i < thread_m_blocks; i++) {
-        mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
-        mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
-      }
-    }
-  };
-
-  // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n
-  // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output
-  // location; which we have to reduce over in the end. We do in shared memory.
-  auto thread_block_reduce = [&] () {
-    constexpr int red_off = threads / b_sh_stride / 2;
-    if (red_off >= 1) {
-      int red_idx = threadIdx.x / b_sh_stride;
-      constexpr int red_sh_stride = b_sh_stride * 4 * 2;
-      constexpr int red_sh_delta = b_sh_stride;
-      int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
-
-      // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations,
-      // e.g., for two warps we write only once by warp 1 and read only once by warp 0.
-
-      #pragma unroll
-      for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
-        #pragma unroll
-        for (int i = red_off; i > 0; i /= 2) {
-          if (i <= red_idx && red_idx < 2 * i) {
-            #pragma unroll
-            for (int j = 0; j < 4 * 2; j++) {
-              int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
-              if (i < red_off) {
-                float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
-                float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
-                #pragma unroll
-                for (int k = 0; k < 4; k++)
-                  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];
-              }
-              sh[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
-            }
-          }
-          __syncthreads();
-        }
-        if (red_idx == 0) {
-          #pragma unroll
-          for (int i = 0; i < 4 * 2; i++) {
-            float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
-            #pragma unroll
-            for (int j = 0; j < 4; j++)
-              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];
-          }
-        }
-        __syncthreads();
-      }
-    }
-  };
-
-  // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over
-  // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather
-  // small, we perform this reduction serially in L2 cache.
-  auto global_reduce = [&] (bool first = false, bool last = false) {
-    // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step.
-    // To do this, we write out results in FP16 (but still reduce with FP32 compute).
-    constexpr int active_threads = 32 * thread_n_blocks / 4;
-    if (threadIdx.x < active_threads) {
-      int c_gl_stride = prob_n / 8;
-      int c_gl_wr_delta_o = 8 * c_gl_stride;
-      int c_gl_wr_delta_i = 4 * (active_threads / 32);
-      int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4;
-      c_gl_wr += (2 * thread_n_blocks) * slice_col;
-      constexpr int c_sh_wr_delta = active_threads;
-      int c_sh_wr = threadIdx.x;
-
-      int row = (threadIdx.x % 32) / 4;
-
-      if (!first) {
-        // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns,
-        // hence we also use async-copies even though these fetches are not actually asynchronous.
-        #pragma unroll
-        for (int i = 0; i < thread_m_blocks * 4; i++) {
-          cp_async4_pred(
-            &sh[c_sh_wr + c_sh_wr_delta * i],
-            &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],
-            i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m
-          );
-        }
-        cp_async_fence();
-        cp_async_wait<0>();
-      }
-
-      #pragma unroll
-      for (int i = 0; i < thread_m_blocks * 4; i++) {
-        if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
-          if (!first) {
-            int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
-            #pragma unroll
-            for (int j = 0; j < 2 * 4; j++) {
-              reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float(
-                reinterpret_cast<__half*>(&c_red)[j]
-              );
-            }
-          }
-          if (!last) {
-            int4 c;
-            #pragma unroll
-            for (int j = 0; j < 2 * 4; j++) {
-              reinterpret_cast<__half*>(&c)[j] = __float2half(
-                reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]
-              );
-            }
-            C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c;
-          }
-        }
-      }
-    }
-  };
-
-  // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step,
-  // the reduction above is performed in fragment layout.
-  auto write_result = [&] () {
-    int c_gl_stride = prob_n / 8;
-    constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
-    int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
-    constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks));
-
-    int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
-    c_gl_wr += (2 * thread_n_blocks) * slice_col;
-    int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
-    c_sh_wr += 32 * (threadIdx.x / 32);
-    int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
-
-    int c_gl_wr_end = c_gl_stride * prob_m;
-
-    // We first reorder in shared memory to guarantee the most efficient final global write patterns
-    auto write = [&] (int idx, float c0, float c1, FragS& s) {
-      half2 res = __halves2half2(__float2half(c0), __float2half(c1));
-      if (group_blocks == -1) // for per-column quantization we finally apply the scale here
-        res = __hmul2(res, s[0]);
-      ((half2*) sh)[idx] = res;
-    };
-    if (threadIdx.x / 32 < thread_n_blocks / 4) {
-      #pragma unroll
-      for (int i = 0; i < thread_m_blocks; i++) {
-        #pragma unroll
-        for (int j = 0; j < 4; j++) {
-          int wr = c_sh_wr + 8 * j;
-          write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
-          write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
-          write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
-          write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
-        }
-        c_sh_wr += 16 * (4 * c_sh_stride);
-      }
-    }
-    __syncthreads();
-
-    #pragma unroll
-    for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {
-      if (c_gl_wr < c_gl_wr_end) {
-        C[c_gl_wr] = sh[c_sh_rd];
-        c_gl_wr += c_gl_wr_delta;
-        c_sh_rd += c_sh_rd_delta;
-      }
-    }
-  };
-
-  // Start global fetch and register load pipelines.
-  auto start_pipes = [&] () {
-    #pragma unroll
-    for (int i = 0; i < stages - 1; i++)
-      fetch_to_shared(i, i, i < slice_iters);
-    zero_accums();
-    wait_for_stage();
-    fetch_to_registers(0, 0);
-    a_gl_rd += a_gl_rd_delta_o * (stages - 1);
-  };
-  start_pipes();
-
-  // Main loop.
-  while (slice_iters) {
-    // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are
-    // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0.
-    #pragma unroll
-    for (int pipe = 0; pipe < stages;) {
-      #pragma unroll
-      for (int k = 0; k < b_sh_wr_iters; k++) {
-        fetch_to_registers(k + 1, pipe % stages);
-        if (k == b_sh_wr_iters - 2) {
-          fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages);
-          pipe++;
-          wait_for_stage();
-        }
-        matmul(k);
-      }
-      slice_iters--;
-      if (slice_iters == 0)
-        break;
-    }
-    a_gl_rd += a_gl_rd_delta_o * stages;
-
-    // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most
-    // readable, other ways of writing the loop seemed to noticeably worse performance after compliation.
-    if (slice_iters == 0) {
-      cp_async_wait<0>();
-      bool last = slice_idx == slice_count - 1;
-      // For per-column scales, we only fetch them here in the final step before write-out
-      if (group_blocks == -1 && last) {
-        if (s_sh_wr_pred)
-          cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
-        cp_async_fence();
-      }
-      thread_block_reduce();
-      if (group_blocks == -1 && last) {
-        cp_async_wait<0>();
-        __syncthreads();
-        if (threadIdx.x / 32 < thread_n_blocks / 4) {
-          reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
-          reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
-        }
-      }
-      if (slice_count > 1) { // only globally reduce if there is more than one block in a slice
-        barrier_acquire(&locks[slice_col], slice_idx);
-        global_reduce(slice_idx == 0, last);
-        barrier_release(&locks[slice_col], last);
-      }
-      if (last) // only the last block in a slice actually writes the result
-        write_result();
-      slice_row = 0;
-      slice_col++;
-      init_slice();
-      if (slice_iters) {
-        a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
-        #pragma unroll
-        for (int i = 0; i < b_sh_wr_iters; i++)
-          B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
-        s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
-        start_pipes();
-      }
-    }
-  }
-#endif
-}
-
-
-// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more
-// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles.
-const int THREADS = 256;
-const int STAGES = 4; // 4 pipeline stages fit into shared memory
-const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
-
-#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \
-  else if ( \
-    thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \
-    group_blocks == GROUP_BLOCKS \
-  ) { \
-    cudaFuncSetAttribute( \
-      Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
-      cudaFuncAttributeMaxDynamicSharedMemorySize, \
-      SHARED_MEM \
-    ); \
-    Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS><<<blocks, THREADS, SHARED_MEM, stream>>>( \
-      A_ptr, B_ptr, C_ptr, s_ptr, \
-      prob_m, prob_n, prob_k, \
-      locks \
-    ); \
-  }
-
-const int ERR_PROB_SHAPE = 1;
-const int ERR_KERN_SHAPE = 2;
-
-int marlin_cuda(
-  const void* A,
-  const void* B,
-        void* C,
-        void* s,
-  int prob_m,
-  int prob_n,
-  int prob_k,
-  void* workspace,
-  int groupsize = -1,
-  int dev = 0,
-  cudaStream_t stream = 0,
-  int thread_k = -1,
-  int thread_n = -1,
-  int sms = -1
-) {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  int tot_m = prob_m;
-  int tot_m_blocks = ceildiv(tot_m, 16);
-
-  if (sms == -1)
-    cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
-  if (thread_k == -1 || thread_n == -1) {
-    if (prob_m <= 16) {
-      // For small batchizes, better partioning is slightly more important than better compute utilization
-      thread_k = 128;
-      thread_n = 128;
-    } else {
-      thread_k = 64;
-      thread_n = 256;
-    }
-  }
-
-  int thread_k_blocks = thread_k / 16;
-  int thread_n_blocks = thread_n / 16;
-  int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
-  int blocks = sms;
-
-  if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0))
-    return ERR_PROB_SHAPE;
-  if (prob_m == 0 || prob_n == 0 || prob_k == 0)
-    return 0;
-
-  const int4* A_ptr = (const int4*) A;
-  const int4* B_ptr = (const int4*) B;
-  int4* C_ptr = (int4*) C;
-  const int4* s_ptr = (const int4*) s;
-
-  int cols = prob_n / thread_n;
-  int* locks = (int*) workspace;
-
-  int ret = 0;
-  for (int i = 0; i < tot_m_blocks; i += 4) {
-    int thread_m_blocks = tot_m_blocks - i;
-    prob_m = tot_m - 16 * i;
-    if (thread_m_blocks > 4) {
-      thread_m_blocks = 4;
-      prob_m = 64;
-    }
-
-    // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance)
-    // in our testing, however many more are, in principle, possible.
-    if (false) {}
-    CALL_IF(1,  8,  8, -1)
-    CALL_IF(1,  8,  8,  8)
-    CALL_IF(1, 16,  4, -1)
-    CALL_IF(1, 16,  4,  8)
-    CALL_IF(2, 16,  4, -1)
-    CALL_IF(2, 16,  4,  8)
-    CALL_IF(3, 16,  4, -1)
-    CALL_IF(3, 16,  4,  8)
-    CALL_IF(4, 16,  4, -1)
-    CALL_IF(4, 16,  4,  8)
-    else
-      ret = ERR_KERN_SHAPE;
-
-    A_ptr += 16 * thread_m_blocks * (prob_k / 8);
-    C_ptr += 16 * thread_m_blocks * (prob_n / 8);
-  }
-
-  return ret;
-#endif
-}
-
-#endif
-
-} // namespace marlin
-} // namespace aphrodite
-
-const int ERR_PROB_SHAPE = 1;
-const int ERR_KERN_SHAPE = 2;
-
-// input:     `torch.half` input matrix of shape `(m, k)` in standard row-major layout
-// weights:   `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()`
-// output:    `torch.half` out matrix of shape `(m, n)` in standard row-major layout
-// scales:    `torch.half` scales of shape `(m / groupsize, n)`
-// workspace: `torch.int` tensor with at least `n / 128` entries that are all zero
-
-void marlin_gemm(
-  const torch::Tensor& input,
-  const torch::Tensor& weights,
-        torch::Tensor& output,
-  const torch::Tensor& scales,
-        torch::Tensor& workspace
-) {
-#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  // thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1)
-  int thread_k = -1;
-  // thread_n: `n` size of a thread_tile in `weights` (can usually be left as auto -1)
-  int thread_n = -1;
-  // sms: number of SMs to use for the kernel (can usually be left as auto -1)
-  int sms = -1;
-
-  int prob_m = input.size(0);
-  int prob_n = output.size(1);
-  int prob_k = input.size(1);
-  int groupsize = (scales.size(0) == 1) ? -1 : prob_k / scales.size(0);
-  if (groupsize != -1 && groupsize * scales.size(0) != prob_k)
-    AT_ERROR("k=", prob_k, " not compatible with ", scales.size(0), " groups.");
-  int dev = input.get_device();
-  int err = aphrodite::marlin::marlin_cuda(
-    input.data_ptr(),
-    weights.data_ptr(),
-    output.data_ptr(),
-    scales.data_ptr(),
-    prob_m, prob_n, prob_k,
-    workspace.data_ptr(),
-    groupsize,
-    dev,
-    at::cuda::getCurrentCUDAStream(dev),
-    thread_k,
-    thread_n,
-    sms
-  );
-  if (err == ERR_PROB_SHAPE) {
-    AT_ERROR(
-      "Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")",
-      " not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "."
-    );
-  } else if (err == ERR_KERN_SHAPE) {
-    AT_ERROR(
-      "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "."
-    );
-  }
-#endif
-}
+ #include <torch/extension.h>
+
+ #include <ATen/cuda/CUDAContext.h>
+ #include <c10/cuda/CUDAGuard.h>
+ #include <cuda.h>
+ #include <cuda_fp16.h>
+ #include <cuda_runtime.h>
+ 
+ #include <iostream>
+ 
+ template <typename T> inline std::string str(T x) { return std::to_string(x); }
+ 
+ namespace marlin {
+ 
+ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
+ 
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+ 
+ // Instances of `Vec` are used to organize groups of >>registers<<, as needed
+ // for instance as inputs to tensor core operations. Consequently, all
+ // corresponding index accesses must be compile-time constants, which is why we
+ // extensively use `#pragma unroll` throughout the kernel code to guarantee
+ // this.
+ template <typename T, int n> struct Vec {
+   T elems[n];
+   __device__ T &operator[](int i) { return elems[i]; }
+ };
+ 
+ using I4 = Vec<int, 4>;
+ 
+ // Matrix fragments for tensor core instructions; their precise layout is
+ // documented here:
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+ using FragA = Vec<half2, 4>;
+ using FragB = Vec<half2, 2>;
+ using FragC = Vec<float, 4>;
+ using FragS = Vec<half2, 1>; // quantization scales
+ 
+ // Predicated asynchronous global->shared copy; used for inputs A where we apply
+ // predication to handle batchsizes that are not multiples of 16.
+ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
+                                       bool pred = true) {
+   const int BYTES = 16;
+   uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
+   asm volatile("{\n"
+                "   .reg .pred p;\n"
+                "   setp.ne.b32 p, %0, 0;\n"
+                "   @p cp.async.cg.shared.global [%1], [%2], %3;\n"
+                "}\n" ::"r"((int)pred),
+                "r"(smem), "l"(glob_ptr), "n"(BYTES));
+ }
+ 
+ // Asynchronous global->shared copy with a cache hint indicating that the values
+ // may be evicted immediately; used for quantized weights B, which are only
+ // accessed precisely once and should thus not pollute the L2 cache which we
+ // need for inputs A and outputs C.
+ __device__ inline void cp_async4_stream(void *smem_ptr, const void *glob_ptr) {
+   const int BYTES = 16;
+   uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
+   asm volatile(
+       "{\n"
+       "   .reg .b64 p;\n"
+       "   createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
+       "   cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
+       "}\n" ::"r"(smem),
+       "l"(glob_ptr), "n"(BYTES));
+ }
+ 
+ // Async copy fence.
+ __device__ inline void cp_async_fence() {
+   asm volatile("cp.async.commit_group;\n" ::);
+ }
+ 
+ // Wait until at most `n` async copy stages are still pending.
+ template <int n> __device__ inline void cp_async_wait() {
+   asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
+ }
+ 
+ // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
+ // output/accumulation.
+ __device__ inline void mma(const FragA &a_frag, const FragB &frag_b,
+                            FragC &frag_c) {
+   const uint32_t *a = reinterpret_cast<const uint32_t *>(&a_frag);
+   const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b);
+   float *c = reinterpret_cast<float *>(&frag_c);
+   asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
+                "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
+                : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
+                : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
+                  "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
+ }
+ 
+ // Instruction for loading a full 16x16 matrix fragment of operand A from shared
+ // memory, directly in tensor core layout.
+ __device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
+   uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
+   uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
+   asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
+                : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
+                : "r"(smem));
+ }
+ 
+ // Lookup-table based 3-input logical operation; explicitly used for
+ // dequantization as the compiler does not seem to automatically recognize it in
+ // all cases.
+ template <int lut> __device__ inline int lop3(int a, int b, int c) {
+   int res;
+   asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+                : "=r"(res)
+                : "r"(a), "r"(b), "r"(c), "n"(lut));
+   return res;
+ }
+ 
+ // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
+ // values. We mostly follow the strategy in the link below, with some small
+ // changes:
+ // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
+ __device__ inline FragB dequant(int q) {
+   const int LO = 0x000f000f;
+   const int HI = 0x00f000f0;
+   const int EX = 0x64006400;
+   // Guarantee that the `(a & b) | c` operations are LOP3s.
+   int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
+   int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
+   // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
+   // directly into `SUB` and `ADD`.
+   const int SUB = 0x64086408;
+   const int MUL = 0x2c002c00;
+   const int ADD = 0xd480d480;
+   FragB frag_b;
+   frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
+                       *reinterpret_cast<const half2 *>(&SUB));
+   frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi),
+                       *reinterpret_cast<const half2 *>(&MUL),
+                       *reinterpret_cast<const half2 *>(&ADD));
+   return frag_b;
+ }
+ 
+ // Multiply dequantized values by the corresponding quantization scale; used
+ // only for grouped quantization.
+ __device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
+   half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]);
+   frag_b[0] = __hmul2(frag_b[0], s);
+   frag_b[1] = __hmul2(frag_b[1], s);
+ }
+ 
+ // Wait until barrier reaches `count`, then lock for current threadblock.
+ __device__ inline void barrier_acquire(int *lock, int count) {
+   if (threadIdx.x == 0) {
+     int state = -1;
+     do
+       // Guarantee that subsequent writes by this threadblock will be visible
+       // globally.
+       asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
+                    : "=r"(state)
+                    : "l"(lock));
+     while (state != count);
+   }
+   __syncthreads();
+ }
+ 
+ // Release barrier and increment visitation count.
+ __device__ inline void barrier_release(int *lock, bool reset = false) {
+   __syncthreads();
+   if (threadIdx.x == 0) {
+     if (reset) {
+       lock[0] = 0;
+       return;
+     }
+     int val = 1;
+     // Make sure that all writes since acquiring this barrier are visible
+     // globally, while releasing the barrier.
+     asm volatile("fence.acq_rel.gpu;\n");
+     asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
+                  :
+                  : "l"(lock), "r"(val));
+   }
+ }
+ 
+ template <const int threads,         // number of threads in a threadblock
+           const int thread_m_blocks, // number of 16x16 blocks in the m
+                                      // dimension (batchsize) of the threadblock
+           const int thread_n_blocks, // same for n dimension (output)
+           const int thread_k_blocks, // same for k dimension (reduction)
+           const int stages, // number of stages for the async global->shared
+                             // fetch pipeline
+           const int group_blocks = -1 // number of consecutive 16x16 blocks with
+                                       // a separate quantization scale
+           >
+ __global__ void
+ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
+        const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
+        int4 *__restrict__ C,       // fp16 output buffer of shape mxn
+        const int4
+            *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
+        int prob_m,          // batch dimension m
+        int prob_n,          // output dimension n
+        int prob_k,          // reduction dimension k
+        int *locks           // extra global storage for barrier synchronization
+ ) {
+   // Each threadblock processes one "stripe" of the B matrix with (roughly) the
+   // same size, which might involve multiple column "slices" (of width 16 *
+   // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
+   // example:
+   //   0 1 3
+   //   0 2 3
+   //   1 2 4
+   // While this kind of partitioning makes things somewhat more complicated, it
+   // ensures good utilization of all SMs for many kinds of shape and GPU
+   // configurations, while requiring as few slow global cross-threadblock
+   // reductions as possible.
+ 
+   // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
+   // better partitioning with less reductions
+   int parallel = 1;
+   if (prob_m > 16 * thread_m_blocks) {
+     parallel = prob_m / (16 * thread_m_blocks);
+     prob_m = 16 * thread_m_blocks;
+   }
+ 
+   int k_tiles = prob_k / 16 / thread_k_blocks;
+   int n_tiles = prob_n / 16 / thread_n_blocks;
+   int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
+   // Ensure that the number of tiles in each stripe is a multiple of the
+   // groupsize; this avoids an annoying special case where a stripe starts in
+   // the middle of group.
+   if (group_blocks != -1)
+     iters = (group_blocks / thread_k_blocks) *
+             ceildiv(iters, (group_blocks / thread_k_blocks));
+ 
+   int slice_row = (iters * blockIdx.x) % k_tiles;
+   int slice_col_par = (iters * blockIdx.x) / k_tiles;
+   int slice_col = slice_col_par;
+   int slice_iters; // number of threadblock tiles in the current slice
+   int slice_count =
+       0;         // total number of active threadblocks in the current slice
+   int slice_idx; // index of threadblock in current slice; numbered bottom to
+                  // top
+ 
+   // We can easily implement parallel problem execution by just remapping
+   // indices and advancing global pointers
+   if (slice_col_par >= n_tiles) {
+     A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
+     C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
+     locks += (slice_col_par / n_tiles) * n_tiles;
+     slice_col = slice_col_par % n_tiles;
+   }
+ 
+   // Compute all information about the current slice which is required for
+   // synchronization.
+   auto init_slice = [&]() {
+     slice_iters =
+         iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
+     if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
+       slice_iters = 0;
+     if (slice_iters == 0)
+       return;
+     if (slice_row + slice_iters > k_tiles)
+       slice_iters = k_tiles - slice_row;
+     slice_count = 1;
+     slice_idx = 0;
+     int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
+     if (col_first <= k_tiles * (slice_col_par + 1)) {
+       int col_off = col_first - k_tiles * slice_col_par;
+       slice_count = ceildiv(k_tiles - col_off, iters);
+       if (col_off > 0)
+         slice_count++;
+       int delta_first = iters * blockIdx.x - col_first;
+       if (delta_first < 0 || (col_off == 0 && delta_first == 0))
+         slice_idx = slice_count - 1;
+       else {
+         slice_idx = slice_count - 1 - delta_first / iters;
+         if (col_off > 0)
+           slice_idx--;
+       }
+     }
+     if (slice_col == n_tiles) {
+       A += 16 * thread_m_blocks * prob_k / 8;
+       C += 16 * thread_m_blocks * prob_n / 8;
+       locks += n_tiles;
+       slice_col = 0;
+     }
+   };
+   init_slice();
+ 
+   int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
+   // We typically use `constexpr` to indicate that this value is a compile-time
+   // constant
+   constexpr int a_sh_stride =
+       16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
+   constexpr int a_gl_rd_delta_o =
+       16 * thread_k_blocks /
+       8; // delta between subsequent A tiles in global memory
+   int a_gl_rd_delta_i =
+       a_gl_stride *
+       (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
+   constexpr int a_sh_wr_delta =
+       a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes
+   constexpr int a_sh_rd_delta_o =
+       2 * ((threads / 32) /
+            (thread_n_blocks / 4)); // between shared memory tile reads
+   constexpr int a_sh_rd_delta_i =
+       a_sh_stride * 16; // within a shared memory tile
+   constexpr int a_sh_stage =
+       a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
+   constexpr int a_sh_wr_iters =
+       ceildiv(a_sh_stage,
+               a_sh_wr_delta); // number of shared write iterations for a tile
+ 
+   int b_gl_stride = 16 * prob_n / 32;
+   constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
+   int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
+   int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
+   constexpr int b_sh_wr_delta = threads;
+   constexpr int b_sh_rd_delta = threads;
+   constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
+   constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
+ 
+   int s_gl_stride = prob_n / 8;
+   constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
+   constexpr int s_sh_stage = s_sh_stride;
+   int s_gl_rd_delta = s_gl_stride;
+ 
+   // Global A read index of current thread.
+   int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
+                 (threadIdx.x % a_gl_rd_delta_o);
+   a_gl_rd += a_gl_rd_delta_o * slice_row;
+   // Shared write index of current thread.
+   int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
+                 (threadIdx.x % a_gl_rd_delta_o);
+   // Shared read index.
+   int a_sh_rd =
+       a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
+   a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
+ 
+   int b_gl_rd =
+       b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
+   b_gl_rd += b_sh_stride * slice_col;
+   b_gl_rd += b_gl_rd_delta_o * slice_row;
+   int b_sh_wr = threadIdx.x;
+   int b_sh_rd = threadIdx.x;
+ 
+   int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
+                 s_sh_stride * slice_col + threadIdx.x;
+   int s_sh_wr = threadIdx.x;
+   int s_sh_rd;
+   // We use a different scale layout for grouped and column-wise quantization as
+   // we scale a `half2` tile in column-major layout in the former and in
+   // row-major in the latter case.
+   if (group_blocks != -1)
+     s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
+               (threadIdx.x % 32) / 4;
+   else
+     s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
+               (threadIdx.x % 32) % 4;
+ 
+   // Precompute which thread should not read memory in which iterations; this is
+   // needed if there are more threads than required for a certain tilesize or
+   // when the batchsize is not a multiple of 16.
+   bool a_sh_wr_pred[a_sh_wr_iters];
+ #pragma unroll
+   for (int i = 0; i < a_sh_wr_iters; i++)
+     a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
+   bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
+ 
+   // To ensure that writing and reading A tiles to/from shared memory, the
+   // latter in fragment format, is fully bank conflict free, we need to use a
+   // rather fancy XOR-based layout. The key here is that neither reads nor
+   // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
+   // same shared memory banks. Further, it seems (based on NSight-Compute) that
+   // each warp must also write a consecutive memory segment?
+   auto transform_a = [&](int i) {
+     int row = i / a_gl_rd_delta_o;
+     return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
+   };
+   // Since the computation of this remapping is non-trivial and, due to our main
+   // loop unrolls, all shared memory accesses are static, we simply precompute
+   // both transformed reads and writes.
+   int a_sh_wr_trans[a_sh_wr_iters];
+ #pragma unroll
+   for (int i = 0; i < a_sh_wr_iters; i++)
+     a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
+   int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
+ #pragma unroll
+   for (int i = 0; i < b_sh_wr_iters; i++) {
+ #pragma unroll
+     for (int j = 0; j < thread_m_blocks; j++)
+       a_sh_rd_trans[i][j] =
+           transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
+   }
+ 
+   // Since B-accesses have non-constant stride they have to be computed at
+   // runtime; we break dependencies between subsequent accesses with a tile by
+   // maintining multiple pointers (we have enough registers), a tiny
+   // optimization.
+   const int4 *B_ptr[b_sh_wr_iters];
+ #pragma unroll
+   for (int i = 0; i < b_sh_wr_iters; i++)
+     B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
+ 
+   extern __shared__ int4 sh[];
+   // Shared memory storage for global fetch pipelines.
+   int4 *sh_a = sh;
+   int4 *sh_b = sh_a + (stages * a_sh_stage);
+   int4 *sh_s = sh_b + (stages * b_sh_stage);
+   // Register storage for double buffer of shared memory reads.
+   FragA frag_a[2][thread_m_blocks];
+   I4 frag_b_quant[2];
+   FragC frag_c[thread_m_blocks][4][2];
+   FragS frag_s[2][4];
+ 
+   // Zero accumulators.
+   auto zero_accums = [&]() {
+ #pragma unroll
+     for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
+       reinterpret_cast<float *>(frag_c)[i] = 0;
+   };
+ 
+   // Asynchronously fetch the next A, B and s tile from global to the next
+   // shared memory pipeline location.
+   auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
+     if (pred) {
+       int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
+ #pragma unroll
+       for (int i = 0; i < a_sh_wr_iters; i++) {
+         cp_async4_pred(
+             &sh_a_stage[a_sh_wr_trans[i]],
+             &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
+             a_sh_wr_pred[i]);
+       }
+       int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
+ #pragma unroll
+       for (int i = 0; i < b_sh_wr_iters; i++) {
+         cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
+         B_ptr[i] += b_gl_rd_delta_o;
+       }
+       // Only fetch scales if this tile starts a new group
+       if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
+         int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
+         if (s_sh_wr_pred)
+           cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
+         s_gl_rd += s_gl_rd_delta;
+       }
+     }
+     // Insert a fence even when we are winding down the pipeline to ensure that
+     // waiting is also correct at this point.
+     cp_async_fence();
+   };
+ 
+   // Wait until the next thread tile has been loaded to shared memory.
+   auto wait_for_stage = [&]() {
+     // We only have `stages - 2` active fetches since we are double buffering
+     // and can only issue the next fetch when it is guaranteed that the previous
+     // shared memory load is fully complete (as it may otherwise be
+     // overwritten).
+     cp_async_wait<stages - 2>();
+     __syncthreads();
+   };
+ 
+   // Load the next sub-tile from the current location in the shared memory pipe
+   // into the current register buffer.
+   auto fetch_to_registers = [&](int k, int pipe) {
+     // It may seem inefficient that we reload the groups for every sub-tile;
+     // however, this does not seem to be a significant bottleneck, while some
+     // theoretically better attempts have lead to bad instruction ordering by
+     // the compiler and correspondingly a noticeable drop in performance.
+     if (group_blocks != -1) {
+       int4 *sh_s_stage =
+           sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
+                                (pipe / (group_blocks / thread_k_blocks)));
+       reinterpret_cast<int4 *>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
+     }
+     int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
+ #pragma unroll
+     for (int i = 0; i < thread_m_blocks; i++)
+       ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
+     int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
+     frag_b_quant[k % 2] = *reinterpret_cast<I4 *>(
+         &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
+   };
+ 
+   // Execute the actual tensor core matmul of a sub-tile.
+   auto matmul = [&](int k) {
+ // We have the m dimension as the inner loop in order to encourage overlapping
+ // dequantization and matmul operations.
+ #pragma unroll
+     for (int j = 0; j < 4; j++) {
+       int b_quant = frag_b_quant[k % 2][j];
+       int b_quant_shift = b_quant >> 8;
+       FragB frag_b0 = dequant(b_quant);
+       // If there are no groups, we can just scale the final output once and can
+       // avoid doing so for each weight.
+       if (group_blocks != -1)
+         scale(frag_b0, frag_s[k % 2][j], 0);
+       FragB frag_b1 = dequant(b_quant_shift);
+       if (group_blocks != -1)
+         scale(frag_b1, frag_s[k % 2][j], 1);
+ #pragma unroll
+       for (int i = 0; i < thread_m_blocks; i++) {
+         mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
+         mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
+       }
+     }
+   };
+ 
+   // Since we slice across the k dimension of a tile in order to increase the
+   // number of warps while keeping the n dimension of a tile reasonable, we have
+   // multiple warps that accumulate their partial sums of the same output
+   // location; which we have to reduce over in the end. We do in shared memory.
+   auto thread_block_reduce = [&]() {
+     constexpr int red_off = threads / b_sh_stride / 2;
+     if (red_off >= 1) {
+       int red_idx = threadIdx.x / b_sh_stride;
+       constexpr int red_sh_stride = b_sh_stride * 4 * 2;
+       constexpr int red_sh_delta = b_sh_stride;
+       int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
+                       (threadIdx.x % b_sh_stride);
+ 
+       // Parallel logarithmic shared memory reduction. We make sure to avoid any
+       // unnecessary read or write iterations, e.g., for two warps we write only
+       // once by warp 1 and read only once by warp 0.
+ 
+ #pragma unroll
+       for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
+ #pragma unroll
+         for (int i = red_off; i > 0; i /= 2) {
+           if (i <= red_idx && red_idx < 2 * i) {
+ #pragma unroll
+             for (int j = 0; j < 4 * 2; j++) {
+               int red_sh_wr =
+                   red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
+               if (i < red_off) {
+                 float *c_rd = reinterpret_cast<float *>(
+                     &sh[red_sh_delta * j + red_sh_rd]);
+                 float *c_wr = reinterpret_cast<float *>(&sh[red_sh_wr]);
+ #pragma unroll
+                 for (int k = 0; k < 4; k++)
+                   reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + j][k] +=
+                       c_rd[k] + c_wr[k];
+               }
+               sh[red_sh_wr] =
+                   reinterpret_cast<int4 *>(&frag_c)[4 * 2 * m_block + j];
+             }
+           }
+           __syncthreads();
+         }
+         if (red_idx == 0) {
+ #pragma unroll
+           for (int i = 0; i < 4 * 2; i++) {
+             float *c_rd =
+                 reinterpret_cast<float *>(&sh[red_sh_delta * i + red_sh_rd]);
+ #pragma unroll
+             for (int j = 0; j < 4; j++)
+               reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + i][j] +=
+                   c_rd[j];
+           }
+         }
+         __syncthreads();
+       }
+     }
+   };
+ 
+   // Since multiple threadblocks may process parts of the same column slice, we
+   // finally have to globally reduce over the results. As the striped partitioning
+   // minimizes the number of such reductions and our outputs are usually rather
+   // small, we perform this reduction serially in L2 cache.
+   auto global_reduce = [&](bool first = false, bool last = false) {
+     // We are very careful here to reduce directly in the output buffer to
+     // maximize L2 cache utilization in this step. To do this, we write out
+     // results in FP16 (but still reduce with FP32 compute).
+     constexpr int active_threads = 32 * thread_n_blocks / 4;
+     if (threadIdx.x < active_threads) {
+       int c_gl_stride = prob_n / 8;
+       int c_gl_wr_delta_o = 8 * c_gl_stride;
+       int c_gl_wr_delta_i = 4 * (active_threads / 32);
+       int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
+                     4 * (threadIdx.x / 32) + threadIdx.x % 4;
+       c_gl_wr += (2 * thread_n_blocks) * slice_col;
+       constexpr int c_sh_wr_delta = active_threads;
+       int c_sh_wr = threadIdx.x;
+ 
+       int row = (threadIdx.x % 32) / 4;
+ 
+       if (!first) {
+ // Interestingly, doing direct global accesses here really seems to mess up the
+ // compiler and lead to slowdowns, hence we also use async-copies even though
+ // these fetches are not actually asynchronous.
+ #pragma unroll
+         for (int i = 0; i < thread_m_blocks * 4; i++) {
+           cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
+                          &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
+                             c_gl_wr_delta_i * (i % 2)],
+                          i < (thread_m_blocks - 1) * 4 ||
+                              8 * (i / 2) + row < prob_m);
+         }
+         cp_async_fence();
+         cp_async_wait<0>();
+       }
+ 
+ #pragma unroll
+       for (int i = 0; i < thread_m_blocks * 4; i++) {
+         if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
+           if (!first) {
+             int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
+ #pragma unroll
+             for (int j = 0; j < 2 * 4; j++) {
+               reinterpret_cast<float *>(
+                   &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
+                   __half2float(reinterpret_cast<__half *>(&c_red)[j]);
+             }
+           }
+           if (!last) {
+             int4 c;
+ #pragma unroll
+             for (int j = 0; j < 2 * 4; j++) {
+               reinterpret_cast<__half *>(&c)[j] =
+                   __float2half(reinterpret_cast<float *>(
+                       &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
+             }
+             C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
+                 c;
+           }
+         }
+       }
+     }
+   };
+ 
+   // Write out the reduce final result in the correct layout. We only actually
+   // reshuffle matrix fragments in this step, the reduction above is performed
+   // in fragment layout.
+   auto write_result = [&]() {
+     int c_gl_stride = prob_n / 8;
+     constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
+     int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
+     constexpr int c_sh_rd_delta =
+         c_sh_stride * (threads / (2 * thread_n_blocks));
+ 
+     int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
+                   (threadIdx.x % (2 * thread_n_blocks));
+     c_gl_wr += (2 * thread_n_blocks) * slice_col;
+     int c_sh_wr =
+         (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
+     c_sh_wr += 32 * (threadIdx.x / 32);
+     int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
+                   (threadIdx.x % (2 * thread_n_blocks));
+ 
+     int c_gl_wr_end = c_gl_stride * prob_m;
+ 
+     // We first reorder in shared memory to guarantee the most efficient final
+     // global write patterns
+     auto write = [&](int idx, float c0, float c1, FragS &s) {
+       half2 res = __halves2half2(__float2half(c0), __float2half(c1));
+       if (group_blocks ==
+           -1) // for per-column quantization we finally apply the scale here
+         res = __hmul2(res, s[0]);
+       ((half2 *)sh)[idx] = res;
+     };
+     if (threadIdx.x / 32 < thread_n_blocks / 4) {
+ #pragma unroll
+       for (int i = 0; i < thread_m_blocks; i++) {
+ #pragma unroll
+         for (int j = 0; j < 4; j++) {
+           int wr = c_sh_wr + 8 * j;
+           write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
+                 frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
+           write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
+                 frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
+           write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
+                 frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
+           write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
+                 frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
+         }
+         c_sh_wr += 16 * (4 * c_sh_stride);
+       }
+     }
+     __syncthreads();
+ 
+ #pragma unroll
+     for (int i = 0;
+          i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
+          i++) {
+       if (c_gl_wr < c_gl_wr_end) {
+         C[c_gl_wr] = sh[c_sh_rd];
+         c_gl_wr += c_gl_wr_delta;
+         c_sh_rd += c_sh_rd_delta;
+       }
+     }
+   };
+ 
+   // Start global fetch and register load pipelines.
+   auto start_pipes = [&]() {
+ #pragma unroll
+     for (int i = 0; i < stages - 1; i++)
+       fetch_to_shared(i, i, i < slice_iters);
+     zero_accums();
+     wait_for_stage();
+     fetch_to_registers(0, 0);
+     a_gl_rd += a_gl_rd_delta_o * (stages - 1);
+   };
+   start_pipes();
+ 
+   // Main loop.
+   while (slice_iters) {
+ // We unroll over both the global fetch and the register load pipeline to ensure
+ // all shared memory accesses are static. Note that both pipelines have even
+ // length meaning that the next iteration will always start at index 0.
+ #pragma unroll
+     for (int pipe = 0; pipe < stages;) {
+ #pragma unroll
+       for (int k = 0; k < b_sh_wr_iters; k++) {
+         fetch_to_registers(k + 1, pipe % stages);
+         if (k == b_sh_wr_iters - 2) {
+           fetch_to_shared((pipe + stages - 1) % stages, pipe,
+                           slice_iters >= stages);
+           pipe++;
+           wait_for_stage();
+         }
+         matmul(k);
+       }
+       slice_iters--;
+       if (slice_iters == 0)
+         break;
+     }
+     a_gl_rd += a_gl_rd_delta_o * stages;
+ 
+     // Process results and, if necessary, proceed to the next column slice.
+     // While this pattern may not be the most readable, other ways of writing
+     // the loop seemed to noticeably worse performance after compilation.
+     if (slice_iters == 0) {
+       cp_async_wait<0>();
+       bool last = slice_idx == slice_count - 1;
+       // For per-column scales, we only fetch them here in the final step before
+       // write-out
+       if (group_blocks == -1 && last) {
+         if (s_sh_wr_pred)
+           cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
+         cp_async_fence();
+       }
+       thread_block_reduce();
+       if (group_blocks == -1 && last) {
+         cp_async_wait<0>();
+         __syncthreads();
+         if (threadIdx.x / 32 < thread_n_blocks / 4) {
+           reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0];
+           reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4];
+         }
+       }
+       if (slice_count > 1) { // only globally reduce if there is more than one
+                              // block in a slice
+         barrier_acquire(&locks[slice_col], slice_idx);
+         global_reduce(slice_idx == 0, last);
+         barrier_release(&locks[slice_col], last);
+       }
+       if (last) // only the last block in a slice actually writes the result
+         write_result();
+       slice_row = 0;
+       slice_col_par++;
+       slice_col++;
+       init_slice();
+       if (slice_iters) {
+         a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
+                   (threadIdx.x % a_gl_rd_delta_o);
+ #pragma unroll
+         for (int i = 0; i < b_sh_wr_iters; i++)
+           B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
+         if (slice_col == 0) {
+ #pragma unroll
+           for (int i = 0; i < b_sh_wr_iters; i++)
+             B_ptr[i] -= b_gl_stride;
+         }
+         s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
+         start_pipes();
+       }
+     }
+   }
+ }
+ 
+ #else
+ 
+ template <const int threads,         // number of threads in a threadblock
+           const int thread_m_blocks, // number of 16x16 blocks in the m
+                                      // dimension (batchsize) of the threadblock
+           const int thread_n_blocks, // same for n dimension (output)
+           const int thread_k_blocks, // same for k dimension (reduction)
+           const int stages, // number of stages for the async global->shared
+                             // fetch pipeline
+           const int group_blocks = -1 // number of consecutive 16x16 blocks with
+                                       // a separate quantization scale
+           >
+ __global__ void
+ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
+        const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
+        int4 *__restrict__ C,       // fp16 output buffer of shape mxn
+        const int4
+            *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
+        int prob_m,          // batch dimension m
+        int prob_n,          // output dimension n
+        int prob_k,          // reduction dimension k
+        int *locks           // extra global storage for barrier synchronization
+ ) {
+   // Marlin is not implemented yet for SM < 8.0
+   assert(false);
+   return;
+ }
+ 
+ #endif
+ 
+ // 8 warps are a good choice since every SM has 4 schedulers and having more
+ // than 1 warp per schedule allows some more latency hiding. At the same time,
+ // we want relatively few warps to have many registers per warp and small tiles.
+ const int USER_THREADS =
+     256;              // Note: This is only used with user-provided thread_k/n
+ const int STAGES = 4; // 4 pipeline stages fit into shared memory
+ const int SHARED_MEM =
+     96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
+ 
+ static constexpr int min_thread_n = 64;
+ static constexpr int min_thread_k = 64;
+ 
+ static constexpr int tile_size = 16;
+ static constexpr int max_par = 16;
+ 
+ static constexpr int pack_factor_4bit =
+     8; // We have 8 4-bit vals inside a 32 bit
+ 
+ #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS,           \
+                   GROUP_BLOCKS, NUM_THREADS)                                   \
+   else if (thread_m_blocks == THREAD_M_BLOCKS &&                               \
+            thread_n_blocks == THREAD_N_BLOCKS &&                               \
+            thread_k_blocks == THREAD_K_BLOCKS &&                               \
+            group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) {       \
+     cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
+                                 THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>,        \
+                          cudaFuncAttributeMaxDynamicSharedMemorySize,          \
+                          SHARED_MEM);                                          \
+     Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS,     \
+            STAGES, GROUP_BLOCKS><<<blocks, NUM_THREADS, SHARED_MEM, stream>>>( \
+         A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks);            \
+   }
+ 
+ typedef struct {
+   int thread_k;
+   int thread_n;
+   int num_threads;
+ } thread_config_t;
+ 
+ thread_config_t small_batch_thread_configs[] = {
+     // Ordered by priority
+ 
+     // thread_k, thread_n, num_threads
+     {128, 128, 256}, // Default
+     {128, 64, 128},  // Reduce N 2X, same K
+     {64, 256, 256},  // Reduce K 2X, increase N 2X
+     {64, 128, 128},  // Reduce K 2X, same N
+ };
+ 
+ thread_config_t large_batch_thread_configs[] = {
+     // Ordered by priority
+ 
+     // thread_k, thread_n, num_threads
+     {64, 256, 256},  // Default
+     {128, 128, 256}, // Reduce N 2X, increase K 2X
+     {64, 128, 128},  // Reduce N 2X, same K
+     {128, 64, 128},  // Reduce N 4X, increase K 2X
+ };
+ 
+ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
+                      int prob_k) {
+   // Sanity
+   if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
+       th_config.num_threads == -1) {
+     return false;
+   }
+ 
+   // Verify K/N are divisible by thread K/N
+   if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
+     return false;
+   }
+ 
+   // thread_k can be only 128 or 64 (because it must be less than groupsize
+   // which is 128)
+   if (th_config.thread_k != 128 && th_config.thread_k != 64) {
+     return false;
+   }
+ 
+   // Verify min for thread K/N
+   if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
+     return false;
+   }
+ 
+   // num_threads must be at least 128 (= 4 warps)
+   if (th_config.num_threads < 128) {
+     return false;
+   }
+ 
+   return true;
+ }
+ 
+ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
+ 
+   if (prob_m <= 16) {
+     for (auto th_config : small_batch_thread_configs) {
+       if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
+         return th_config;
+       }
+     }
+ 
+   } else {
+     for (auto th_config : large_batch_thread_configs) {
+       if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
+         return th_config;
+       }
+     }
+   }
+ 
+   return thread_config_t{-1, -1, -1};
+ }
+ 
+ #define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS)                               \
+   __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)                            \
+   __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)                             \
+   __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)                            \
+   __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)                             \
+   __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)                            \
+   __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)                             \
+   __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)                            \
+   __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)                             \
+   __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)                            \
+   __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
+ 
+ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
+                  int prob_n, int prob_k, void *workspace, int groupsize = -1,
+                  int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
+                  int thread_n = -1, int sms = -1, int max_par = 16) {
+   int tot_m = prob_m;
+   int tot_m_blocks = ceildiv(tot_m, 16);
+   int pad = 16 * tot_m_blocks - tot_m;
+ 
+   if (sms == -1)
+     cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
+ 
+   // Set thread config
+   thread_config_t th_config;
+   if (thread_k != -1 && thread_n != -1) {
+     // User-defined config
+     th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
+   } else {
+     // Auto config
+     th_config = determine_thread_config(prob_m, prob_n, prob_k);
+   }
+ 
+   if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
+     throw std::runtime_error(
+         "Invalid thread config: thread_k = " + str(th_config.thread_k) +
+         ", thread_n = " + str(th_config.thread_n) +
+         ", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
+         str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
+   }
+ 
+   // Uncomment for debug
+   // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) +
+   //                  ", thread_n = " + str(th_config.thread_n) +
+   //                  ", num_threads = " + str(th_config.num_threads) + " for
+   //                  MKN = [" + str(prob_m) +
+   //                  ", " + str(prob_k) + ", " + str(prob_n) + "]\n";
+ 
+   int num_threads = th_config.num_threads;
+   thread_k = th_config.thread_k;
+   thread_n = th_config.thread_n;
+ 
+   int thread_k_blocks = thread_k / 16;
+   int thread_n_blocks = thread_n / 16;
+   int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
+   int blocks = sms;
+ 
+   if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
+     return;
+   }
+ 
+   TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
+               " is not divisible by thread_n = ", thread_n);
+   TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
+               " is not divisible by thread_k = ", thread_k);
+   if (group_blocks != -1) {
+     TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
+                 " is not divisible by group_blocks = ", group_blocks);
+   }
+ 
+   const int4 *A_ptr = (const int4 *)A;
+   const int4 *B_ptr = (const int4 *)B;
+   int4 *C_ptr = (int4 *)C;
+   const int4 *s_ptr = (const int4 *)s;
+ 
+   int *locks = (int *)workspace;
+ 
+   for (int i = 0; i < tot_m_blocks; i += 4) {
+     int thread_m_blocks = tot_m_blocks - i;
+     prob_m = tot_m - 16 * i;
+     int par = 1;
+     if (thread_m_blocks > 4) {
+       // Note that parallel > 1 currently only works for inputs without any
+       // padding
+       par = (16 * thread_m_blocks - pad) / 64;
+       if (par > max_par)
+         par = max_par;
+       prob_m = 64 * par;
+       i += 4 * (par - 1);
+       thread_m_blocks = 4;
+     }
+ 
+     // For compilation speed, we only define the kernel configurations that have
+     // seemed useful (in terms of performance) in our testing, however many more
+     // are, in principle, possible.
+     if (false) {
+     }
+     CALL_IF(8, 8, 256)
+     CALL_IF(16, 4, 256)
+     CALL_IF(8, 4, 128)
+     CALL_IF(4, 8, 128)
+     else {
+       throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
+                                ", " + str(prob_k) + ", " + str(prob_n) + "]" +
+                                ", groupsize = " + str(groupsize) +
+                                ", thread_m_blocks = " + str(thread_m_blocks) +
+                                ", thread_n_blocks = " + str(thread_n_blocks) +
+                                ", thread_k_blocks = " + str(thread_k_blocks));
+     }
+ 
+     A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
+     C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
+   }
+ }
+ 
+ } // namespace marlin
+ 
+ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
+                           torch::Tensor &b_scales, torch::Tensor &workspace,
+                           int64_t size_m, int64_t size_n, int64_t size_k) {
+ 
+   // Verify M
+   TORCH_CHECK(size_m == a.size(0),
+               "Shape mismatch: a.size(0) = " + str(a.size(0)) +
+                   ", size_m = " + str(size_m));
+ 
+   // Verify K
+   TORCH_CHECK(size_k == a.size(1),
+               "Shape mismatch: a.size(1) = " + str(a.size(1)) +
+                   ", size_k = " + str(size_k));
+   TORCH_CHECK(size_k % marlin::tile_size == 0,
+               "size_k = " + str(size_k) +
+                   " is not divisible by tile_size = " + str(marlin::tile_size));
+   TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
+               "Shape mismatch: b_q_weight.size(0) = " +
+                   str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
+                   ", tile_size = " + str(marlin::tile_size));
+ 
+   // Verify N
+   TORCH_CHECK(b_scales.size(1) == size_n,
+               "b_scales.size(1) = " + str(b_scales.size(1)) +
+                   ", size_n = " + str(size_n));
+   TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
+               "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
+                   " is not divisible by tile_size = " + str(marlin::tile_size));
+ 
+   int actual_size_n =
+       (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
+   TORCH_CHECK(size_n == actual_size_n,
+               "size_n = " + str(size_n) +
+                   ", actual_size_n = " + str(actual_size_n));
+ 
+   // Verify A device and strides
+   TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
+   TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
+ 
+   // Verify B device and strides
+   TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
+   TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
+ 
+   // Verify scales device and strides
+   TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
+   TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
+ 
+   // Alloc C matrix
+   const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
+   auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
+   torch::Tensor c = torch::empty({size_m, size_n}, options);
+ 
+   // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
+   // auto -1)
+   int thread_k = -1;
+   // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
+   // auto -1)
+   int thread_n = -1;
+   // sms: number of SMs to use for the kernel (can usually be left as auto -1)
+   int sms = -1;
+ 
+   // Detect groupsize
+   if (b_scales.size(0) != 1) {
+     TORCH_CHECK(size_k % b_scales.size(0) == 0,
+                 "size_k = " + str(size_k) +
+                     ", is not divisible by b_scales.size(0) = " +
+                     str(b_scales.size(0)));
+   }
+   int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0);
+ 
+   // Verify groupsize
+   TORCH_CHECK(groupsize == -1 || groupsize == 128,
+               "Unexpected groupsize = " + str(groupsize));
+ 
+   // Verify workspace size
+   TORCH_CHECK(
+       size_n % marlin::min_thread_n == 0,
+       "size_n = " + str(size_n) +
+           ", is not divisible by min_thread_n = " + str(marlin::min_thread_n));
+   int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
+   TORCH_CHECK(workspace.numel() >= min_workspace_size,
+               "workspace.numel = " + str(workspace.numel()) +
+                   " is below min_workspace_size = " + str(min_workspace_size));
+ 
+   int dev = a.get_device();
+   marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
+                       b_scales.data_ptr(), size_m, size_n, size_k,
+                       workspace.data_ptr(), groupsize, dev,
+                       at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
+                       sms, marlin::max_par);
+ 
+   return c;
+ }
+ 

+ 5 - 4
tests/benchmarks/throughput.py

@@ -207,10 +207,11 @@ if __name__ == "__main__":
                         type=str,
                         default="EleutherAI/pythia-70m-deduped")
     parser.add_argument("--tokenizer", type=str, default=None)
-    parser.add_argument("--quantization",
-                        "-q",
-                        choices=["awq", "gguf", "gptq", "squeezellm", None],
-                        default=None)
+    parser.add_argument(
+        "--quantization",
+        "-q",
+        choices=["awq", "gguf", "gptq", "squeezellm", "marlin", None],
+        default=None)
     parser.add_argument("--gpu-memory-utilization", type=float, default=0.88)
     parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
     parser.add_argument("--n",

+ 41 - 5
tests/conftest.py

@@ -13,12 +13,10 @@ _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
 _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
 
 
-def _read_prompts(filename: str) -> str:
-    prompts = []
+def _read_prompts(filename: str) -> List[str]:
     with open(filename, "r") as f:
-        prompt = f.readline()
-        prompts.append(prompt)
-    return prompts
+        prompts = f.readlines()
+        return prompts
 
 
 @pytest.fixture
@@ -165,6 +163,9 @@ class AphroditeRunner:
         model_name: str,
         tokenizer_name: Optional[str] = None,
         dtype: str = "half",
+        disable_log_stats: bool = True,
+        tensor_parallel_size: int = 1,
+        **kwargs,
     ) -> None:
         self.model = LLM(
             model=model_name,
@@ -172,6 +173,9 @@ class AphroditeRunner:
             trust_remote_code=True,
             dtype=dtype,
             swap_space=0,
+            disable_log_stats=disable_log_stats,
+            tensor_parallel_size=tensor_parallel_size,
+            **kwargs,
         )
 
     def generate(
@@ -195,6 +199,24 @@ class AphroditeRunner:
             outputs.append((req_sample_output_ids, req_sample_output_strs))
         return outputs
 
+    def generate_w_logprobs(
+        self,
+        prompts: List[str],
+        sampling_params: SamplingParams,
+    ) -> List[Tuple[List[int], str]]:
+        assert sampling_params.logprobs is not None
+
+        req_outputs = self.model.generate(prompts,
+                                          sampling_params=sampling_params)
+        outputs = []
+        for req_output in req_outputs:
+            for sample in req_output.outputs:
+                output_str = sample.text
+                output_ids = sample.token_ids
+                output_logprobs = sample.logprobs
+            outputs.append((output_ids, output_str, output_logprobs))
+        return outputs
+
     def generate_greedy(
         self,
         prompts: List[str],
@@ -205,6 +227,20 @@ class AphroditeRunner:
         return [(output_ids[0], output_str[0])
                 for output_ids, output_str in outputs]
 
+    def generate_greedy_logprobs(
+        self,
+        prompts: List[str],
+        max_tokens: int,
+        num_logprobs: int,
+    ) -> List[Tuple[List[int], str]]:
+        greedy_logprobs_params = SamplingParams(temperature=0.0,
+                                                max_tokens=max_tokens,
+                                                logprobs=num_logprobs)
+        outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)
+
+        return [(output_ids, output_str, output_logprobs)
+                for output_ids, output_str, output_logprobs in outputs]
+
     def generate_beam_search(
         self,
         prompts: List[str],