Ver Fonte

kernels: disambiguate quantized types via a new ScalarType

Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
AlpinDale há 6 meses atrás
pai
commit
141672a0d4

+ 7 - 1
.github/workflows/publish.yml

@@ -4,9 +4,13 @@
 name: Create Release
 
 on:
+  schedule:
+    - cron:  '0 2 * * *'
   push:
+    branches:
+      - 'rc_054'
     tags:
-      - v*
+      - 'v*'
 
 # Needed to create release and upload assets
 permissions:
@@ -55,6 +59,8 @@ jobs:
     steps:
       - name: Checkout
         uses: actions/checkout@v3
+        with:
+          ref: 'rc_054'
 
       - name: Set up Linux Env
         if: ${{ runner.os == 'Linux' }}

+ 35 - 17
CMakeLists.txt

@@ -66,6 +66,39 @@ endif()
 #
 find_package(Torch REQUIRED)
 
+#
+# Add the `default` target which detects which extensions should be
+# built based on platform/architecture.  This is the same logic that
+# setup.py uses to select which extensions should be built and should
+# be kept in sync.
+#
+# The `default` target makes direct use of cmake easier since knowledge
+# of which extensions are supported has been factored in, e.g.
+#
+# mkdir build && cd build
+# cmake -G Ninja -DAPHRODITE_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../aphrodite ..
+# cmake --build . --target default
+#
+add_custom_target(default)
+message(STATUS "Enabling core extension.")
+
+# Define _core_C extension
+#  built for (almost) every target platform, (excludes TPU and Neuron)
+
+set(APHRODITE_EXT_SRC
+  "kernels/core/torch_bindings.cpp")
+
+define_gpu_extension_target(
+  _core_C
+  DESTINATION aphrodite
+  LANGUAGE CXX
+  SOURCES ${APHRODITE_EXT_SRC}
+  COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
+  USE_SABI 3
+  WITH_SOABI)
+
+add_dependencies(default _core_C)
+
 #
 # Forward the non-CUDA device extensions to external CMake scripts.
 #
@@ -74,7 +107,7 @@ if (NOT APHRODITE_TARGET_DEVICE STREQUAL "cuda" AND
     if (APHRODITE_TARGET_DEVICE STREQUAL "cpu")
         include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
     else()
-        message(FATAL_ERROR "Unsupported Aphrodite target device: ${APHRODITE_TARGET_DEVICE}")
+        return()
     endif()
     return()
 endif()
@@ -132,7 +165,7 @@ if(NVCC_THREADS AND APHRODITE_GPU_LANG STREQUAL "CUDA")
 endif()
 
 #
-# Define extension targets
+# Define other extension targets
 #
 
 #
@@ -227,21 +260,6 @@ define_gpu_extension_target(
   USE_SABI 3
   WITH_SOABI)
 
-#
-# Add the `default` target which detects which extensions should be
-# built based on platform/architecture.  This is the same logic that
-# setup.py uses to select which extensions should be built and should
-# be kept in sync.
-#
-# The `default` target makes direct use of cmake easier since knowledge
-# of which extensions are supported has been factored in, e.g.
-#
-# mkdir build && cd build
-# cmake -G Ninja -DAPHRODITE_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../aphrodite ..
-# cmake --build . --target default
-#
-add_custom_target(default)
-
 if(APHRODITE_GPU_LANG STREQUAL "CUDA" OR APHRODITE_GPU_LANG STREQUAL "HIP")
   message(STATUS "Enabling C extension.")
   add_dependencies(default _C)

+ 3 - 0
Dockerfile.openvino

@@ -13,6 +13,9 @@ COPY requirements-common.txt /workspace/aphrodite-engine/
 COPY requirements-openvino.txt /workspace/aphrodite-engine/
 
 COPY aphrodite/ /workspace/aphrodite-engine/aphrodite
+COPY kernels/core /workspace/aphrodite-engine/kernels/core
+COPY cmake/utils.cmake /workspace/aphrodite-engine/cmake/
+COPY CMakeLists.txt /workspace/aphrodite-engine/
 COPY setup.py /workspace/aphrodite-engine/
 
 # install build requirements

+ 176 - 0
aphrodite/_core_ext.py

@@ -0,0 +1,176 @@
+import importlib.util
+from enum import Enum
+from typing import TYPE_CHECKING, Optional, Union
+
+import torch
+from loguru import logger
+
+core_C_available = importlib.util.find_spec('._core_C',
+                                            'aphrodite') is not None
+
+
+# Mirrors enum in `core/scalar_type.hpp`
+class NanRepr(Enum):
+    NONE = 0  # nans are not supported
+    IEEE_754 = 1  # nans are: Exp all 1s, mantissa not all 0s
+    EXTD_RANGE_MAX_MIN = 2  # nans are: Exp all 1s, mantissa all 1s
+
+
+if TYPE_CHECKING or not core_C_available:
+    # On platforms were we cannot use/build the C++ core extension (i.e. namely
+    # neuron and tpu), we define the mock ScalarType class here that partially
+    # mimics the C++ ScalarType class.
+    #
+    # We also use this provide type signatures to the Python LSP for the methods
+    # in the C++ ScalarType class. So these type signatures should be kept
+    # in sync with csrc/core/scalar_type.hpp
+
+    from dataclasses import dataclass
+
+    @dataclass(frozen=True)
+    class ScalarType:
+        """
+        ScalarType can represent a wide range of floating point and integer 
+        types, in particular it can be used to represent sub-byte data types 
+        (something that torch.dtype currently does not support). It is also 
+        capable of  representing types with a bias, i.e.:
+          `stored_value = value + bias`, 
+        this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias 
+        of 8). The implementation for this class can be found in 
+        csrc/core/scalar_type.hpp, these type signatures should be kept in sync 
+        with that file.
+        """
+
+        exponent: int
+        """
+        Number of bits in the exponent if this is a floating point type
+        (zero if this an integer type)
+        """
+
+        mantissa: int
+        """
+        Number of bits in the mantissa if this is a floating point type,
+        or the number bits representing an integer excluding the sign bit if 
+        this an integer type.
+        """
+
+        bias: int
+        """
+        bias used to encode the values in this scalar type 
+        (value = stored_value - bias, default 0) for example if we store the 
+        type as an unsigned integer with a bias of 128 then the value 0 will be 
+        stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
+        """
+
+        signed: bool
+        "If the type is signed (i.e. has a sign bit)"
+
+        _finite_values_only: bool = False
+        """
+        Private: if NANs are supported, used `has_infs()` instead.
+        """
+
+        nan_repr: int = NanRepr.IEEE_754.value
+        """
+        How NaNs are represent in this scalar type, returns NanRepr value. 
+        (not applicable for integer types)
+        """
+
+        @property
+        def size_bits(self):
+            return self.exponent + self.mantissa + int(self.signed)
+
+        def min(self) -> Union[int, float]:
+            """
+            Min representable value for this scalar type. 
+            (accounting for bias if there is one)
+            """
+            raise NotImplementedError
+
+        def max(self) -> Union[int, float]:
+            """
+            Max representable value for this scalar type. 
+            (accounting for bias if there is one)
+            """
+            raise NotImplementedError
+
+        def is_signed(self) -> bool:
+            """
+            If the type is signed (i.e. has a sign bit), same as `signed`
+            added for consistency with:
+            https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
+            """
+            ...
+
+        def is_floating_point(self):
+            "If the type is a floating point type"
+            return self.exponent != 0
+
+        def is_integer(self):
+            "If the type is an integer type"
+            return self.exponent == 0
+
+        def has_bias(self):
+            "If the type has a non-zero bias"
+            return self.bias != 0
+
+        def has_infs(self):
+            "If the type is floating point and supports infinity"
+            return not self._finite_values_only
+
+        def has_nans(self):
+            return self.nan_repr != NanRepr.NONE.value
+
+        def is_ieee_754(self) -> bool:
+            """
+            If the type is a floating point type that follows IEEE 754 
+            conventions
+            """
+            return self.nan_repr == NanRepr.IEEE_754.value and \
+                not self._finite_values_only
+
+        def __str__(self) -> str:
+            raise NotImplementedError
+
+        def __repr__(self) -> str:
+            raise NotImplementedError
+
+        #
+        # Convenience Constructors
+        #
+
+        @classmethod
+        def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
+            "Create a signed integer scalar type (size_bits includes sign-bit)."
+            return cls(size_bits - 1, size_bits, bias if bias else 0, True)
+
+        @classmethod
+        def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
+            """Create a unsigned integer scalar type."""
+            return cls(size_bits, size_bits, bias if bias else 0, False)
+
+        @classmethod
+        def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
+            """
+            Create a standard floating point type 
+            (i.e. follows IEEE 754 conventions).
+            """
+            return cls(exponent, mantissa, 0, True)
+
+        @classmethod
+        def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
+                   nan_repr: int):
+            """
+            Create a non-standard floating point type 
+            (i.e. does not follow IEEE 754 conventions).
+            """
+            return cls(exponent, mantissa, 0, True, finite_values_only,
+                       nan_repr)
+
+elif core_C_available:
+    try:
+        import aphrodite._core_C  # noqa: F401
+    except ImportError as e:
+        logger.warning(f"Failed to import from aphrodite._core_C with {e}")
+
+    ScalarType = torch.classes._core_C.ScalarType

+ 20 - 10
aphrodite/_custom_ops.py

@@ -5,6 +5,8 @@ from typing import List, Optional, Tuple, Type
 import torch
 from loguru import logger
 
+from aphrodite._core_ext import ScalarType
+
 try:
     import aphrodite._C
 except ImportError as e:
@@ -217,10 +219,10 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
 # marlin_24
 def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                         b_meta: torch.Tensor, b_scales: torch.Tensor,
-                        workspace: torch.Tensor, num_bits: int, size_m: int,
-                        size_n: int, size_k: int) -> torch.Tensor:
+                        workspace: torch.Tensor, b_q_type: ScalarType,
+                        size_m: int, size_n: int, size_k: int) -> torch.Tensor:
     return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
-                                            workspace, num_bits, size_m,
+                                            workspace, b_q_type, size_m,
                                             size_n, size_k)
 
 
@@ -284,14 +286,22 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
     return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
 
 
-def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
-                     b_scales: torch.Tensor, b_zeros: torch.Tensor,
-                     g_idx: torch.Tensor, perm: torch.Tensor,
-                     workspace: torch.Tensor, num_bits: int, size_m: int,
-                     size_n: int, size_k: int, is_k_full: bool, has_zp: bool,
-                     use_fp32_reduce: bool) -> torch.Tensor:
+def gptq_marlin_gemm(a: torch.Tensor,
+                     b_q_weight: torch.Tensor,
+                     b_scales: torch.Tensor,
+                     b_zeros: torch.Tensor,
+                     g_idx: torch.Tensor,
+                     perm: torch.Tensor,
+                     workspace: torch.Tensor,
+                     b_q_type: ScalarType,
+                     size_m: int,
+                     size_n: int,
+                     size_k: int,
+                     is_k_full: bool,
+                     has_zp: bool = False,
+                     use_fp32_reduce: bool = False) -> torch.Tensor:
     return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
-                                         g_idx, perm, workspace, num_bits,
+                                         g_idx, perm, workspace, b_q_type,
                                          size_m, size_n, size_k, is_k_full,
                                          has_zp, use_fp32_reduce)
 

+ 31 - 18
aphrodite/quantization/awq_marlin.py

@@ -10,29 +10,40 @@ from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.utils.marlin_utils import (
-    apply_awq_marlin_linear, awq_to_marlin_zero_points,
-    check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
-    marlin_permute_scales, replace_tensor, verify_awq_marlin_supported,
-    verify_marlin_supports_shape)
+    apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
+    marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
+    replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
+from aphrodite.scalar_type import scalar_types
 
 
 class AWQMarlinConfig(QuantizationConfig):
     """Config class for AWQ Marlin"""
 
+    # num_bits -> type
+    TYPE_MAP = {
+        4: scalar_types.uint4,
+        8: scalar_types.uint8,
+    }
+
     def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
                  lm_head_quantized: bool) -> None:
-        self.weight_bits = weight_bits
-        self.pack_factor = 32 // self.weight_bits  # packed into int32
+        self.pack_factor = 32 // weight_bits  # packed into int32
         self.group_size = group_size
         self.has_zp = has_zp
         self.lm_head_quantized = lm_head_quantized
 
-        verify_awq_marlin_supported(num_bits=self.weight_bits,
-                                    group_size=self.group_size,
-                                    has_zp=self.has_zp)
+        if weight_bits not in self.TYPE_MAP:
+            raise ValueError(f"Unsupported num_bits = {weight_bits}. "
+                             f"Supported num_bits = {self.TYPE_MAP.keys()}")
+
+        self.quant_type = self.TYPE_MAP[weight_bits]
+
+        verify_marlin_supported(self.quant_type,
+                                group_size=self.group_size,
+                                has_zp=self.has_zp)
 
     def __repr__(self) -> str:
-        return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, "
+        return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
                 f"group_size={self.group_size}, "
                 f"has_zp={self.has_zp}, "
                 f"lm_head_quantized={self.lm_head_quantized})")
@@ -107,11 +118,13 @@ class AWQMarlinConfig(QuantizationConfig):
         if (num_bits is None or group_size is None or has_zp is None):
             return False
 
-        return check_awq_marlin_supported(
-            num_bits=num_bits,
-            group_size=group_size,
-            has_zp=has_zp,
-            min_capability=cls.get_min_capability())
+        if num_bits not in cls.TYPE_MAP:
+            return False
+
+        return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
+                                      group_size=group_size,
+                                      has_zp=has_zp,
+                                      min_capability=cls.get_min_capability())
 
 
 class AWQMarlinLinearMethod(LinearMethodBase):
@@ -222,7 +235,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
             layer.qweight,
             size_k=layer.input_size_per_partition,
             size_n=layer.output_size_per_partition,
-            num_bits=self.quant_config.weight_bits)
+            num_bits=self.quant_config.quant_type.size_bits)
         replace_tensor(layer, "qweight", marlin_qweight)
 
         # Permute scales from AWQ format to marlin format.
@@ -238,7 +251,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
             layer.qzeros,
             size_k=layer.num_groups,
             size_n=layer.output_size_per_partition,
-            num_bits=self.quant_config.weight_bits)
+            num_bits=self.quant_config.quant_type.size_bits)
         replace_tensor(layer, "qzeros", marlin_zp)
 
         # Not-used
@@ -259,7 +272,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
             g_idx=layer.g_idx,
             g_idx_sort_indices=layer.g_idx_sort_indices,
             workspace=layer.workspace,
-            num_bits=self.quant_config.weight_bits,
+            quant_type=self.quant_config.quant_type,
             output_size_per_partition=layer.output_size_per_partition,
             input_size_per_partition=layer.input_size_per_partition,
             bias=bias)

+ 14 - 4
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py

@@ -9,9 +9,13 @@ from aphrodite.quantization.compressed_tensors.schemes import \
     CompressedTensorsScheme
 from aphrodite.quantization.gptq_marlin_24 import (GPTQ_MARLIN_24_MAX_PARALLEL,
                                                    GPTQ_MARLIN_24_MIN_THREAD_N)
+from aphrodite.scalar_type import scalar_types
 
 __all__ = ["CompressedTensorsW4A16Sparse24"]
-W4A16SPARSE24_SUPPORTED_BITS = [4]
+W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
+    4: scalar_types.uint4b8,
+}
+W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
 
 
 class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
@@ -22,9 +26,15 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
                  group_size: Optional[int] = None):
         self.strategy = strategy
         self.group_size = group_size
-        self.num_bits = num_bits
         self.tile_size = 16
 
+        if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
+            raise ValueError(
+                f"Unsupported num_bits = {num_bits}. "
+                f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}")
+
+        self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
+
         if self.strategy == "group" and self.group_size is None:
             raise ValueError(
                 "group_size must be given when using strategy group")
@@ -43,7 +53,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
                        params_dtype: torch.dtype, weight_loader: Callable,
                        **kwargs):
 
-        pack_factor = 32 // self.num_bits
+        pack_factor = 32 // self.quant_type.size_bits
         output_size_per_partition = sum(output_partition_sizes)
 
         qweight = Parameter(
@@ -137,7 +147,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
         size_n = scales.shape[1]
 
         output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
-                                            workspace, self.num_bits, size_m,
+                                            workspace, self.quant_type, size_m,
                                             size_n, size_k)
 
         output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))

+ 19 - 9
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py

@@ -9,11 +9,16 @@ from aphrodite.quantization.compressed_tensors.schemes import \
     CompressedTensorsScheme
 from aphrodite.quantization.utils.marlin_utils import (
     apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
-    marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported,
+    marlin_permute_scales, replace_tensor, verify_marlin_supported,
     verify_marlin_supports_shape)
+from aphrodite.scalar_type import scalar_types
 
 __all__ = ["CompressedTensorsWNA16"]
-WNA16_SUPPORTED_BITS = [4, 8]
+WNA16_SUPPORTED_TYPES_MAP = {
+    4: scalar_types.uint4b8,
+    8: scalar_types.uint8b128,
+}
+WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
 
 
 class CompressedTensorsWNA16(CompressedTensorsScheme):
@@ -22,8 +27,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
                  strategy: str,
                  num_bits: int,
                  group_size: Optional[int] = None):
-        self.num_bits = num_bits
-        self.pack_factor = 32 // self.num_bits
+        self.pack_factor = 32 // num_bits
         self.strategy = strategy
 
         self.group_size: int
@@ -37,10 +41,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
         else:
             self.group_size = group_size
 
+        if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
+            raise ValueError(
+                f"Unsupported num_bits = {num_bits}. "
+                f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
+
+        self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
+
         # Verify supported on platform.
-        verify_gptq_marlin_supported(num_bits=self.num_bits,
-                                     group_size=self.group_size,
-                                     is_sym=True)
+        verify_marlin_supported(quant_type=self.quant_type,
+                                group_size=self.group_size)
 
     @classmethod
     def get_min_capability(cls) -> int:
@@ -150,7 +160,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
             perm=layer.g_idx_sort_indices,
             size_k=layer.input_size_per_partition,
             size_n=layer.output_size_per_partition,
-            num_bits=self.num_bits)
+            num_bits=self.quant_type.size_bits)
         replace_tensor(layer, "weight_packed", marlin_qweight)
 
         # Permute scales from compressed-tensors format to marlin format.
@@ -172,7 +182,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
             g_idx=layer.g_idx,
             g_idx_sort_indices=layer.g_idx_sort_indices,
             workspace=layer.workspace,
-            num_bits=self.num_bits,
+            wtype=self.quant_type,
             output_size_per_partition=layer.output_size_per_partition,
             input_size_per_partition=layer.input_size_per_partition,
             is_k_full=True,

+ 27 - 16
aphrodite/quantization/gptq_marlin.py

@@ -10,15 +10,22 @@ from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.utils.marlin_utils import (
-    apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full,
+    apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
     marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
     marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
-    verify_gptq_marlin_supported, verify_marlin_supports_shape)
+    verify_marlin_supported, verify_marlin_supports_shape)
+from aphrodite.scalar_type import scalar_types
 
 
 class GPTQMarlinConfig(QuantizationConfig):
     """Config class for GPTQ Marlin"""
 
+    # (num_bits, is_sym) -> quant_type
+    TYPE_MAP = {
+        (4, True): scalar_types.uint4b8,
+        (8, True): scalar_types.uint8b128,
+    }
+
     def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
                  is_sym: bool, lm_head_quantized: bool) -> None:
         if desc_act and group_size == -1:
@@ -26,20 +33,23 @@ class GPTQMarlinConfig(QuantizationConfig):
             # (since we have only one group per output channel)
             desc_act = False
 
-        self.weight_bits = weight_bits
-        self.pack_factor = 32 // self.weight_bits  # packed into int32
+        self.pack_factor = 32 // weight_bits  # packed into int32
         self.group_size = group_size
         self.desc_act = desc_act
-        self.is_sym = is_sym
         self.lm_head_quantized = lm_head_quantized
 
+        if (weight_bits, is_sym) not in self.TYPE_MAP:
+            raise ValueError("Unsupported quantization config: "
+                             f"bits={weight_bits}, sym={is_sym}")
+
+        self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
+
         # Verify supported on platform.
-        verify_gptq_marlin_supported(num_bits=self.weight_bits,
-                                     group_size=self.group_size,
-                                     is_sym=self.is_sym)
+        verify_marlin_supported(quant_type=self.quant_type,
+                                group_size=self.group_size)
 
     def __repr__(self) -> str:
-        return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
+        return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
                 f"group_size={self.group_size}, "
                 f"desc_act={self.desc_act}, "
                 f"lm_head_quantized={self.lm_head_quantized})")
@@ -119,11 +129,12 @@ class GPTQMarlinConfig(QuantizationConfig):
                 or desc_act is None):
             return False
 
-        return check_gptq_marlin_supported(
-            num_bits=num_bits,
-            group_size=group_size,
-            is_sym=sym,
-            min_capability=cls.get_min_capability())
+        if (num_bits, sym) not in cls.TYPE_MAP:
+            return False
+
+        return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
+                                      group_size=group_size,
+                                      min_capability=cls.get_min_capability())
 
 
 class GPTQMarlinLinearMethod(LinearMethodBase):
@@ -290,7 +301,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
             perm=layer.g_idx_sort_indices,
             size_k=layer.input_size_per_partition,
             size_n=layer.output_size_per_partition,
-            num_bits=self.quant_config.weight_bits)
+            num_bits=self.quant_config.quant_type.size_bits)
         replace_tensor(layer, "qweight", marlin_qweight)
 
         # Permute scales from autogptq format to marlin format.
@@ -316,7 +327,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
             g_idx=layer.g_idx,
             g_idx_sort_indices=layer.g_idx_sort_indices,
             workspace=layer.workspace,
-            num_bits=self.quant_config.weight_bits,
+            wtype=self.quant_config.quant_type,
             output_size_per_partition=layer.output_size_per_partition,
             input_size_per_partition=layer.input_size_per_partition,
             is_k_full=layer.is_k_full,

+ 19 - 10
aphrodite/quantization/gptq_marlin_24.py

@@ -8,15 +8,17 @@ from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.scalar_type import scalar_types
 
 GPTQ_MARLIN_24_TILE = 16
 GPTQ_MARLIN_24_MIN_THREAD_N = 128
 GPTQ_MARLIN_24_MIN_THREAD_K = 128
 GPTQ_MARLIN_24_MAX_PARALLEL = 64
 
-GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
+GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
+    scalar_types.uint4b8, scalar_types.uint8b128
+]
 GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
-GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
 
 
 class GPTQMarlin24Config(QuantizationConfig):
@@ -28,14 +30,19 @@ class GPTQMarlin24Config(QuantizationConfig):
         weight_bits: int,
         group_size: int,
     ) -> None:
-        self.weight_bits = weight_bits
+        quant_type = {
+            4: scalar_types.uint4b8,
+            8: scalar_types.uint8b128,
+        }.get(weight_bits)
+
         self.group_size = group_size
 
         # Verify
-        if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
+        if quant_type is None or \
+            quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
             raise ValueError(
-                f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
-                f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
+                f"Marlin_24 does not support quant_type = {quant_type}. "
+                f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
                 "are supported.")
         if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
             raise ValueError(
@@ -43,8 +50,10 @@ class GPTQMarlin24Config(QuantizationConfig):
                 f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
                 "are supported.")
 
+        self.quant_type = quant_type
+
         # 4 Bits packed into 32 bit datatype.
-        self.pack_factor = 32 // self.weight_bits
+        self.pack_factor = 32 // self.quant_type.size_bits
 
         # Tile size used by marlin kernels.
         self.tile_size = 16
@@ -63,8 +72,8 @@ class GPTQMarlin24Config(QuantizationConfig):
         self.perm_len = 1024
 
     def __repr__(self) -> str:
-        return "Marlin24Config(weight_bits={}, group_size={})".format(
-            self.weight_bits, self.group_size)
+        return "Marlin24Config(quant_type={}, group_size={})".format(
+            self.quant_type, self.group_size)
 
     @classmethod
     def get_name(cls) -> str:
@@ -276,7 +285,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
 
         output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
                                             workspace,
-                                            self.quant_config.weight_bits,
+                                            self.quant_config.quant_type,
                                             size_m, size_n, size_k)
 
         output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))

+ 57 - 63
aphrodite/quantization/utils/marlin_utils.py

@@ -5,6 +5,7 @@ import torch
 
 from aphrodite import _custom_ops as ops
 from aphrodite.platforms import current_platform
+from aphrodite.scalar_type import ScalarType, scalar_types
 
 from .quant_utils import pack_cols, unpack_cols
 
@@ -13,7 +14,6 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
 GPTQ_MARLIN_MIN_THREAD_K = 128
 GPTQ_MARLIN_MAX_PARALLEL = 16
 
-MARLIN_SUPPORTED_NUM_BITS = [4, 8]
 MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
 
 # In case there is a performance issue with Marlin, the variable below can be
@@ -22,76 +22,70 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
 USE_FP32_REDUCE_DEFAULT = True
 
 
-def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
-                            min_capability: Optional[int],
-                            has_zp: bool) -> Tuple[bool, Optional[str]]:
-    if min_capability is not None:
+# For binary size and compile time, we don't support the same types for with and
+#  without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
+#  TODO: we may want to move this into the C++ so its closer to the actual impl
+def query_marlin_supported_quant_types(has_zp: bool,
+                                       min_capability: Optional[int] = None):
+    if min_capability is None:
         major, minor = current_platform.get_device_capability()
-        device_capability = major * 10 + minor
-        if device_capability < min_capability:
-            return (False, "Marlin does not support device_capability = {}"
-                    ", the min_capability required is {}".format(
-                        device_capability, min_capability))
-
-    if num_bits not in MARLIN_SUPPORTED_NUM_BITS:
-        return (False, "Marlin does not support weight_bits = {}. "
-                "Only weight_bits = {} are supported.".format(
-                    num_bits, MARLIN_SUPPORTED_NUM_BITS))
-
-    if group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
-        return (False, "Marlin does not support group_size = {}. Only "
-                "group_sizes = {} are supported.".format(
-                    group_size, MARLIN_SUPPORTED_GROUP_SIZES))
-
-    if not has_zp and not is_sym:
-        return (False,
-                "Marlin without zero_points must have symmetric quantization")
+        min_capability = major * 10 + minor
 
-    return True, None
+    if min_capability < 80:
+        return []
 
+    if has_zp:
+        # AWQ style, unsigned + runtime zero-point
+        return [scalar_types.uint4, scalar_types.uint8]
+    else:
+        # GPTQ style, unsigned + symmetric bias
+        # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
+        #  to add `scalar_types.float8_e4m3fn` here
+        return [scalar_types.uint4b8, scalar_types.uint8b128]
 
-def check_gptq_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
-                                min_capability: int) -> bool:
-    cond, _ = _check_marlin_supported(num_bits,
-                                      group_size,
-                                      is_sym,
-                                      min_capability,
-                                      has_zp=False)
-    return cond
 
+def _check_marlin_supported(
+        quant_type: ScalarType,
+        group_size: Optional[int],
+        has_zp: bool,
+        min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
 
-def check_awq_marlin_supported(num_bits: int, group_size: int, has_zp: bool,
-                               min_capability: int) -> bool:
-    cond, _ = _check_marlin_supported(num_bits,
-                                      group_size,
-                                      False,
-                                      min_capability,
-                                      has_zp=has_zp)
-    return cond
+    if min_capability is None:
+        major, minor = current_platform.get_device_capability()
+        min_capability = major * 10 + minor
 
+    supported_types = query_marlin_supported_quant_types(
+        has_zp, min_capability)
 
-def verify_gptq_marlin_supported(num_bits: int, group_size: int,
-                                 is_sym: bool) -> None:
-    cond, err_msg = _check_marlin_supported(num_bits,
-                                            group_size,
-                                            is_sym,
-                                            min_capability=None,
-                                            has_zp=False)
-    if not cond:
-        assert err_msg is not None
-        raise ValueError("GPTQ" + err_msg)
+    if quant_type not in supported_types:
+        return (False, f"Marlin does not support weight_bits = {quant_type}. "
+                f"Only types = {supported_types} "
+                f"are supported (for group_size = {group_size}, "
+                f"min_capability = {min_capability}, zp = {has_zp}).")
+    if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
+        return (False, f"Marlin does not support group_size = {group_size}. "
+                f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
+                "are supported.")
+
+    return True, None
+
+
+def check_marlin_supported(quant_type: ScalarType,
+                           group_size: int,
+                           has_zp: bool = False,
+                           min_capability: Optional[int] = None) -> bool:
+    cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
+                                      min_capability)
+    return cond
 
 
-def verify_awq_marlin_supported(num_bits: int, group_size: int,
-                                has_zp: bool) -> None:
-    cond, err_msg = _check_marlin_supported(num_bits,
-                                            group_size,
-                                            False,
-                                            min_capability=None,
-                                            has_zp=has_zp)
+def verify_marlin_supported(quant_type: ScalarType,
+                            group_size: int,
+                            has_zp: bool = False) -> None:
+    cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
     if not cond:
         assert err_msg is not None
-        raise ValueError("AWQ" + err_msg)
+        raise ValueError(err_msg)
 
 
 def verify_marlin_supports_shape(output_size_per_partition: int,
@@ -245,7 +239,7 @@ def apply_gptq_marlin_linear(
         g_idx: torch.Tensor,
         g_idx_sort_indices: torch.Tensor,
         workspace: torch.Tensor,
-        num_bits: int,
+        wtype: ScalarType,
         output_size_per_partition: int,
         input_size_per_partition: int,
         is_k_full: bool,
@@ -261,7 +255,7 @@ def apply_gptq_marlin_linear(
                                   g_idx,
                                   g_idx_sort_indices,
                                   workspace,
-                                  num_bits,
+                                  wtype,
                                   size_m=reshaped_x.shape[0],
                                   size_n=output_size_per_partition,
                                   size_k=input_size_per_partition,
@@ -283,7 +277,7 @@ def apply_awq_marlin_linear(
         g_idx: torch.Tensor,
         g_idx_sort_indices: torch.Tensor,
         workspace: torch.Tensor,
-        num_bits: int,
+        quant_type: ScalarType,
         output_size_per_partition: int,
         input_size_per_partition: int,
         bias: Optional[torch.Tensor] = None,
@@ -298,7 +292,7 @@ def apply_awq_marlin_linear(
                                   g_idx,
                                   g_idx_sort_indices,
                                   workspace,
-                                  num_bits,
+                                  quant_type,
                                   size_m=reshaped_x.shape[0],
                                   size_n=output_size_per_partition,
                                   size_k=input_size_per_partition,

+ 19 - 10
aphrodite/quantization/utils/marlin_utils_test.py

@@ -5,10 +5,12 @@ from typing import List
 import numpy as np
 import torch
 
+from aphrodite.scalar_type import ScalarType
+
 from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales,
                            marlin_zero_points)
-from .quant_utils import (get_pack_factor, quantize_weights,
-                          quantize_weights_with_zp, sort_weights)
+from .quant_utils import (get_pack_factor, gptq_quantize_weights,
+                          quantize_weights, sort_weights)
 
 
 class MarlinWorkspace:
@@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int):
     return perm
 
 
-def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
+def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
                     act_order: bool):
     size_k, size_n = w.shape
+    num_bits = quant_type.size_bits
 
     # Normalize group_size
     if group_size == -1:
@@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
     assert group_size <= size_k
 
     # Quantize (and apply act_order if provided)
-    w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
-                                                       act_order)
+    w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
+        w, quant_type, group_size, act_order)
 
     # For act_order, sort the "weights" and "g_idx" so that group ids are
     # increasing
@@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
     return res_list
 
 
-def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
+def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType,
+                        group_size: int):
     size_k, size_n = w.shape
 
     # Normalize group_size
@@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
     num_groups = size_k // group_size
 
     # Quantize with zp
-    w_ref, q_w, s, zp = quantize_weights_with_zp(w, num_bits, group_size)
+    w_ref, q_w, s, zp = quantize_weights(w,
+                                         quant_type,
+                                         group_size,
+                                         zero_points=True)
 
     # Reformat to marlin
-    weight_perm = get_weight_perm(num_bits)
-    marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
+    weight_perm = get_weight_perm(quant_type.size_bits)
+    marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
+                                weight_perm)
     marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
-    marlin_zp = marlin_zero_points(zp, num_groups, size_n, num_bits)
+    marlin_zp = marlin_zero_points(zp, num_groups, size_n,
+                                   quant_type.size_bits)
 
     # Create result
     res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]

+ 14 - 16
aphrodite/quantization/utils/marlin_utils_test_24.py

@@ -6,8 +6,10 @@ from typing import List
 import numpy
 import torch
 
+from aphrodite.scalar_type import ScalarType
+
 from .marlin_utils_test import marlin_weights
-from .quant_utils import quantize_weights
+from .quant_utils import gptq_quantize_weights
 
 
 # This is PyTorch implementation of main part of reorder_meta()
@@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False):
     print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
 
 
-def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
+def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
     assert q_24.shape == (size_k, size_n)
 
-    # Remove zp to normalize over 0
-    max_q_val = (1 << num_bits) - 1
-    zp = (max_q_val + 1) // 2
-    q_24_no_zp = q_24 - zp
+    # Remove bias to normalize over 0
+    q_24_no_zp = q_24 - wtype.bias
 
     # Compress
     q_24_no_zp = q_24_no_zp.t().contiguous()
@@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
         q_24_no_zp)
     q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
 
-    # Restore zp
-    q_24_comp = q_24_no_zp_comp + zp
+    # Restore bias
+    q_24_comp = q_24_no_zp_comp + wtype.bias
 
     # Resize meta to its actual shape (without moving any data)
     meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
@@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
 
 def marlin_24_quantize(
     w: torch.Tensor,
-    num_bits: int,
+    quant_type: ScalarType,
     group_size: int,
 ):
     size_k, size_n = w.shape
@@ -441,20 +441,18 @@ def marlin_24_quantize(
     w_24, mask_24 = inject_24(w, size_k, size_n)
 
     # Quantize
-    w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
-                                                             num_bits,
-                                                             group_size,
-                                                             act_order=False)
+    w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
+        w_24, quant_type, group_size, act_order=False)
 
     # Compress quantized weight
     q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
-                                                     num_bits)
+                                                     quant_type)
     size_k_comp = size_k // 2
 
     # Reformat to marlin
-    weight_perm = get_weight_perm_24(num_bits)
+    weight_perm = get_weight_perm_24(quant_type.size_bits)
     marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
-                                        num_bits, weight_perm)
+                                        quant_type.size_bits, weight_perm)
     marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
 
     # Create result

+ 63 - 84
aphrodite/quantization/utils/quant_utils.py

@@ -4,7 +4,10 @@ from typing import List
 import numpy
 import torch
 
-SUPPORTED_NUM_BITS = [4, 8]
+from aphrodite.quantization.qqq import MARLIN_QQQ_SUPPORTED_NUM_BITS
+from aphrodite.scalar_type import ScalarType, scalar_types
+
+SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
 SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
 
 # NOTE: this is a hack. We should update each model to register the
@@ -45,7 +48,7 @@ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
 
 
 def get_pack_factor(num_bits):
-    assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
+    assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
     return 32 // num_bits
 
 
@@ -74,24 +77,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
     )
 
 
-def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
-                     act_order: bool):
+def quantize_weights(w: torch.Tensor,
+                     quant_type: ScalarType,
+                     group_size: int,
+                     zero_points: bool = False):
+    assert quant_type.is_integer(), \
+        "Floating point quantization may work but has not been tested"
+
     orig_device = w.device
+    orig_type = w.dtype
     size_k, size_n = w.shape
 
     assert w.is_floating_point(), "w must be float"
-    assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
-    assert group_size in SUPPORTED_GROUP_SIZES + [
-        size_k
-    ], f"Unsupported groupsize = {group_size}"
 
     if group_size == -1:
         group_size = size_k
     assert group_size <= size_k
 
-    max_q_val = 2**num_bits - 1
-    half_q_val = (max_q_val + 1) // 2
-
     # Reshape to [groupsize, -1]
     if group_size < size_k:
         w = w.reshape((-1, group_size, size_n))
@@ -99,16 +101,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
         w = w.reshape((group_size, -1))
 
     # Compute scale for each group
-    s = torch.max(torch.abs(w), 0, keepdim=True)[0]
-    s *= 2 / max_q_val  # 2 => symmetric
+    max_val = torch.max(w, 0, keepdim=True).values
+    min_val = torch.min(w, 0, keepdim=True).values
+
+    max_q_val = quant_type.max()
+    min_q_val = quant_type.min()
+
+    if zero_points:
+        assert not quant_type.is_signed() and quant_type.max() > 0
+        w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
+        maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
+            .clamp(min_q_val, max_q_val).int()
+    else:
+        # If the bias is such that there are no possible negative/positive
+        #  values, set the max value to inf to avoid divide by 0
+        w_s = torch.max(
+            abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
+            abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
+        maybe_w_zp = None
 
     # Quantize
-    q_w = torch.round(w / s).int()
-    q_w += half_q_val
-    q_w = torch.clamp(q_w, 0, max_q_val)
+    w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
+    w_q = torch.clamp(w_q, min_q_val, max_q_val)
 
     # Compute ref (dequantized)
-    w_ref = (q_w - half_q_val).half() * s
+    w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
+
+    if quant_type.has_bias():
+        w_q += quant_type.bias
 
     # Restore original shapes
     if group_size < size_k:
@@ -119,90 +139,48 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
             w = w.reshape((size_k, size_n)).contiguous()
             return w
 
-        q_w = reshape_w(q_w)
+        w_q = reshape_w(w_q)
         w_ref = reshape_w(w_ref)
 
-    s = s.reshape((-1, size_n)).contiguous()
+    w_s = w_s.reshape((-1, size_n)).contiguous()
 
-    # Apply act_order
-    g_idx = torch.empty(0, dtype=torch.int, device=w.device)
-    rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
-    if act_order:
-        assert (
-            group_size < size_k
-        ), "For act_order, groupsize = {} must be less than size_k = {}".format(
-            group_size, size_k)
-
-        w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
+    if zero_points:
+        maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
+        maybe_w_zp = maybe_w_zp.to(device=orig_device)
 
     return (
         w_ref.to(device=orig_device),
-        q_w.to(device=orig_device),
-        s.to(device=orig_device),
-        g_idx.to(device=orig_device),
-        rand_perm.to(device=orig_device),
+        w_q.to(device=orig_device),
+        w_s.to(device=orig_device),
+        maybe_w_zp,
     )
 
 
-def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
-    orig_device = w.device
-    size_k, size_n = w.shape
+def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
+                          group_size: int, act_order: bool):
+    size_k, _ = w.shape
 
     assert w.is_floating_point(), "w must be float"
-    assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
+    assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \
+        f"Unsupported gptq type = {quant_type}"
     assert group_size in SUPPORTED_GROUP_SIZES + [
         size_k
     ], f"Unsupported groupsize = {group_size}"
 
-    if group_size == -1:
-        group_size = size_k
-    assert group_size <= size_k
-
-    max_q_val = 2**num_bits - 1
-    min_q_val = 0
+    w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
 
-    # Reshape to [groupsize, -1]
-    if group_size < size_k:
-        w = w.reshape((-1, group_size, size_n))
-        w = w.permute(1, 0, 2)
-        w = w.reshape((group_size, -1))
-
-    # Compute scale for each group
-    max = torch.max(w, 0, keepdim=True)[0]
-    min = torch.min(w, 0, keepdim=True)[0]
-    s = (max - min).clamp(min=1e-5) / max_q_val
-
-    # Compute zero-point for each group
-    zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int()
-
-    # Quantize
-    q_w = torch.round(w / s).int() + zp
-    q_w = torch.clamp(q_w, min_q_val, max_q_val)
-
-    # Compute ref (dequantized)
-    w_ref = (q_w - zp).half() * s
-
-    # Restore original shapes
-    if group_size < size_k:
-
-        def reshape_w(w):
-            w = w.reshape((group_size, -1, size_n))
-            w = w.permute(1, 0, 2)
-            w = w.reshape((size_k, size_n)).contiguous()
-            return w
-
-        q_w = reshape_w(q_w)
-        w_ref = reshape_w(w_ref)
+    # Apply act_order
+    g_idx = torch.empty(0, dtype=torch.int, device=w.device)
+    rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
+    if act_order:
+        assert (
+            group_size < size_k
+        ), "For act_order, groupsize = {} must be less than size_k = {}".format(
+            group_size, size_k)
 
-    s = s.reshape((-1, size_n)).contiguous()
-    zp = zp.reshape((-1, size_n)).contiguous()
+        w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
 
-    return (
-        w_ref.to(device=orig_device),
-        q_w.to(device=orig_device),
-        s.to(device=orig_device),
-        zp.to(device=orig_device),
-    )
+    return w_ref, w_q, w_s, g_idx, rand_perm
 
 
 # QQQ employs different quant schemes for per-group and
@@ -212,7 +190,8 @@ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
     size_k, size_n = w.shape
 
     assert w.is_floating_point(), "w must be float"
-    assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
+    assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \
+           f"Unsupported num_bits = {num_bits}"
     assert group_size in SUPPORTED_GROUP_SIZES + [
         size_k
     ], f"Unsupported groupsize = {group_size}"

+ 35 - 0
aphrodite/scalar_type.py

@@ -0,0 +1,35 @@
+from ._core_ext import NanRepr, ScalarType
+
+# naming generally follows: https://github.com/jax-ml/ml_dtypes
+# for floating point types (leading f) the scheme is:
+#  `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
+#  flags:
+#  - no-flags: means it follows IEEE 754 conventions
+#  - f: means finite values only (no infinities)
+#  - n: means nans are supported (non-standard encoding)
+# for integer types the scheme is:
+#  `[u]int<size_bits>[b<bias>]`
+#  - if bias is not present it means its zero
+
+
+class scalar_types:
+    int4 = ScalarType.int_(4, None)
+    uint4 = ScalarType.uint(4, None)
+    int8 = ScalarType.int_(8, None)
+    uint8 = ScalarType.uint(8, None)
+    float8_e4m3fn = ScalarType.float_(4, 3, True,
+                                      NanRepr.EXTD_RANGE_MAX_MIN.value)
+    float8_e5m2 = ScalarType.float_IEEE754(5, 2)
+    float16_e8m7 = ScalarType.float_IEEE754(8, 7)
+    float16_e5m10 = ScalarType.float_IEEE754(5, 10)
+
+    # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
+    float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
+
+    # "gptq" types
+    uint4b8 = ScalarType.uint(4, 8)
+    uint8b128 = ScalarType.uint(8, 128)
+
+    # colloquial names
+    bfloat16 = float16_e8m7
+    float16 = float16_e5m10

+ 0 - 1
aphrodite/server/launch.py

@@ -8,7 +8,6 @@ from loguru import logger
 
 
 async def serve_http(app: FastAPI, **uvicorn_kwargs: Any) -> None:
-    logger.info("Available routes are:")
     for route in app.routes:
         methods = getattr(route, "methods", None)
         path = getattr(route, "path", None)

+ 0 - 1
cmake/cpu_extension.cmake

@@ -113,6 +113,5 @@ define_gpu_extension_target(
     WITH_SOABI
 )
 
-add_custom_target(default)
 message(STATUS "Enabling C extension.")
 add_dependencies(default _C)

+ 0 - 0
kernels/registration.h → kernels/core/registration.h


+ 382 - 0
kernels/core/scalar_type.hpp

@@ -0,0 +1,382 @@
+#pragma once
+
+#include <torch/custom_class.h>
+
+namespace aphrodite {
+
+//
+//  ScalarType can represent a wide range of floating point and integer types,
+//  in particular it can be used to represent sub-byte data types (something
+//  that torch.dtype currently does not support).
+//
+//  ScalarTypeTorch is a subclass of ScalarType that is compatible with
+//  TORCH_LIBRARY, making it accessible from Python as well meaning this class
+//  can be used as a argument for custom operators, helping to simplify these
+//  interfaces.
+//
+//  The type definitions on the Python side can be found in: aphrodite/_core_ext.pyi
+//  these type definitions should be kept up to date with any Python API changes
+//  here.
+//
+class ScalarType {
+ public:
+  enum NanRepr : int64_t {
+    NAN_NONE = 0,                // nans are not supported
+    NAN_IEEE_754 = 1,            // nans are: exp all 1s, mantissa not all 0s
+    NAN_EXTD_RANGE_MAX_MIN = 2,  // nans are: exp all 1s, mantissa all 1s
+
+    NAN_REPR_ID_MAX
+  };
+
+  constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa,
+                       int64_t bias, bool finite_values_only = false,
+                       NanRepr nan_repr = NAN_IEEE_754)
+      : exponent(exponent),
+        mantissa(mantissa),
+        bias(bias),
+        signed_(signed_),
+        finite_values_only(finite_values_only),
+        nan_repr(nan_repr){};
+
+  static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) {
+    return ScalarType(true, 0, size_bits - 1, bias);
+  }
+
+  static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) {
+    return ScalarType(false, 0, size_bits, bias);
+  }
+
+  // IEEE 754 compliant floating point type
+  static constexpr ScalarType float_IEEE754(int64_t exponent,
+                                            int64_t mantissa) {
+    TORCH_CHECK(mantissa > 0 && exponent > 0);
+    return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754);
+  }
+
+  // IEEE 754 non-compliant floating point type
+  static constexpr ScalarType float_(int64_t exponent, int64_t mantissa,
+                                     bool finite_values_only,
+                                     NanRepr nan_repr) {
+    TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
+    TORCH_CHECK(mantissa > 0 && exponent > 0);
+    TORCH_CHECK(nan_repr != NAN_IEEE_754,
+                "use `float_IEEE754` constructor for floating point types that "
+                "follow IEEE 754 conventions");
+    return ScalarType(true, exponent, mantissa, 0, finite_values_only,
+                      nan_repr);
+  }
+
+  int64_t const exponent;  // size of the exponent field (0 for integer types)
+  int64_t const mantissa;  // size of the mantissa field (size of the integer
+                           // excluding the sign bit for integer types)
+  int64_t const bias;      // stored values equal value + bias,
+                           // used for quantized type
+  bool const signed_;  // flag if the type supports negative numbers (i.e. has a
+                       // sign bit)
+
+  // Extra Floating point info
+  bool const finite_values_only;  // i.e. no +/-inf if true
+  NanRepr const nan_repr;         // how NaNs are represented
+                                  // (not applicable for integer types)
+
+  int64_t size_bits() const { return mantissa + exponent + is_signed(); }
+  bool is_signed() const { return signed_; }
+  bool is_integer() const { return exponent == 0; }
+  bool is_floating_point() const { return exponent > 0; }
+  bool is_ieee_754() const {
+    return is_floating_point() && finite_values_only == false &&
+           nan_repr == NAN_IEEE_754;
+  }
+  bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; }
+  bool has_infs() const {
+    return is_floating_point() && finite_values_only == false;
+  }
+  bool has_bias() const { return bias != 0; }
+
+ private:
+  double _floating_point_max() const {
+    TORCH_CHECK(mantissa <= 52 && exponent <= 11,
+                "Cannot represent max/min as a double for type ", str());
+
+    uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
+    if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
+      max_mantissa -= 1;
+    }
+
+    uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
+    if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
+      TORCH_CHECK(exponent < 11,
+                  "Cannot represent max/min as a double for type ", str());
+      max_exponent += 1;
+    }
+
+    // adjust the exponent to match that of a double
+    //  for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
+    //  is the exponent bits), there is some precedent for non-standard biases,
+    //  example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
+    //  but to avoid premature over complication we are just assuming the
+    //  standard exponent bias until there is a need to support non-standard
+    //  biases
+    uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
+    uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1;  // double e = 11
+
+    uint64_t max_exponent_double =
+        max_exponent - exponent_bias + exponent_bias_double;
+
+    // shift the mantissa into the position for a double and
+    // the exponent
+    uint64_t double_raw =
+        (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
+
+    return *reinterpret_cast<double*>(&double_raw);
+  }
+
+  std::variant<int64_t, double> _raw_max() const {
+    if (is_floating_point()) {
+      return {_floating_point_max()};
+    } else {
+      TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
+                  "Cannot represent max as a int64_t");
+      return {(int64_t(1) << mantissa) - 1};
+    }
+  }
+
+  std::variant<int64_t, double> _raw_min() const {
+    if (is_floating_point()) {
+      TORCH_CHECK(is_signed(),
+                  "We currently assume all floating point types are signed");
+      constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
+
+      double max = _floating_point_max();
+      uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
+      uint64_t min_raw = max_raw | sign_bit_double;
+      return {*reinterpret_cast<double*>(&min_raw)};
+    } else {
+      TORCH_CHECK(!is_signed() || size_bits() <= 64,
+                  "Cannot represent min as a int64_t");
+      if (is_signed()) {
+        // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
+        // then perform an arithmetic shift right to set all the bits above
+        // (size_bits() - 1) to 1
+        return {INT64_MIN >> (64 - size_bits())};
+      } else {
+        return {int64_t(0)};
+      }
+    }
+  }
+
+ public:
+  // Max representable value for this scalar type.
+  // (accounting for bias if there is one)
+  std::variant<int64_t, double> max() const {
+    return std::visit(
+        [this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
+        _raw_max());
+  }
+
+  // Min representable value for this scalar type.
+  // (accounting for bias if there is one)
+  std::variant<int64_t, double> min() const {
+    return std::visit(
+        [this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
+        _raw_min());
+  }
+
+  std::string str() const {
+    /* naming generally follows: https://github.com/jax-ml/ml_dtypes
+     * for floating point types (leading f) the scheme is:
+     *  `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
+     *  flags:
+     *  - no-flags: means it follows IEEE 754 conventions
+     *  - f: means finite values only (no infinities)
+     *  - n: means nans are supported (non-standard encoding)
+     * for integer types the scheme is:
+     *  `[u]int<size_bits>[b<bias>]`
+     *  - if bias is not present it means its zero
+     */
+    if (is_floating_point()) {
+      auto ret = "float" + std::to_string(size_bits()) + "_e" +
+                 std::to_string(exponent) + "m" + std::to_string(mantissa);
+      if (!is_ieee_754()) {
+        if (finite_values_only) {
+          ret += "f";
+        }
+        if (nan_repr != NAN_NONE) {
+          ret += "n";
+        }
+      }
+      return ret;
+    } else {
+      auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
+      if (has_bias()) {
+        ret += "b" + std::to_string(bias);
+      }
+      return ret;
+    }
+  }
+
+  bool operator==(ScalarType const& other) const {
+    return mantissa == other.mantissa && exponent == other.exponent &&
+           bias == other.bias && signed_ == other.signed_ &&
+           finite_values_only == other.finite_values_only &&
+           nan_repr == other.nan_repr;
+  }
+};
+
+// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from
+//  torch::CustomClassHolder), we use multiple inheritance here since we cannot
+//  have ScalarType inherit from torch::CustomClassHolder and have a constexpr
+//  constructor at the same time (torch::CustomClassHolder does not have a
+//  constexpr destructor)
+class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
+ public:
+  ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
+                  bool _signed)
+      : ScalarType(exponent, mantissa, bias, _signed){};
+
+  ScalarTypeTorch(ScalarType type) : ScalarType(type){};
+
+  using Base = ScalarType;
+  using Self = ScalarTypeTorch;
+  using SelfPtr = c10::intrusive_ptr<Self>;
+
+  static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
+    return c10::make_intrusive<Self>(
+        ScalarType::int_(size_bits, bias.value_or(0)));
+  }
+
+  static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
+    return c10::make_intrusive<Self>(
+        ScalarType::uint(size_bits, bias.value_or(0)));
+  }
+
+  static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
+    return c10::make_intrusive<Self>(
+        ScalarType::float_IEEE754(exponent, mantissa));
+  }
+
+  static SelfPtr float_(int64_t exponent, int64_t mantissa,
+                        bool finite_values_only, int64_t nan_repr) {
+    return c10::make_intrusive<Self>(ScalarType::float_(
+        exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
+  }
+
+  template <typename T>
+  static void bind_readonly_property(torch::class_<Self>& cls,
+                                     std::string const& name, T Base::*field) {
+    auto getter_func = [field = std::move(field)](SelfPtr const& self) {
+      if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
+        return (self.get()->*field)();
+      } else {
+        return self.get()->*field;
+      }
+    };
+
+    cls.def_property(name, getter_func);
+  }
+
+  template <typename MemberFunc, typename Cls>
+  static void bind_function(torch::class_<Self>& cls, const std::string& name,
+                            MemberFunc Cls::*member) {
+    cls.def(name, [member = std::move(member)](SelfPtr const& self) {
+      return (self.get()->*member)();
+    });
+  }
+
+  template <typename Func>
+  static void bind_function(torch::class_<Self>& cls, const std::string& name,
+                            Func func) {
+    cls.def(name, func);
+  }
+
+  template <typename Func>
+  static void bind_static_function(torch::class_<Self>& cls,
+                                   const std::string& name, Func func) {
+    cls.def_static(name, func);
+  }
+
+  static void bind_class(torch::Library& lib) {
+    auto cls = lib.class_<ScalarTypeTorch>("ScalarType")
+                   .def(torch::init<int64_t, int64_t, int64_t, bool>());
+
+    // Bind Properties
+    bind_readonly_property(cls, "mantissa", &Base::mantissa);
+    bind_readonly_property(cls, "exponent", &Base::exponent);
+    bind_readonly_property(cls, "bias", &Base::bias);
+    bind_readonly_property(cls, "signed", &Base::is_signed);
+    bind_readonly_property(cls, "size_bits", &Base::size_bits);
+
+    // Bind member functions
+    bind_function(cls, "is_signed", &Base::is_signed);
+    bind_function(cls, "is_integer", &Base::is_integer);
+    bind_function(cls, "is_floating_point", &Base::is_floating_point);
+    bind_function(cls, "is_ieee_754", &Base::is_ieee_754);
+    bind_function(cls, "has_nans", &Base::has_nans);
+    bind_function(cls, "has_infs", &Base::has_infs);
+    bind_function(cls, "has_bias", &Base::has_bias);
+
+    bind_function(cls, "max", [](SelfPtr const& self) {
+      return std::visit([](auto arg) { return c10::IValue(arg); },
+                        self.get()->max());
+    });
+    bind_function(cls, "min", [](SelfPtr const& self) {
+      return std::visit([](auto arg) { return c10::IValue(arg); },
+                        self.get()->min());
+    });
+
+    bind_function(cls, "__str__", &Base::str);
+    bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
+      return *self == *other;
+    });
+    bind_function(cls, "__repr__", [](SelfPtr const& self) {
+      return "ScalarType." + self.get()->str();
+    });
+
+    // Bind static functions (convenience constructors)
+    bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
+    bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
+    bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754);
+    bind_static_function(cls, "float_", &ScalarTypeTorch::float_);
+  }
+};
+
+using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
+
+// "rust style" names generally following:
+//   https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
+static inline constexpr auto kS4 = ScalarType::int_(4);
+static inline constexpr auto kU4 = ScalarType::uint(4);
+static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
+static inline constexpr auto kS8 = ScalarType::int_(8);
+static inline constexpr auto kU8 = ScalarType::uint(8);
+static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
+
+static inline constexpr auto kFE3M2f =
+    ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
+static inline constexpr auto kFE4M3fn =
+    ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
+static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
+static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
+static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
+
+// Fixed width style names, generally following:
+//  https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
+static inline constexpr auto kInt4 = kS4;
+static inline constexpr auto kUint4 = kU4;
+static inline constexpr auto kUint4b8 = kU4B8;
+static inline constexpr auto kInt8 = kS8;
+static inline constexpr auto kUint8 = kU8;
+static inline constexpr auto kUint8b128 = kU8B128;
+
+static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
+static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
+static inline constexpr auto kFloat8_e5m2 = kFE5M2;
+static inline constexpr auto kFloat16_e8m7 = kFE8M7;
+static inline constexpr auto kFloat16_e5m10 = kFE5M10;
+
+// colloquial names
+static inline constexpr auto kHalf = kFE5M10;
+static inline constexpr auto kFloat16 = kHalf;
+static inline constexpr auto kBFloat16 = kFE8M7;
+
+};  // namespace aphrodite

+ 16 - 0
kernels/core/torch_bindings.cpp

@@ -0,0 +1,16 @@
+#include <torch/library.h>
+
+#include "scalar_type.hpp"
+#include "registration.h"
+
+// Note the CORE exstension will be built for (almost) all hardware targets so
+// new additions must account for this. (currently not built for TPU and Neuron)
+
+TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) {
+  // ScalarType, a custom class for representing data types that supports
+  // quantized types, declared here so it can be used when creating interfaces
+  // for custom ops.
+  aphrodite::ScalarTypeTorch::bind_class(lib);
+}
+
+REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

+ 1 - 1
kernels/cpu/torch_bindings.cpp

@@ -1,6 +1,6 @@
 #include "cache.h"
 #include "ops.h"
-#include "registration.h"
+#include "core/registration.h"
 
 #include <torch/library.h>
 

+ 1 - 1
kernels/moe/torch_bindings.cpp

@@ -1,4 +1,4 @@
-#include "registration.h"
+#include "../core/registration.h"
 #include "moe_ops.h"
 
 TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {

+ 44 - 22
kernels/quantization/gptq_marlin/gptq_marlin.cu

@@ -21,6 +21,7 @@
 
 #include "marlin.cuh"
 #include "marlin_dtypes.cuh"
+#include "core/scalar_type.hpp"
 
 #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t)               \
   static_assert(std::is_same<scalar_t, half>::value ||          \
@@ -71,14 +72,15 @@ __global__ void Marlin(
     bool use_fp32_reduce  // whether to use fp32 global reduce
 ) {}
 
-}  // namespace gptq_marlin
+}  // namespace marlin
 
 torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                torch::Tensor& b_scales, torch::Tensor& b_zeros,
                                torch::Tensor& g_idx, torch::Tensor& perm,
-                               torch::Tensor& workspace, int64_t num_bits,
+                               torch::Tensor& workspace,
+                               aphrodite::ScalarTypeTorchPtr const& b_q_type,
                                int64_t size_m, int64_t size_n, int64_t size_k,
-                               bool is_k_full) {
+                               bool is_k_full, bool has_zp) {
   TORCH_CHECK_NOT_IMPLEMENTED(false,
                               "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
   return torch::empty({1, 1});
@@ -1963,18 +1965,29 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
     __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
 
 template <typename scalar_t>
-void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
-                     void* s, void* zp, void* g_idx, void* perm, void* a_tmp,
-                     int prob_m, int prob_n, int prob_k, void* workspace,
-                     int num_bits, bool has_act_order, bool is_k_full,
-                     bool has_zp, int num_groups, int group_size, int dev,
-                     cudaStream_t stream, int thread_k, int thread_n, int sms,
-                     int max_par, bool use_fp32_reduce) {
-  TORCH_CHECK(num_bits == 4 || num_bits == 8,
-              "num_bits must be 4 or 8. Got = ", num_bits);
+void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
+               void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
+               int prob_n, int prob_k, void* workspace,
+               aphrodite::ScalarType const& q_type, bool has_act_order,
+               bool is_k_full, bool has_zp, int num_groups, int group_size,
+               int dev, cudaStream_t stream, int thread_k, int thread_n,
+               int sms, int max_par, bool use_fp32_reduce) {
+  if (has_zp) {
+    TORCH_CHECK(
+        q_type == aphrodite::kU4 || q_type == aphrodite::kU8,
+        "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
+  } else {
+    TORCH_CHECK(
+        q_type == aphrodite::kU4B8 || q_type == aphrodite::kU8B128,
+        "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
+        q_type.str());
+  }
+
   TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
               ", ", prob_n, ", ", prob_k, "]");
 
+  // TODO: remove alias when we start supporting other 8bit types
+  int num_bits = q_type.size_bits();
   int tot_m = prob_m;
   int tot_m_blocks = div_ceil(tot_m, 16);
   int pad = 16 * tot_m_blocks - tot_m;
@@ -2126,19 +2139,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
   }
 }
 
-}  // namespace gptq_marlin
+}  // namespace marlin
 
 torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                torch::Tensor& b_scales, torch::Tensor& b_zeros,
                                torch::Tensor& g_idx, torch::Tensor& perm,
-                               torch::Tensor& workspace, int64_t num_bits,
+                               torch::Tensor& workspace,
+                               aphrodite::ScalarTypeTorchPtr const& b_q_type,
                                int64_t size_m, int64_t size_n, int64_t size_k,
                                bool is_k_full, bool has_zp,
                                bool use_fp32_reduce) {
-  // Verify num_bits
-  TORCH_CHECK(num_bits == 4 || num_bits == 8,
-              "num_bits must be 4 or 8. Got = ", num_bits);
-  int pack_factor = 32 / num_bits;
+  if (has_zp) {
+    TORCH_CHECK(*b_q_type == aphrodite::kU4 || *b_q_type == aphrodite::kU8,
+                "b_q_type must be u4 or u8 when has_zp = True. Got = ",
+                b_q_type->str());
+  } else {
+    TORCH_CHECK(
+        *b_q_type == aphrodite::kU4B8 || *b_q_type == aphrodite::kU8B128,
+        "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
+        b_q_type->str());
+  }
+
+  int pack_factor = 32 / b_q_type->size_bits();
 
   // Verify A
   TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
@@ -2265,21 +2287,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
 
   int dev = a.get_device();
   if (a.scalar_type() == at::ScalarType::Half) {
-    marlin::marlin_mm_f16i4<half>(
+    marlin::marlin_mm<half>(
         a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
         c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
         b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
         a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
-        workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
+        workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
         num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
         thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
   } else if (a.scalar_type() == at::ScalarType::BFloat16) {
-    marlin::marlin_mm_f16i4<nv_bfloat16>(
+    marlin::marlin_mm<nv_bfloat16>(
         a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
         c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
         b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
         perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
-        workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
+        workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
         num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
         thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
   } else {

+ 10 - 7
kernels/quantization/marlin/sparse/marlin_24_cuda_kernel.cu

@@ -27,6 +27,7 @@
 #include <iostream>
 
 #include "common/base.h"
+#include "../../../core/scalar_type.hpp"
 
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
 
@@ -86,7 +87,8 @@ __global__ void Marlin_24(
 torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                   torch::Tensor& b_meta,
                                   torch::Tensor& b_scales,
-                                  torch::Tensor& workspace, int64_t num_bits,
+                                  torch::Tensor& workspace,
+                                  aphrodite::ScalarTypeTorchPtr const& b_q_type,
                                   int64_t size_m, int64_t size_n,
                                   int64_t size_k) {
   TORCH_CHECK_NOT_IMPLEMENTED(
@@ -1025,13 +1027,14 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
 torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                   torch::Tensor& b_meta,
                                   torch::Tensor& b_scales,
-                                  torch::Tensor& workspace, int64_t num_bits,
+                                  torch::Tensor& workspace,
+                                  aphrodite::ScalarTypeTorchPtr const& b_q_type,
                                   int64_t size_m, int64_t size_n,
                                   int64_t size_k) {
   // Verify num_bits
-  TORCH_CHECK(num_bits == 4 || num_bits == 8,
-              "num_bits must be 4 or 8. Got = ", num_bits);
-  int pack_factor = 32 / num_bits;
+  TORCH_CHECK(*b_q_type == aphrodite::kU4B8 || *b_q_type == aphrodite::kU8B128,
+              "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str());
+  int pack_factor = 32 / b_q_type->size_bits();
 
   // Verify M
   TORCH_CHECK(size_m == a.size(0),
@@ -1126,8 +1129,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
   marlin_24::marlin_cuda_2_4(
       a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
       b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
-      num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
-      thread_m, sms, max_par);
+      b_q_type->size_bits(), groupsize, dev,
+      at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par);
 
   return c;
 }

+ 6 - 2
kernels/quantization/quant_ops.h

@@ -2,6 +2,8 @@
 
 #include <torch/library.h>
 
+#include "core/scalar_type.hpp"
+
 #ifndef USE_ROCM
 // AQLM
 torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
@@ -64,14 +66,16 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
 torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                   torch::Tensor& b_meta,
                                   torch::Tensor& b_scales,
-                                  torch::Tensor& workspace, int64_t num_bits,
+                                  torch::Tensor& workspace,
+                                  aphrodite::ScalarTypeTorchPtr const& b_q_type,
                                   int64_t size_m, int64_t size_n,
                                   int64_t size_k);
 
 torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                torch::Tensor& b_scales, torch::Tensor& b_zeros,
                                torch::Tensor& g_idx, torch::Tensor& perm,
-                               torch::Tensor& workspace, int64_t num_bits,
+                               torch::Tensor& workspace,
+                               aphrodite::ScalarTypeTorchPtr const& b_q_type,
                                int64_t size_m, int64_t size_n, int64_t size_k,
                                bool is_k_full, bool has_zp,
                                bool use_fp32_reduce);

+ 1 - 1
kernels/torch_bindings.cpp

@@ -1,7 +1,7 @@
 #include "cache.h"
 #include "cuda_utils.h"
 #include "ops.h"
-#include "registration.h"
+#include "core/registration.h"
 #include "quantization/quant_ops.h"
 
 #include <torch/library.h>

+ 8 - 1
setup.py

@@ -260,6 +260,10 @@ def _build_custom_ops() -> bool:
     return _is_cuda() or _is_hip() or _is_cpu()
 
 
+def _build_core_ext() -> bool:
+    return not _is_neuron() and not _is_tpu()
+
+
 def get_hipcc_rocm_version():
     # Run the hipcc --version command
     result = subprocess.run(['hipcc', '--version'],
@@ -421,6 +425,9 @@ def get_requirements() -> List[str]:
 
 ext_modules = []
 
+if _build_core_ext():
+    ext_modules.append(CMakeExtension(name="aphrodite._core_C"))
+
 if _is_cuda() or _is_hip():
     ext_modules.append(CMakeExtension(name="aphrodite._moe_C"))
 
@@ -466,7 +473,7 @@ setup(
         "ray": ["ray>=2.9"],
     },
     ext_modules=ext_modules,
-    cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
+    cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
     package_data=package_data,
     entry_points={
         "console_scripts": [