Browse Source

chore: refactor marlin python utils

AlpinDale 6 months ago
parent
commit
058e629f8e

+ 77 - 84
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py

@@ -7,10 +7,10 @@ from aphrodite import _custom_ops as ops
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.compressed_tensors.schemes import \
     CompressedTensorsScheme
-from aphrodite.quantization.gptq_marlin import (GPTQ_MARLIN_MAX_PARALLEL,
-                                                GPTQ_MARLIN_MIN_THREAD_N,
-                                                GPTQMarlinState,
-                                                marlin_permute_scales)
+from aphrodite.quantization.utils.marlin_utils import (
+    apply_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
+    marlin_permute_scales, replace_tensor, verify_marlin_supported,
+    verify_marlin_supports_shape)
 
 __all__ = ["CompressedTensorsWNA16"]
 WNA16_SUPPORTED_BITS = [4, 8]
@@ -23,40 +23,53 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
                  num_bits: int,
                  group_size: Optional[int] = None):
         self.num_bits = num_bits
+        self.pack_factor = 32 // self.num_bits
         self.strategy = strategy
-        self.group_size = group_size
 
-        if self.strategy == "group" and self.group_size is None:
-            raise ValueError(
-                "group_size must be given when using strategy group")
+        self.group_size: int
+        if group_size is None:
+            if self.strategy != "channel":
+                raise ValueError(
+                    "Marlin kernels require group quantization or "
+                    "channelwise quantization, but found no group "
+                    "size and strategy is not channelwise.")
+            self.group_size = -1
+        else:
+            self.group_size = group_size
 
-    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
-        pass
+        # Verify supported on platform.
+        verify_marlin_supported(num_bits=self.num_bits,
+                                group_size=self.group_size,
+                                is_sym=True)
 
     def create_weights(self, layer: torch.nn.Module, input_size: int,
                        output_partition_sizes: List[int],
                        input_size_per_partition: int,
                        params_dtype: torch.dtype, weight_loader: Callable,
                        **kwargs):
-
-        pack_factor = 32 // self.num_bits
         output_size_per_partition = sum(output_partition_sizes)
 
-        if self.group_size is not None:
-            group_size = self.group_size
-        else:
-            group_size = input_size
+        # If group_size is -1, we are in channelwise case.
+        group_size = input_size if self.group_size == -1 else self.group_size
+
+        verify_marlin_supports_shape(
+            output_size_per_partition=output_size_per_partition,
+            input_size_per_partition=input_size_per_partition,
+            input_size=input_size,
+            group_size=group_size)
 
+        weight_scale_dim = None
         scales_and_zp_size = input_size // group_size
 
         if (input_size != input_size_per_partition
                 and self.group_size is not None):
+            weight_scale_dim = 1
             scales_and_zp_size = input_size_per_partition // group_size
 
         weight = Parameter(
             torch.empty(
                 output_size_per_partition,
-                input_size_per_partition // pack_factor,
+                input_size_per_partition // self.pack_factor,
                 dtype=torch.int32,
             ),
             requires_grad=False,
@@ -67,10 +80,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
                 "input_dim": 1,
                 "output_dim": 0,
                 "packed_dim": 1,
-                "pack_factor": pack_factor,
+                "pack_factor": self.pack_factor,
                 "weight_loader": weight_loader
             })
-
         layer.register_parameter("weight_packed", weight)
 
         weight_scale = Parameter(
@@ -82,6 +94,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
             requires_grad=False,
         )
 
+        set_weight_attrs(
+            weight_scale, {
+                "weight_loader": weight_loader,
+                "input_dim": weight_scale_dim,
+                "output_dim": 0
+            })
         layer.register_parameter("weight_scale", weight_scale)
 
         # A 2D array defining the original shape of the weights
@@ -97,73 +115,48 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
 
         layer.input_size_per_partition = input_size_per_partition
         layer.output_size_per_partition = output_size_per_partition
-
         layer.input_size = input_size
-        layer.marlin_state = GPTQMarlinState.REPACK
-        layer.is_k_full = True
         layer.group_size = group_size
 
-        max_workspace_size = (
-            output_size_per_partition //
-            GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
-
-        workspace = torch.zeros(max_workspace_size,
-                                dtype=torch.int,
-                                requires_grad=False)
-        layer.workspace = workspace
+    # Checkpoints are serialized in compressed-tensors format, which is
+    # different from marlin format. Handle repacking here.
+    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+        device = layer.weight_packed.device
+
+        # Allocate marlin workspace.
+        layer.workspace = marlin_make_workspace(
+            layer.output_size_per_partition, device)
+
+        # Act-order not supported in compressed-tensors yet, so set to empty.
+        layer.g_idx = marlin_make_empty_g_idx(device)
+        layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
+
+        # Repack weights from compressed-tensors format to marlin format.
+        marlin_qweight = ops.gptq_marlin_repack(
+            layer.weight_packed.t().contiguous(),
+            perm=layer.g_idx_sort_indices,
+            size_k=layer.input_size_per_partition,
+            size_n=layer.output_size_per_partition,
+            num_bits=self.num_bits)
+        replace_tensor(layer, "weight_packed", marlin_qweight)
+
+        # Permute scales from compressed-tensors format to marlin format.
+        marlin_scales = marlin_permute_scales(
+            layer.weight_scale.squeeze().t().contiguous(),
+            size_k=layer.input_size_per_partition,
+            size_n=layer.output_size_per_partition,
+            group_size=layer.group_size)
+        replace_tensor(layer, "weight_scale", marlin_scales)
 
     def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
-        reshaped_x = x.reshape(-1, x.shape[-1])
-
-        size_m = reshaped_x.shape[0]
-        part_size_n = layer.output_size_per_partition
-        part_size_k = layer.input_size_per_partition
-
-        out_shape = x.shape[:-1] + (part_size_n, )
-
-        if layer.marlin_state == GPTQMarlinState.REPACK:
-            layer.marlin_state = GPTQMarlinState.READY
-
-            # Newly generated tensors need to replace existing tensors that are
-            # already registered as parameters by Aphrodite (and won't be freed)
-            def replace_tensor(name, new_t):
-                # It is important to use resize_() here since it ensures
-                # the same buffer is reused
-                getattr(layer, name).resize_(new_t.shape)
-                getattr(layer, name).copy_(new_t)
-                del new_t
-
-            cur_device = layer.weight_packed.device
-
-            # Reset g_idx related tensors
-            layer.g_idx = Parameter(torch.empty(0,
-                                                dtype=torch.int,
-                                                device=cur_device),
-                                    requires_grad=False)
-            layer.g_idx_sort_indices = Parameter(torch.empty(
-                0, dtype=torch.int, device=cur_device),
-                                                 requires_grad=False)
-
-            # Repack weights
-            marlin_qweight = ops.gptq_marlin_repack(
-                layer.weight_packed.t().contiguous(), layer.g_idx_sort_indices,
-                part_size_k, part_size_n, self.num_bits)
-
-            replace_tensor("weight_packed", marlin_qweight)
-
-            # Permute scales
-            scales_size_k = part_size_k
-            scales_size_n = part_size_n
-
-            marlin_scales = marlin_permute_scales(
-                layer.weight_scale.squeeze().t().contiguous(), scales_size_k,
-                scales_size_n, layer.group_size, self.num_bits)
-            replace_tensor("weight_scale", marlin_scales)
-
-        output = ops.gptq_marlin_gemm(reshaped_x, layer.weight_packed,
-                                      layer.weight_scale, layer.g_idx,
-                                      layer.g_idx_sort_indices,
-                                      layer.workspace, self.num_bits, size_m,
-                                      part_size_n, part_size_k,
-                                      layer.is_k_full)
-        return output.reshape(out_shape)
+        return apply_marlin_linear(
+            input=x,
+            weight=layer.weight_packed,
+            weight_scale=layer.weight_scale,
+            g_idx=layer.g_idx,
+            g_idx_sort_indices=layer.g_idx_sort_indices,
+            workspace=layer.workspace,
+            num_bits=self.num_bits,
+            output_size_per_partition=layer.output_size_per_partition,
+            input_size_per_partition=layer.input_size_per_partition,
+            is_k_full=True)

+ 1 - 1
aphrodite/quantization/fp8.py

@@ -14,7 +14,7 @@ from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.platforms import current_platform
 from aphrodite.quantization.base_config import (QuantizationConfig,
                                                 QuantizeMethodBase)
-from aphrodite.quantization.utils.marlin_utils import (
+from aphrodite.quantization.utils.marlin_utils_fp8 import (
     apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
 from aphrodite.quantization.utils.w8a8_utils import (
     all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,

+ 67 - 196
aphrodite/quantization/gptq_marlin.py

@@ -1,5 +1,3 @@
-import enum
-from enum import Enum
 from typing import Any, Dict, List, Optional
 
 import torch
@@ -10,43 +8,11 @@ from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
                                               set_weight_attrs)
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
-from aphrodite.platforms import current_platform
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.utils.marlin_utils import (
-    GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_K,
-    GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES,
-    GPTQ_MARLIN_SUPPORTED_NUM_BITS, GPTQ_MARLIN_SUPPORTED_SYM,
-    GPTQ_MARLIN_TILE)
-
-
-# Permutations for Marlin scale shuffling
-def get_scale_perms(num_bits: int):
-    scale_perm: List[int] = []
-    for i in range(8):
-        scale_perm.extend([i + 8 * j for j in range(8)])
-    scale_perm_single: List[int] = []
-    for i in range(4):
-        scale_perm_single.extend(
-            [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
-    return scale_perm, scale_perm_single
-
-
-def get_pack_factor(num_bits: int):
-    assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
-            ), f"Unsupported num_bits = {num_bits}"
-    return 32 // num_bits
-
-
-def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
-                          group_size: int, num_bits: int):
-    scale_perm, scale_perm_single = get_scale_perms(num_bits)
-    if group_size < size_k and group_size != -1:
-        s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
-    else:
-        s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
-    s = s.reshape((-1, size_n)).contiguous()
-
-    return s
+    check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
+    marlin_permute_scales, marlin_sort_g_idx, replace_tensor,
+    verify_marlin_supported, verify_marlin_supports_shape)
 
 
 class GPTQMarlinConfig(QuantizationConfig):
@@ -60,33 +26,16 @@ class GPTQMarlinConfig(QuantizationConfig):
             desc_act = False
 
         self.weight_bits = weight_bits
+        self.pack_factor = 32 // self.weight_bits  # packed into int32
         self.group_size = group_size
         self.desc_act = desc_act
         self.is_sym = is_sym
         self.lm_head_quantized = lm_head_quantized
 
-        # Verify
-        if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
-            raise ValueError(
-                f"Marlin does not support weight_bits = {self.weight_bits}. "
-                f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
-                "are supported.")
-        if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
-            raise ValueError(
-                f"Marlin does not support group_size = {self.group_size}. "
-                f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
-                "are supported.")
-        if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
-            raise ValueError(
-                f"Marlin does not support is_sym = {self.is_sym}. "
-                f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
-
-        # Init
-        self.pack_factor = get_pack_factor(weight_bits)
-        self.tile_size = GPTQ_MARLIN_TILE
-        self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
-        self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
-        self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
+        # Verify supported on platform.
+        verify_marlin_supported(num_bits=self.weight_bits,
+                                group_size=self.group_size,
+                                is_sym=self.is_sym)
 
     def __repr__(self) -> str:
         return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
@@ -165,21 +114,10 @@ class GPTQMarlinConfig(QuantizationConfig):
                 or desc_act is None):
             return False
 
-        # If the capability of the device is too low, cannot convert.
-        major, minor = current_platform.get_device_capability()
-        device_capability = major * 10 + minor
-        if device_capability < cls.get_min_capability():
-            return False
-
-        # Otherwise, can convert if model satisfies marlin constraints.
-        return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
-                and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
-                and sym in GPTQ_MARLIN_SUPPORTED_SYM)
-
-
-class GPTQMarlinState(Enum):
-    REPACK = enum.auto()
-    READY = enum.auto()
+        return check_marlin_supported(num_bits=num_bits,
+                                      group_size=group_size,
+                                      is_sym=sym,
+                                      min_capability=cls.get_min_capability())
 
 
 class GPTQMarlinLinearMethod(LinearMethodBase):
@@ -203,6 +141,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
         **extra_weight_attrs,
     ) -> None:
         del output_size
+        output_size_per_partition = sum(output_partition_sizes)
 
         # Normalize group_size
         if self.quant_config.group_size != -1:
@@ -210,31 +149,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
         else:
             group_size = input_size
 
-        # Validate dtype
-        if params_dtype not in [torch.float16, torch.bfloat16]:
-            raise ValueError(f"The params dtype must be float16 "
-                             f"or bfloat16, but got {params_dtype}")
-
-        # Validate output_size_per_partition
-        output_size_per_partition = sum(output_partition_sizes)
-        if output_size_per_partition % self.quant_config.min_thread_n != 0:
-            raise ValueError(
-                f"Weight output_size_per_partition = "
-                f"{output_size_per_partition} is not divisible by "
-                f" min_thread_n = {self.quant_config.min_thread_n}.")
-
-        # Validate input_size_per_partition
-        if input_size_per_partition % self.quant_config.min_thread_k != 0:
-            raise ValueError(
-                f"Weight input_size_per_partition = "
-                f"{input_size_per_partition} is not divisible "
-                f"by min_thread_k = {self.quant_config.min_thread_k}.")
-
-        if (group_size < input_size
-                and input_size_per_partition % group_size != 0):
-            raise ValueError(
-                f"Weight input_size_per_partition = {input_size_per_partition}"
-                f" is not divisible by group_size = {group_size}.")
+        verify_marlin_supports_shape(
+            output_size_per_partition=output_size_per_partition,
+            input_size_per_partition=input_size_per_partition,
+            input_size=input_size,
+            group_size=group_size)
 
         # Detect sharding of scales/zp
 
@@ -300,11 +219,6 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
             },
         )
 
-        g_idx_sort_indices = torch.empty(
-            g_idx.shape,
-            dtype=torch.int32,
-        )
-
         # Scales
         scales = Parameter(
             torch.empty(
@@ -344,25 +258,50 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
             },
         )
 
-        # Allocate marlin workspace
-        max_workspace_size = (
-            output_size_per_partition //
-            self.quant_config.min_thread_n) * self.quant_config.max_parallel
-        workspace = torch.zeros(max_workspace_size,
-                                dtype=torch.int,
-                                requires_grad=False)
-
         layer.register_parameter("qweight", qweight)
         layer.register_parameter("g_idx", g_idx)
         layer.register_parameter("scales", scales)
         layer.register_parameter("qzeros", qzeros)
-        layer.g_idx_sort_indices = g_idx_sort_indices
-        layer.workspace = workspace
         layer.input_size_per_partition = input_size_per_partition
         layer.output_size_per_partition = output_size_per_partition
         layer.input_size = input_size
         layer.is_k_full = is_k_full
-        layer.marlin_state = GPTQMarlinState.REPACK
+
+    # Checkpoints are serialized in AutoGPTQ format, which is different from the
+    # marlin format. This function is called after the weights are loaded.
+    # Here, we handle the repacking, including the activation reordering case.
+    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+        device = layer.qweight.device
+        # Allocate marlin workspace
+        layer.workspace = marlin_make_workspace(
+            layer.output_size_per_partition, device)
+
+        # Handle sorting for activation reordering if needed.
+        if self.quant_config.desc_act:
+            g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
+            layer.g_idx_sort_indices = g_idx_sort_indices
+            replace_tensor(layer, "g_idx", g_idx)
+        else:
+            layer.g_idx = marlin_make_empty_g_idx(device)
+            layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
+
+        # Repack weights from autogptq format to marlin format.
+        marlin_qweight = ops.gptq_marlin_repack(
+            layer.qweight,
+            perm=layer.g_idx_sort_indices,
+            size_k=layer.input_size_per_partition,
+            size_n=layer.output_size_per_partition,
+            num_bits=self.quant_config.weight_bits)
+        replace_tensor(layer, "qweight", marlin_qweight)
+
+        # Permute scales from autogptq format to marlin format.
+        marlin_scales = marlin_permute_scales(
+            layer.scales,
+            size_k=(layer.input_size if self.quant_config.desc_act else
+                    layer.input_size_per_partition),
+            size_n=layer.output_size_per_partition,
+            group_size=self.quant_config.group_size)
+        replace_tensor(layer, "scales", marlin_scales)
 
     def apply(
         self,
@@ -371,87 +310,19 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
         bias: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
         reshaped_x = x.reshape(-1, x.shape[-1])
-
-        size_m = reshaped_x.shape[0]
-        part_size_n = layer.output_size_per_partition
-        part_size_k = layer.input_size_per_partition
-        full_size_k = layer.input_size
-
-        out_shape = x.shape[:-1] + (part_size_n, )
-
-        if layer.marlin_state == GPTQMarlinState.REPACK:
-            layer.marlin_state = GPTQMarlinState.READY
-
-            # Newly generated tensors need to replace existing tensors that are
-            # already registered as parameters by vLLM (and won't be freed)
-            def replace_tensor(name, new_t):
-                # It is important to use resize_() here since it ensures
-                # the same buffer is reused
-                getattr(layer, name).resize_(new_t.shape)
-                getattr(layer, name).copy_(new_t)
-                del new_t
-
-            cur_device = layer.qweight.device
-
-            # Process act_order
-            if self.quant_config.desc_act:
-                # Get sorting based on g_idx
-                g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
-
-                sorted_g_idx = layer.g_idx[g_idx_sort_indices]
-
-                replace_tensor("g_idx", sorted_g_idx)
-                replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
-
-            else:
-                # Reset g_idx related tensors
-                layer.g_idx = Parameter(
-                    torch.empty(0, dtype=torch.int, device=cur_device),
-                    requires_grad=False,
-                )
-                layer.g_idx_sort_indices = Parameter(
-                    torch.empty(0, dtype=torch.int, device=cur_device),
-                    requires_grad=False,
-                )
-
-            # Repack weights
-            marlin_qweight = ops.gptq_marlin_repack(
-                layer.qweight,
-                layer.g_idx_sort_indices,
-                part_size_k,
-                part_size_n,
-                self.quant_config.weight_bits,
-            )
-            replace_tensor("qweight", marlin_qweight)
-
-            # Permute scales
-            scales_size_k = part_size_k
-            scales_size_n = part_size_n
-            if self.quant_config.desc_act:
-                scales_size_k = full_size_k
-
-            marlin_scales = marlin_permute_scales(
-                layer.scales,
-                scales_size_k,
-                scales_size_n,
-                self.quant_config.group_size,
-                self.quant_config.weight_bits,
-            )
-            replace_tensor("scales", marlin_scales)
-
-        output = ops.gptq_marlin_gemm(
-            reshaped_x,
-            layer.qweight,
-            layer.scales,
-            layer.g_idx,
-            layer.g_idx_sort_indices,
-            layer.workspace,
-            self.quant_config.weight_bits,
-            size_m,
-            part_size_n,
-            part_size_k,
-            layer.is_k_full,
-        )
+        out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
+
+        output = ops.gptq_marlin_gemm(reshaped_x,
+                                      layer.qweight,
+                                      layer.scales,
+                                      g_idx=layer.g_idx,
+                                      perm=layer.g_idx_sort_indices,
+                                      workspace=layer.workspace,
+                                      num_bits=self.quant_config.weight_bits,
+                                      size_m=reshaped_x.shape[0],
+                                      size_n=layer.output_size_per_partition,
+                                      size_k=layer.input_size_per_partition,
+                                      is_k_full=layer.is_k_full)
 
         if bias is not None:
             output.add_(bias)  # In-place add

+ 0 - 59
aphrodite/quantization/utils/marlin_24_perms.py

@@ -1,59 +0,0 @@
-from typing import Dict, List
-
-import numpy
-import torch
-
-
-# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
-#
-# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
-# with the tensor-core format that is described here:
-# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
-#
-# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
-# (without the need to use ldmatrix instructions) # noqa: E501
-def get_perms_24(num_bits: int):
-    perm_list: List[int] = []
-    for i in range(32):
-        perm1: List[int] = []
-        col = i // 4
-        col_o = col // 2
-        for block in [0, 1]:
-            for row in [
-                    2 * (i % 4),
-                    2 * (i % 4) + 1,
-                    2 * (i % 4 + 4),
-                    2 * (i % 4 + 4) + 1,
-            ]:
-                perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +
-                             4 * block)
-        for j in range(4):
-            perm_list.extend([p + 1 * j for p in perm1])
-    perm = numpy.array(perm_list)
-
-    if num_bits == 4:
-        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
-    elif num_bits == 8:
-        interleave = numpy.array([0, 2, 1, 3])
-    else:
-        raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
-
-    perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
-    perm = torch.from_numpy(perm)
-    scale_perm: List[int] = []
-    for i in range(8):
-        scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
-    scale_perm_single: List[int] = []
-    for i in range(8):
-        scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
-    return perm, scale_perm, scale_perm_single
-
-
-marlin_24_perm: Dict[int, torch.Tensor] = {}
-marlin_24_scale_perm: Dict[int, List[int]] = {}
-marlin_24_scale_perm_single: Dict[int, List[int]] = {}
-for num_bits in [4, 8]:
-    perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits)
-    marlin_24_perm[num_bits] = perm_24
-    marlin_24_scale_perm[num_bits] = scale_perm_24
-    marlin_24_scale_perm_single[num_bits] = scale_perm_single_24

+ 0 - 59
aphrodite/quantization/utils/marlin_perms.py

@@ -1,59 +0,0 @@
-from typing import Dict, List
-
-import numpy
-import torch
-
-
-# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
-#
-# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
-# with the tensor-core format that is described here:
-# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
-#
-# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
-# (without the need to use ldmatrix instructions) # noqa: E501
-def get_perms(num_bits: int):
-    perm_list: List[int] = []
-    for i in range(32):
-        perm1: List[int] = []
-        col = i // 4
-        for block in [0, 1]:
-            for row in [
-                    2 * (i % 4),
-                    2 * (i % 4) + 1,
-                    2 * (i % 4 + 4),
-                    2 * (i % 4 + 4) + 1,
-            ]:
-                perm1.append(16 * row + col + 8 * block)
-        for j in range(4):
-            perm_list.extend([p + 256 * j for p in perm1])
-
-    perm = numpy.array(perm_list)
-
-    if num_bits == 4:
-        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
-    elif num_bits == 8:
-        interleave = numpy.array([0, 2, 1, 3])
-    else:
-        raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
-
-    perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
-    perm = torch.from_numpy(perm)
-    scale_perm: List[int] = []
-    for i in range(8):
-        scale_perm.extend([i + 8 * j for j in range(8)])
-    scale_perm_single: List[int] = []
-    for i in range(4):
-        scale_perm_single.extend(
-            [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
-    return perm, scale_perm, scale_perm_single
-
-
-marlin_perm: Dict[int, torch.Tensor] = {}
-marlin_scale_perm: Dict[int, List[int]] = {}
-marlin_scale_perm_single: Dict[int, List[int]] = {}
-for num_bits in [4, 8]:
-    perm, scale_perm, scale_perm_single = get_perms(num_bits)
-    marlin_perm[num_bits] = perm
-    marlin_scale_perm[num_bits] = scale_perm
-    marlin_scale_perm_single[num_bits] = scale_perm_single

+ 133 - 308
aphrodite/quantization/utils/marlin_utils.py

@@ -1,23 +1,9 @@
-import random
-from typing import Optional
+from typing import List, Optional, Tuple
 
-import numpy
 import torch
 
 from aphrodite import _custom_ops as ops
-from aphrodite.quantization.utils.format_24 import (
-    mask_creator, sparse_semi_structured_from_dense_cutlass)
-from aphrodite.quantization.utils.marlin_24_perms import (
-    marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single)
-from aphrodite.quantization.utils.marlin_perms import (marlin_perm,
-                                                       marlin_scale_perm,
-                                                       marlin_scale_perm_single
-                                                       )
-from aphrodite.quantization.utils.quant_utils import (get_pack_factor,
-                                                      quantize_weights,
-                                                      sort_weights)
 from aphrodite.platforms import current_platform
-from aphrodite.common.utils import print_warning_once
 
 GPTQ_MARLIN_TILE = 16
 GPTQ_MARLIN_MIN_THREAD_N = 64
@@ -27,135 +13,110 @@ GPTQ_MARLIN_MAX_PARALLEL = 16
 GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
 GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
 GPTQ_MARLIN_SUPPORTED_SYM = [True]
-
-
-def is_marlin_supported():
-    capability = current_platform.get_device_capability()
-    return capability[0] >= 8
-
-
-def apply_fp8_marlin_linear(
-    input: torch.Tensor,
-    weight: torch.Tensor,
-    weight_scale: torch.Tensor,
-    workspace: torch.Tensor,
-    size_n: int,
-    size_k: int,
-    bias: Optional[torch.Tensor],
-) -> torch.Tensor:
-    # For GPUs that lack FP8 hardware support, we can leverage the
-    # Marlin kernel for fast weight-only FP8 quantization
-
-    reshaped_x = input.reshape(-1, input.shape[-1])
-    out_shape = input.shape[:-1] + (size_n, )
-
-    output = ops.fp8_marlin_gemm(
-        a=reshaped_x,
-        b_q_weight=weight,
-        b_scales=weight_scale,
-        workspace=workspace,
-        num_bits=8,
-        size_m=reshaped_x.shape[0],
-        size_n=size_n,
-        size_k=size_k,
-    )
-
-    if bias is not None:
-        output.add_(bias)  # In-place add
-
-    return output.reshape(out_shape)
-
-
-def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
-    print_warning_once(
-        "Your GPU does not have native support for FP8 computation but "
-        "FP8 quantization is being used. Weight-only FP8 compression will "
-        "be used leveraging the Marlin kernel. This may degrade "
-        "performance for compute-heavy workloads.")
-
-    part_size_n = layer.output_size_per_partition
-    part_size_k = layer.input_size_per_partition
-
-    device = layer.weight.device
-
-    # WEIGHTS
-    # Repack weights to gptq format (packed int32 elements)
-    packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
-
-    # Repack weights to marlin format
-    marlin_qweight = ops.gptq_marlin_repack(
-        b_q_weight=packed_gptq_qweight,
-        perm=torch.empty(0, dtype=torch.int, device=device),
-        size_k=part_size_k,
-        size_n=part_size_n,
-        num_bits=8,
-    )
-    layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
-
-    # WEIGHT SCALES
-    # Currently Marlin doesn't support per-tensor scales, so we
-    # expand it to channelwise
-    scales = layer.weight_scale.repeat(1, part_size_n).to(
-        layer.orig_dtype).to(device)
-    # Permute scales
-    num_bits = 8
-    marlin_scales = marlin_permute_scales(
-        s=scales,
-        size_k=part_size_k,
-        size_n=part_size_n,
-        group_size=-1,
-        scale_perm=marlin_scale_perm[num_bits],
-        scale_perm_single=marlin_scale_perm_single[num_bits])
-    layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
-
-    # Allocate marlin workspace
-    max_workspace_size = (part_size_n //
+GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER = [-1]
+
+
+def check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
+                           min_capability: int) -> bool:
+
+    # If the capability of the device is too low, cannot convert.
+    major, minor = current_platform.get_device_capability()
+    device_capability = major * 10 + minor
+    if device_capability < min_capability:
+        return False
+
+    return (device_capability >= min_capability
+            and num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
+            and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
+            and is_sym in GPTQ_MARLIN_SUPPORTED_SYM)
+
+
+def verify_marlin_supported(num_bits: int, group_size: Optional[int],
+                            is_sym: bool) -> None:
+
+    if num_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
+        raise ValueError(
+            f"Marlin does not support weight_bits = {num_bits}. "
+            f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
+            "are supported.")
+    if (group_size is None
+            or group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES):
+        raise ValueError(
+            f"Marlin does not support group_size = {group_size}. "
+            f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
+            "are supported.")
+    if is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
+        raise ValueError(
+            f"Marlin does not support is_sym = is_sym. "
+            f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
+
+
+def verify_marlin_supports_shape(output_size_per_partition: int,
+                                 input_size_per_partition: int,
+                                 input_size: int, group_size: int) -> None:
+
+    # Validate output_size_per_partition
+    if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
+        raise ValueError(f"Weight output_size_per_partition = "
+                         f"{output_size_per_partition} is not divisible by "
+                         f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
+                         "Consider reducing tensor_parallel_size or running "
+                         "with --quantization gptq.")
+
+    # Validate input_size_per_partition
+    if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
+        raise ValueError(f"Weight input_size_per_partition = "
+                         f"{input_size_per_partition} is not divisible "
+                         f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
+                         "Consider reducing tensor_parallel_size or running "
+                         "with --quantization gptq.")
+
+    if (group_size < input_size
+            and input_size_per_partition % group_size != 0):
+        raise ValueError(
+            f"Weight input_size_per_partition = {input_size_per_partition}"
+            f" is not divisible by group_size = {group_size}."
+            "Consider reducing tensor_parallel_size or running "
+            "with --quantization gptq.")
+
+
+def marlin_make_workspace(output_size_per_partition: int,
+                          device: torch.device) -> torch.Tensor:
+    max_workspace_size = (output_size_per_partition //
                           GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
-    workspace = torch.zeros(max_workspace_size,
-                            dtype=torch.int,
-                            device=device,
-                            requires_grad=False)
-
-    layer.workspace = workspace
-
-
-def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
-    assert q_w.shape == (size_k, size_n)
-    assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
-    assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
-
-    # Permute weights to 16x64 marlin tiles
-    q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
-    q_w = q_w.permute((0, 2, 1, 3))
-    q_w = q_w.reshape((size_k // tile, size_n * tile))
 
-    q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
+    return torch.zeros(max_workspace_size,
+                       dtype=torch.int,
+                       device=device,
+                       requires_grad=False)
 
-    return q_w
 
+def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
+    return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
+                              requires_grad=False)
 
-def marlin_weights(q_w, size_k, size_n, num_bits, perm):
-    # Permute
-    q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
 
-    # Pack
-    pack_factor = get_pack_factor(num_bits)
-    orig_device = q_w.device
+def marlin_sort_g_idx(
+        g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+    g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
+    return g_idx[g_idx_sort_indices], g_idx_sort_indices
 
-    q_w = q_w.cpu().numpy().astype(numpy.uint32)
 
-    q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
-                           dtype=numpy.uint32)
-    for i in range(pack_factor):
-        q_packed |= q_w[:, i::pack_factor] << num_bits * i
+def get_scale_perms():
+    scale_perm: List[int] = []
+    for i in range(8):
+        scale_perm.extend([i + 8 * j for j in range(8)])
+    scale_perm_single: List[int] = []
+    for i in range(4):
+        scale_perm_single.extend(
+            [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
+    return scale_perm, scale_perm_single
 
-    q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
 
-    return q_packed
+def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
+                          group_size: int) -> torch.Tensor:
 
-
-def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,
-                          scale_perm_single):
+    scale_perm, scale_perm_single = get_scale_perms()
     if group_size < size_k and group_size != -1:
         s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
     else:
@@ -165,180 +126,44 @@ def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,
     return s
 
 
-def marlin_quantize(
-    w: torch.Tensor,
-    num_bits: int,
-    group_size: int,
-    act_order: bool,
-):
-    size_k, size_n = w.shape
-
-    # Normalize group_size
-    if group_size == -1:
-        group_size = size_k
-    assert group_size <= size_k
-
-    # Quantize (and apply act_order if provided)
-    w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
-                                                       act_order)
-
-    # For act_order, sort the "weights" and "g_idx" so that group ids are
-    # increasing
-    sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
-    if act_order:
-        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
-
-    # Reformat to marlin
-    marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,
-                                marlin_perm[num_bits])
-    marlin_s = marlin_permute_scales(s, size_k, size_n, group_size,
-                                     marlin_scale_perm[num_bits],
-                                     marlin_scale_perm_single[num_bits])
-
-    # Create result
-    res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
-    for i in range(len(res_list)):
-        res_list[i] = res_list[i].to(w.device)
-
-    return res_list
-
-
-def inject_24(w, size_k, size_n):
-    assert w.shape == (size_k, size_n)
-
-    mask = mask_creator(w.t()).t().cuda().bool()
-
-    return (mask * w).contiguous(), mask.contiguous()
-
-
-def check_24(w, num_rows_to_sample=50, _verbose=False):
-    BLOCK_SIZE = 4
-    MAX_NON_ZEROS = 2
-
-    w = w.t().contiguous()
-
-    print("check_24: w.shape = {}".format(w.shape))
-
-    num_rows, num_cols = w.shape
-    sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
-    if _verbose:
-        print(f"Sampled row idxs = {sampled_row_idxs}")
-
-    total_segments = 0
-    non_24_segments = 0
-    for i in sampled_row_idxs:
-        for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
-            total_segments += 1
-            block = w[i, j:j + BLOCK_SIZE]
-            num_nonzero = torch.count_nonzero(block)
-            if num_nonzero > MAX_NON_ZEROS:
-                print("i = {} j = {} block = {}".format(i, j, block))
-                non_24_segments += 1
-
-    print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
-
-
-def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
-    assert q_24.shape == (size_k, size_n)
-
-    # Remove zp to normalize over 0
-    max_q_val = (1 << num_bits) - 1
-    zp = (max_q_val + 1) // 2
-    q_24_no_zp = q_24 - zp
-
-    # Compress
-    q_24_no_zp = q_24_no_zp.t().contiguous()
-    q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(
-        q_24_no_zp)
-    q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
-
-    # Restore zp
-    q_24_comp = q_24_no_zp_comp + zp
-
-    # Resize meta to its actual shape (without moving any data)
-    meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
-
-    return q_24_comp, meta
-
-
-def marlin_24_quantize(
-    w: torch.Tensor,
-    num_bits: int,
-    group_size: int,
-):
-    size_k, size_n = w.shape
-
-    # Normalize group_size
-    if group_size == -1:
-        group_size = size_k
-    assert group_size <= size_k
-
-    # Inject 2:4 sparsity
-    w_24, mask_24 = inject_24(w, size_k, size_n)
-
-    # Quantize
-    w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
-                                                             num_bits,
-                                                             group_size,
-                                                             act_order=False)
-
-    # Compress quantized weight
-    q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
-                                                     num_bits)
-    size_k_comp = size_k // 2
-
-    # Reformat to marlin
-    marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
-                                        num_bits, marlin_24_perm[num_bits])
-    marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size,
-                                        marlin_24_scale_perm[num_bits],
-                                        marlin_24_scale_perm_single[num_bits])
-
-    # Create result
-    res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
-    for i in range(len(res_list)):
-        res_list[i] = res_list[i].to(w.device)
-
-    return res_list
-
-
-def compute_max_diff(output, output_ref):
-    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
-        torch.abs(output_ref))
-
-
-class MarlinWorkspace:
-
-    def __init__(self, out_features, min_thread_n, max_parallel):
-        assert (out_features % min_thread_n == 0), (
-            "out_features = {} is undivisible by min_thread_n = {}".format(
-                out_features, min_thread_n))
-
-        max_workspace_size = ((out_features // min_thread_n) * max_parallel)
-
-        self.scratch = torch.zeros(max_workspace_size,
-                                   dtype=torch.int,
-                                   device="cuda")
-
-
-def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
-    """
-    Repack FP8 weights to gptq format (packed int32 elements)
-    """
-    assert fp8_tensor.dtype == torch.float8_e4m3fn
-    assert fp8_tensor.shape[0] % 4 == 0
-
-    # Reshape to prepare for packing
-    reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
-
-    # Convert fp8 to uint8 (byte) representation
-    byte_tensor = reshaped.view(torch.uint8)
+# Newly generated tensors need to replace existing tensors that are
+# already registered as parameters by vLLM (and won't be freed)
+def replace_tensor(layer: torch.nn.Module, name: str,
+                   new_t: torch.Tensor) -> None:
+    # It is important to use resize_() here since it ensures
+    # the same buffer is reused
+    getattr(layer, name).resize_(new_t.shape)
+    getattr(layer, name).copy_(new_t)
+    del new_t
+
+
+def apply_marlin_linear(input: torch.Tensor,
+                        weight: torch.Tensor,
+                        weight_scale: torch.Tensor,
+                        g_idx: torch.Tensor,
+                        g_idx_sort_indices: torch.Tensor,
+                        workspace: torch.Tensor,
+                        num_bits: int,
+                        output_size_per_partition: int,
+                        input_size_per_partition: int,
+                        is_k_full: bool,
+                        bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    reshaped_x = input.reshape(-1, input.shape[-1])
+    out_shape = input.shape[:-1] + (output_size_per_partition, )
+
+    output = ops.gptq_marlin_gemm(reshaped_x,
+                                  weight,
+                                  weight_scale,
+                                  g_idx,
+                                  g_idx_sort_indices,
+                                  workspace,
+                                  num_bits,
+                                  size_m=reshaped_x.shape[0],
+                                  size_n=output_size_per_partition,
+                                  size_k=input_size_per_partition,
+                                  is_k_full=is_k_full)
 
-    # Pack 4 uint8 values into one int32
-    packed = (byte_tensor[:, 0].to(torch.int32) |
-              (byte_tensor[:, 1].to(torch.int32) << 8) |
-              (byte_tensor[:, 2].to(torch.int32) << 16) |
-              (byte_tensor[:, 3].to(torch.int32) << 24))
+    if bias is not None:
+        output.add_(bias)  # In-place add
 
-    return packed.view(fp8_tensor.shape[0] // 4,
-                       *fp8_tensor.shape[1:]).contiguous()
+    return output.reshape(out_shape)

+ 109 - 0
aphrodite/quantization/utils/marlin_utils_fp8.py

@@ -0,0 +1,109 @@
+from typing import Optional
+
+import torch
+
+import aphrodite._custom_ops as ops
+from aphrodite.common.utils import print_warning_once
+from aphrodite.platforms import current_platform
+
+from .marlin_utils import marlin_make_workspace, marlin_permute_scales
+
+
+def is_fp8_marlin_supported():
+    capability = current_platform.get_device_capability()
+    return capability[0] >= 8
+
+
+def apply_fp8_marlin_linear(
+    input: torch.Tensor,
+    weight: torch.Tensor,
+    weight_scale: torch.Tensor,
+    workspace: torch.Tensor,
+    size_n: int,
+    size_k: int,
+    bias: Optional[torch.Tensor],
+) -> torch.Tensor:
+    # For GPUs that lack FP8 hardware support, we can leverage the
+    # Marlin kernel for fast weight-only FP8 quantization
+
+    reshaped_x = input.reshape(-1, input.shape[-1])
+    out_shape = input.shape[:-1] + (size_n, )
+
+    output = ops.fp8_marlin_gemm(
+        a=reshaped_x,
+        b_q_weight=weight,
+        b_scales=weight_scale,
+        workspace=workspace,
+        num_bits=8,
+        size_m=reshaped_x.shape[0],
+        size_n=size_n,
+        size_k=size_k,
+    )
+
+    if bias is not None:
+        output.add_(bias)  # In-place add
+
+    return output.reshape(out_shape)
+
+
+def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
+    print_warning_once(
+        "Your GPU does not have native support for FP8 computation but "
+        "FP8 quantization is being used. Weight-only FP8 compression will "
+        "be used leveraging the Marlin kernel. This may degrade "
+        "performance for compute-heavy workloads.")
+
+    part_size_n = layer.output_size_per_partition
+    part_size_k = layer.input_size_per_partition
+
+    device = layer.weight.device
+
+    # WORKSPACE
+    layer.workspace = marlin_make_workspace(part_size_n, device)
+
+    # WEIGHT
+    # Repack weights to marlin format
+    marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32(
+        layer.weight),
+                                            perm=torch.empty(0,
+                                                             dtype=torch.int,
+                                                             device=device),
+                                            size_k=part_size_k,
+                                            size_n=part_size_n,
+                                            num_bits=8)
+    layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
+
+    # WEIGHT SCALES
+    # Currently Marlin doesn't support per-tensor scales, so we
+    # expand it to channelwise
+    scales = layer.weight_scale.repeat(1, part_size_n).to(
+        layer.orig_dtype).to(device)
+    # Permute scales
+    marlin_scales = marlin_permute_scales(s=scales,
+                                          size_k=part_size_k,
+                                          size_n=part_size_n,
+                                          group_size=-1)
+    layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
+
+
+def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
+    """
+    Repack FP8 weights to gptq format (packed int32 elements)
+    """
+    assert fp8_tensor.dtype == torch.float8_e4m3fn
+    assert fp8_tensor.shape[0] % 4 == 0
+
+    # Reshape to prepare for packing
+    reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
+
+    # Convert fp8 to uint8 (byte) representation
+    byte_tensor = reshaped.view(torch.uint8)
+
+    # Pack 4 uint8 values into one int32
+    packed = (byte_tensor[:, 0].to(torch.int32) |
+              (byte_tensor[:, 1].to(torch.int32) << 8) |
+              (byte_tensor[:, 2].to(torch.int32) << 16) |
+              (byte_tensor[:, 3].to(torch.int32) << 24))
+
+    return packed.view(fp8_tensor.shape[0] // 4,
+                       *fp8_tensor.shape[1:]).contiguous()

+ 120 - 0
aphrodite/quantization/utils/marlin_utils_test.py

@@ -0,0 +1,120 @@
+"""Utility functions used for tests and benchmarks"""
+
+from typing import List
+
+import numpy
+import torch
+
+from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales
+from .quant_utils import get_pack_factor, quantize_weights, sort_weights
+
+
+class MarlinWorkspace:
+
+    def __init__(self, out_features, min_thread_n, max_parallel):
+        assert (out_features % min_thread_n == 0), (
+            "out_features = {} is undivisible by min_thread_n = {}".format(
+                out_features, min_thread_n))
+
+        max_workspace_size = ((out_features // min_thread_n) * max_parallel)
+
+        self.scratch = torch.zeros(max_workspace_size,
+                                   dtype=torch.int,
+                                   device="cuda")
+
+
+def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
+    assert q_w.shape == (size_k, size_n)
+    assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
+    assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
+
+    # Permute weights to 16x64 marlin tiles
+    q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
+    q_w = q_w.permute((0, 2, 1, 3))
+    q_w = q_w.reshape((size_k // tile, size_n * tile))
+
+    q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
+
+    return q_w
+
+
+def marlin_weights(q_w, size_k, size_n, num_bits, perm):
+    # Permute
+    q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
+
+    # Pack
+    pack_factor = get_pack_factor(num_bits)
+    orig_device = q_w.device
+
+    q_w = q_w.cpu().numpy().astype(numpy.uint32)
+
+    q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
+                           dtype=numpy.uint32)
+    for i in range(pack_factor):
+        q_packed |= q_w[:, i::pack_factor] << num_bits * i
+
+    q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
+
+    return q_packed
+
+
+def get_weight_perm(num_bits: int):
+    perm_list: List[int] = []
+    for i in range(32):
+        perm1: List[int] = []
+        col = i // 4
+        for block in [0, 1]:
+            for row in [
+                    2 * (i % 4),
+                    2 * (i % 4) + 1,
+                    2 * (i % 4 + 4),
+                    2 * (i % 4 + 4) + 1,
+            ]:
+                perm1.append(16 * row + col + 8 * block)
+        for j in range(4):
+            perm_list.extend([p + 256 * j for p in perm1])
+
+    perm = numpy.array(perm_list)
+
+    if num_bits == 4:
+        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
+    elif num_bits == 8:
+        interleave = numpy.array([0, 2, 1, 3])
+    else:
+        raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
+
+    perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
+    perm = torch.from_numpy(perm)
+    return perm
+
+
+def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
+                    act_order: bool):
+    size_k, size_n = w.shape
+
+    # Normalize group_size
+    if group_size == -1:
+        group_size = size_k
+    assert group_size <= size_k
+
+    # Quantize (and apply act_order if provided)
+    w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
+                                                       act_order)
+
+    # For act_order, sort the "weights" and "g_idx" so that group ids are
+    # increasing
+    sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
+    if act_order:
+        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
+
+    # Reformat to marlin
+    weight_perm = get_weight_perm(num_bits)
+    marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
+    marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
+
+    # Create result
+    res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
+    for i in range(len(res_list)):
+        res_list[i] = res_list[i].to(w.device)
+
+    return res_list

+ 160 - 3
aphrodite/quantization/utils/format_24.py → aphrodite/quantization/utils/marlin_utils_test_24.py

@@ -1,9 +1,14 @@
-#
-# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
-#
+"""Utility functions used for tests and benchmarks"""
 
+import random
+from typing import List
+
+import numpy
 import torch
 
+from .marlin_utils_test import marlin_weights
+from .quant_utils import quantize_weights
+
 
 # This is PyTorch implementation of main part of reorder_meta()
 # function, from tools/util/include/cutlass/util/host_reorder.h file
@@ -306,3 +311,155 @@ def mask_creator(tensor):
     mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
 
     return mask
+
+
+def inject_24(w, size_k, size_n):
+    assert w.shape == (size_k, size_n)
+
+    mask = mask_creator(w.t()).t().cuda().bool()
+
+    return (mask * w).contiguous(), mask.contiguous()
+
+
+def check_24(w, num_rows_to_sample=50, _verbose=False):
+    BLOCK_SIZE = 4
+    MAX_NON_ZEROS = 2
+
+    w = w.t().contiguous()
+
+    print("check_24: w.shape = {}".format(w.shape))
+
+    num_rows, num_cols = w.shape
+    sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
+    if _verbose:
+        print(f"Sampled row idxs = {sampled_row_idxs}")
+
+    total_segments = 0
+    non_24_segments = 0
+    for i in sampled_row_idxs:
+        for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
+            total_segments += 1
+            block = w[i, j:j + BLOCK_SIZE]
+            num_nonzero = torch.count_nonzero(block)
+            if num_nonzero > MAX_NON_ZEROS:
+                print("i = {} j = {} block = {}".format(i, j, block))
+                non_24_segments += 1
+
+    print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
+
+
+def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
+    assert q_24.shape == (size_k, size_n)
+
+    # Remove zp to normalize over 0
+    max_q_val = (1 << num_bits) - 1
+    zp = (max_q_val + 1) // 2
+    q_24_no_zp = q_24 - zp
+
+    # Compress
+    q_24_no_zp = q_24_no_zp.t().contiguous()
+    q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(
+        q_24_no_zp)
+    q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
+
+    # Restore zp
+    q_24_comp = q_24_no_zp_comp + zp
+
+    # Resize meta to its actual shape (without moving any data)
+    meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
+
+    return q_24_comp, meta
+
+
+def get_scale_perms_24():
+    scale_perm: List[int] = []
+    for i in range(8):
+        scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
+    scale_perm_single: List[int] = []
+    for i in range(8):
+        scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
+    return scale_perm, scale_perm_single
+
+
+def get_weight_perm_24(num_bits: int):
+    perm_list: List[int] = []
+    for i in range(32):
+        perm1: List[int] = []
+        col = i // 4
+        col_o = col // 2
+        for block in [0, 1]:
+            for row in [
+                    2 * (i % 4),
+                    2 * (i % 4) + 1,
+                    2 * (i % 4 + 4),
+                    2 * (i % 4 + 4) + 1,
+            ]:
+                perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +
+                             4 * block)
+        for j in range(4):
+            perm_list.extend([p + 1 * j for p in perm1])
+    perm = numpy.array(perm_list)
+
+    if num_bits == 4:
+        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
+    elif num_bits == 8:
+        interleave = numpy.array([0, 2, 1, 3])
+    else:
+        raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
+
+    perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
+    perm = torch.from_numpy(perm)
+    return perm
+
+
+def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
+                             group_size: int) -> torch.Tensor:
+
+    scale_perm, scale_perm_single = get_scale_perms_24()
+    if group_size < size_k and group_size != -1:
+        s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
+    else:
+        s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
+    s = s.reshape((-1, size_n)).contiguous()
+
+    return s
+
+
+def marlin_24_quantize(
+    w: torch.Tensor,
+    num_bits: int,
+    group_size: int,
+):
+    size_k, size_n = w.shape
+
+    # Normalize group_size
+    if group_size == -1:
+        group_size = size_k
+    assert group_size <= size_k
+
+    # Inject 2:4 sparsity
+    w_24, mask_24 = inject_24(w, size_k, size_n)
+
+    # Quantize
+    w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
+                                                             num_bits,
+                                                             group_size,
+                                                             act_order=False)
+
+    # Compress quantized weight
+    q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
+                                                     num_bits)
+    size_k_comp = size_k // 2
+
+    # Reformat to marlin
+    weight_perm = get_weight_perm_24(num_bits)
+    marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
+                                        num_bits, weight_perm)
+    marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
+
+    # Create result
+    res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
+    for i in range(len(res_list)):
+        res_list[i] = res_list[i].to(w.device)
+
+    return res_list