Jelajahi Sumber

feat: quant_llm support (#755)

* feat: add fp6 quant kernels

Co-authored-by: intervitens <intervitens@tutanota.com>

* add implementation

* formatting

* fix: add prefix

* codespell

* compile kernel for e3m3

* add fp4_e2m1 support

* add support for fp2, fp3 and configurable exponent

* bump capability to sm_80

* missed one more directive

* clean up

---------

Co-authored-by: intervitens <intervitens@tutanota.com>
AlpinDale 5 bulan lalu
induk
melakukan
73177656ed

+ 1 - 0
CMakeLists.txt

@@ -199,6 +199,7 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
   FetchContent_MakeAvailable(cutlass)
 
   list(APPEND APHRODITE_EXT_SRC
+    "kernels/quantization/fp6/fp6_linear.cu"
     "kernels/mamba/mamba_ssm/selective_scan_fwd.cu"
     "kernels/mamba/causal_conv1d/causal_conv1d.cu"
     "kernels/quantization/aqlm/gemm_kernels.cu"

+ 14 - 0
aphrodite/_custom_ops.py

@@ -465,6 +465,20 @@ def ggml_mul_mat_a8(
     return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
 
 
+# fp6
+def fp_eXmY_linear_forward_cuda(
+    EXPONENT: int,
+    MANTISSA: int,
+    _in_feats: torch.Tensor,
+    _weights: torch.Tensor,
+    _scales: torch.Tensor,
+    splitK: int = 1,
+) -> torch.Tensor:
+    return torch.ops._C.fp_eXmY_linear_forward_cuda(EXPONENT, MANTISSA,
+                                                    _in_feats, _weights,
+                                                    _scales, splitK)
+
+
 # mamba
 def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
                       bias_: Optional[torch.Tensor],

+ 75 - 0
aphrodite/common/config.py

@@ -48,6 +48,12 @@ _PP_SUPPORTED_MODELS = [
 ]
 
 _OPTIMIZED_QUANTS = [
+    "fp2",
+    "fp3",
+    "fp4",
+    "fp5",
+    "fp6",
+    "fp7",
     "fp8",
     "marlin",
     "gptq_marlin_24",
@@ -57,6 +63,7 @@ _OPTIMIZED_QUANTS = [
     "compressed-tensors",
     "compressed_tensors",
     "experts_int8",
+    "quant_llm",
 ]
 
 
@@ -95,6 +102,8 @@ class ModelConfig:
             weights. If None, we assume the model weights are not quantized.
         deepspeed_fp_bits: Number of bits to use for DeepSpeed FP quantization.
             Supported number of bits are: 4, 6, 8, 12.
+        quant_llm_fp_bits: Number of bits to use for QuantLLM FP quantization.
+            Supported number of bits are: 5, 6, 7.
         quantization_param_path: Path to JSON file containing scaling factors.
             Used to load KV cache scaling factors into the model when KV cache
             type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
@@ -142,6 +151,8 @@ class ModelConfig:
         max_model_len: Optional[int] = None,
         quantization: Optional[str] = None,
         deepspeed_fp_bits: Optional[int] = None,
+        quant_llm_fp_bits: Optional[int] = None,
+        quant_llm_exp_bits: Optional[int] = None,
         quantization_param_path: Optional[str] = None,
         enforce_eager: Optional[bool] = None,
         max_context_len_to_capture: Optional[int] = None,
@@ -168,6 +179,8 @@ class ModelConfig:
             self.tokenizer_revision = tokenizer_revision
         self.quantization = quantization
         self.deepspeed_fp_bits = deepspeed_fp_bits
+        self.quant_llm_fp_bits = quant_llm_fp_bits
+        self.quant_llm_exp_bits = quant_llm_exp_bits
         self.quantization_param_path = quantization_param_path
         self.enforce_eager = enforce_eager
         self.max_context_len_to_capture = max_context_len_to_capture
@@ -316,6 +329,68 @@ class ModelConfig:
                 "quant_method": "deepspeedfp"
             }
 
+        VALID_QUANT_LLM_FP_BITS = [2, 3, 4, 5, 6, 7]
+        VALID_QUANT_LLM_EXPONENTS = [1, 2, 3, 4, 5]
+        # The formula is mantissa_bits = fp_bits - exp_bits - 1
+        # The default exp_bits for each fp_bits are as follows:
+        DEFAULT_EXP_BITS = {
+            2: 1,
+            3: 2,
+            4: 2,
+            5: 2,
+            6: 2,
+            7: 3,
+        }
+
+        if self.quantization == "quant_llm":
+            if self.quant_llm_fp_bits is None:
+                raise ValueError(
+                    "quant_llm_fp_bits must be specified when using "
+                    "quant_llm quantization."
+                )
+            if self.quant_llm_fp_bits not in VALID_QUANT_LLM_FP_BITS:
+                raise ValueError(
+                    f"Invalid quant_llm_fp_bits: {self.quant_llm_fp_bits}. "
+                    f"Must be one of {VALID_QUANT_LLM_FP_BITS}."
+                )
+            if self.quant_llm_exp_bits is None:
+                self.quant_llm_exp_bits = DEFAULT_EXP_BITS[
+                    self.quant_llm_fp_bits]
+            else:
+                if self.quant_llm_exp_bits not in VALID_QUANT_LLM_EXPONENTS:
+                    raise ValueError(
+                        f"Invalid exponent bits: {self.quant_llm_exp_bits}. "
+                        f"Must be one of {VALID_QUANT_LLM_EXPONENTS}."
+                    )
+
+            self.hf_config.quantization_config = {
+                "bits": self.quant_llm_fp_bits,
+                "exp_bits": self.quant_llm_exp_bits,
+                "quant_method": "quant_llm"
+            }
+            
+        online_quant_methods = ["fp2", "fp3", "fp4", "fp5", "fp6", "fp7"]
+        if self.quantization is not None and self.quantization in \
+            online_quant_methods:
+            fp_bits = int(self.quantization[2])
+            if fp_bits not in VALID_QUANT_LLM_FP_BITS:
+                raise ValueError(
+                    f"Invalid quant_llm_fp_bits: {fp_bits}. "
+                    f"Must be one of {VALID_QUANT_LLM_FP_BITS}."
+                )
+            if fp_bits in [2, 3]:
+                logger.warning("FP2 and FP3 quantization methods lead to "
+                               "significant accuracy loss. Use them with "
+                               "caution. Model may be incoherent.")
+            exp_bits = DEFAULT_EXP_BITS[fp_bits]
+            self.hf_config.quantization_config = {
+                "bits": fp_bits,
+                "exp_bits": exp_bits,
+                "quant_method": self.quantization
+            }
+            self.dtype = torch.float16
+            self.enforce_eager = True
+
         if self.quantization is not None:
             if self.quantization not in supported_quantization:
                 raise ValueError(

+ 18 - 2
aphrodite/engine/args_tools.py

@@ -94,6 +94,8 @@ class EngineArgs:
     quantization_param_path: Optional[str] = None
     preemption_mode: Optional[str] = None
     deepspeed_fp_bits: Optional[int] = None
+    quant_llm_fp_bits: Optional[int] = None
+    quant_llm_exp_bits: Optional[int] = None
     # Cache Options
     kv_cache_dtype: str = "auto"
     block_size: int = 16
@@ -498,8 +500,20 @@ class EngineArgs:
                             type=int,
                             default=None,
                             help="Category: Quantization Options\n"
-                            "Number of floating bits to use for the deepseed "
-                            "quantization. Supported bits are: 4, 6, 8, 12. ")
+                            "Number of floating bits to use for the deepspeed "
+                            "quantization. Supported bits are: 4, 6, 8, 12.")
+        parser.add_argument("--quant-llm-fp-bits",
+                            type=int,
+                            default=None,
+                            help="Category: Quantization Options\n"
+                            "Number of floating bits to use for the quant_llm "
+                            "quantization. Supported bits are: 4 to 15.")
+        parser.add_argument("--quant-llm-exp-bits",
+                            type=int,
+                            default=None,
+                            help="Category: Quantization Options\n"
+                            "Number of exponent bits to use for the quant_llm "
+                            "quantization. Supported bits are: 1 to 5.")
         # Cache Options
         parser.add_argument(
             '--kv-cache-dtype',
@@ -886,6 +900,8 @@ class EngineArgs:
             max_model_len=self.max_model_len,
             quantization=self.quantization,
             deepspeed_fp_bits=self.deepspeed_fp_bits,
+            quant_llm_fp_bits=self.quant_llm_fp_bits,
+            quant_llm_exp_bits=self.quant_llm_exp_bits,
             quantization_param_path=self.quantization_param_path,
             enforce_eager=self.enforce_eager,
             max_context_len_to_capture=self.max_context_len_to_capture,

+ 9 - 0
aphrodite/quantization/__init__.py

@@ -11,6 +11,7 @@ from aphrodite.quantization.deepspeedfp import DeepSpeedFPConfig
 from aphrodite.quantization.eetq import EETQConfig
 from aphrodite.quantization.experts_int8 import ExpertsInt8Config
 from aphrodite.quantization.fbgemm_fp8 import FBGEMMFp8Config
+from aphrodite.quantization.fp6 import QuantLLMFPConfig
 from aphrodite.quantization.fp8 import Fp8Config
 from aphrodite.quantization.gguf import GGUFConfig
 from aphrodite.quantization.gptq import GPTQConfig
@@ -29,6 +30,7 @@ QUANTIZATION_METHODS = {
     "tpu_int8": Int8TpuConfig,
     "eetq": EETQConfig,
     "fp8": Fp8Config,
+    "quant_llm": QuantLLMFPConfig,
     "fbgemm_fp8": FBGEMMFp8Config,
     "gguf": GGUFConfig,
     # The order of gptq methods is important for config.py iteration over
@@ -44,6 +46,13 @@ QUANTIZATION_METHODS = {
     "bitsandbytes": BitsAndBytesConfig,
     "qqq": QQQConfig,
     "experts_int8": ExpertsInt8Config,
+    # the quant_llm methods
+    "fp2": QuantLLMFPConfig,
+    "fp3": QuantLLMFPConfig,
+    "fp4": QuantLLMFPConfig,
+    "fp5": QuantLLMFPConfig,
+    "fp6": QuantLLMFPConfig,
+    "fp7": QuantLLMFPConfig,
 }
 
 

+ 198 - 0
aphrodite/quantization/fp6.py

@@ -0,0 +1,198 @@
+from typing import Any, Dict, List, Optional
+
+import torch
+import torch.nn as nn
+from loguru import logger
+
+from aphrodite import _custom_ops as ops
+from aphrodite.distributed import get_tensor_model_parallel_rank
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.quantization.utils.fp6_utils import (_SPLIT_K_MAP,
+                                                    from_scaled_tc_fpx,
+                                                    to_scaled_tc_fpx)
+
+
+class QuantLLMFPConfig(QuantizationConfig):
+    """Config for QuantLLM FP quantizer. It supports fp2, fp3, fp4,
+    fp5, fp6, fp7.
+    
+    Reference: https://arxiv.org/abs/2401.14112
+    
+    Args: 
+        weight_bits: the target quantization bits, should be one of
+            2, 3, 4, 5, 6, 7.
+    """
+
+    def __init__(
+        self,
+        weight_bits: int = 6,
+        exp_bits: int = 2,
+    ) -> None:
+        self.weight_bits = weight_bits
+        self.exponent_bits = exp_bits
+
+        self.mantissa_bits = weight_bits - self.exponent_bits - 1
+
+        self.valid_types = [torch.float16]
+
+        if self.weight_bits not in [2, 3, 4, 5, 6, 7]:
+            raise ValueError(
+                "Currently, only 4-bit, 5-bit, 6-bit, and 7-bit "
+                "quantization are "
+                f"supported for QuantLLM FP quantizaiton, but got "
+                f"{self.weight_bits} bits.")
+        
+        if get_tensor_model_parallel_rank() == 0:
+            logger.info(f"Loading model in FP{self.weight_bits}_E"
+                        f"{self.exponent_bits}M{self.mantissa_bits} format.")
+
+    def __repr__(self) -> str:
+        return (f"QuantLLMFPConfig(weight_bits={self.weight_bits}), "
+                f"exponent_bits={self.exponent_bits}")
+
+    @classmethod
+    def get_name(cls) -> str:
+        return "QuantLLMFP"
+
+    @classmethod
+    def from_config(cls, config: Dict[str, Any]) -> "QuantLLMFPConfig":
+        weight_bits = cls.get_from_keys(config, ["bits"])
+        exp_bits = cls.get_from_keys(config, ["exp_bits"])
+        return cls(weight_bits=weight_bits, exp_bits=exp_bits)
+
+    def get_linear_method(self) -> "QuantLLMFPLinearMethod":
+        return QuantLLMFPLinearMethod(self)
+
+    def get_scaled_act_names(self) -> List[str]:
+        return []
+
+    @classmethod
+    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+        return [torch.half]
+
+    @classmethod
+    # Need to figure it out
+    def get_min_capability(cls) -> int:
+        return 80
+
+    @staticmethod
+    def get_config_filenames() -> List[str]:
+        return [
+            "quant_config.json",
+            "quantize_config.json",
+        ]
+
+    def get_quant_method(
+            self,
+            layer: torch.nn.Module,
+            prefix: str) -> Optional["QuantLLMFPLinearMethod"]:
+        if isinstance(layer, LinearBase):
+            return QuantLLMFPLinearMethod(self)
+        return None
+
+
+class QuantLLMFPLinearMethod(LinearMethodBase):
+    """Linear method for QuantLLMFP quantizer.
+    Args:
+        quant_config: the QuantLLMFP quantization config.
+    """
+
+    def __init__(self, quant_config: QuantLLMFPConfig):
+        self.quant_config = quant_config
+        self.weight = None
+
+    def create_weights(self,
+                       layer: torch.nn.Module,
+                       input_size_per_partition: int,
+                       output_partition_sizes: List[int],
+                       input_size: int,
+                       output_size: int,
+                       params_dtype: torch.dtype,
+                       weight_loader=None,
+                       **extra_weight_attrs):
+        del output_size
+        del input_size
+        output_size_per_partition = sum(output_partition_sizes)
+        weight = QuantLLMFPParameter(
+            torch.Size((output_size_per_partition, input_size_per_partition)),
+            params_dtype=params_dtype,
+            quant_config=self.quant_config,
+        )
+        set_weight_attrs(weight, {
+            "input_dim": 1,
+            "output_dim": 0,
+        })
+        layer.register_parameter("weight", weight)
+
+        def quant_weight_loader(param, loaded_weight, *args, **kwargs):
+            # Calls the original weight loader (if any), quantizes the result,
+            # and then loads the quantized parameter.
+            if weight_loader is not None:
+                orig_param_data = param.data
+                param.data = param.quant_llmdequantize()
+                weight_loader(param, loaded_weight, *args, **kwargs)
+                param.data, loaded_weight = orig_param_data, param.data
+            param.quant_llmquantize_(loaded_weight.cuda())
+
+        extra_weight_attrs["weight_loader"] = quant_weight_loader
+        set_weight_attrs(weight, extra_weight_attrs)
+
+    def apply(self,
+              layer,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+        weight = layer.weight
+        weights = weight.data
+        scales = weight.scales
+        out_dim, in_dim = weights.shape
+        bsize = x.shape[0]
+        splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(
+            out_dim, 1) if bsize <= 768 else 1
+        if bias is None:
+            return ops.fp_eXmY_linear_forward_cuda(
+                self.quant_config.exponent_bits,
+                self.quant_config.mantissa_bits,
+                x, weights, scales, splitK)
+        else:
+            return ops.fp_eXmY_linear_forward_cuda(
+                self.quant_config.exponent_bits,
+                self.quant_config.mantissa_bits,
+                x, weights, scales, splitK) + bias
+
+class QuantLLMFPParameter(nn.Parameter):
+    """
+    QuantLLMFP quantized parameter class that implements fp5/fp6/fp7
+    quantization. Weights are stored in quantized form on
+    GPUs, and can be directly applied to float16 activations.
+    """
+
+    def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
+                quant_config: QuantLLMFPConfig):
+
+        data = torch.empty(torch.Size((orig_shape[0],
+                            orig_shape[1] * quant_config.weight_bits // 8)),
+                                   dtype=torch.uint8)
+
+
+        self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
+        self.scales = torch.empty(orig_shape[0],
+                                  dtype=torch.float16)
+        self.quant_config = quant_config
+        self.orig_shape = orig_shape
+        return self
+
+    def quant_llmquantize_(self, tensor: torch.Tensor):
+        assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
+        data, scales = to_scaled_tc_fpx(
+            tensor.data, self.quant_config.exponent_bits,
+            self.quant_config.mantissa_bits)
+        self.data.copy_(data)
+        self.scales.copy_(scales)
+
+    def quant_llmdequantize(self, output_dtype=None):
+        output_dtype = output_dtype or torch.get_default_dtype()
+        return from_scaled_tc_fpx(self.data, self.quant_config.exponent_bits, 
+                        self.quant_config.mantissa_bits, self.scales
+                        ).to(output_dtype)

+ 585 - 0
aphrodite/quantization/utils/fp6_utils.py

@@ -0,0 +1,585 @@
+# ruff: noqa
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3).
+# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain:
+#   1. No encodings are reserved for special values (+/-inf, NaN).
+#   2. When downcasting from FP32 to FPx,
+#      - Rounding mode is round to nearest, ties to even.
+#      - Values outside the representable range of FPx after rounding are clamped to the maximum FPx
+#      magnitude (sign is preserved).
+from functools import reduce
+from typing import Tuple
+
+import torch
+from torch import Tensor
+
+
+def _n_ones(n: int) -> int:
+    return (1 << n) - 1
+
+
+EBITS_F32, MBITS_F32 = 8, 23
+F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
+
+# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py
+_SPLIT_K_MAP = [
+    {  # tokens: [1, 64]
+        3072: 18,
+        4096: 13,
+        5120: 10,
+        6144: 9,
+        8192: 6,
+        10240: 5,
+        14336: 7,
+        28672: 7,
+        57344: 7
+    },
+    {  # tokens: [65:128]
+        3072: 9,
+        4096: 6,
+        5120: 5,
+        6144: 9,
+        8192: 3,
+        10240: 5,
+        14336: 7,
+        28672: 7,
+        57344: 6
+    },
+    {  # tokens: [129:192]
+        3072: 6,
+        4096: 4,
+        5120: 7,
+        6144: 3,
+        8192: 2,
+        10240: 5,
+        14336: 5,
+        28672: 5,
+        57344: 4
+    },
+    {  # tokens: [193:256]
+        3072: 9,
+        4096: 3,
+        5120: 5,
+        6144: 2,
+        8192: 5,
+        10240: 4,
+        14336: 8,
+        28672: 6,
+        57344: 4
+    },
+    {  # tokens: [257:320]
+        3072: 7,
+        4096: 5,
+        5120: 2,
+        6144: 5,
+        8192: 4,
+        10240: 1,
+        14336: 3,
+        28672: 3,
+        57344: 4
+    },
+    {  # tokens: [321:384]
+        3072: 3,
+        4096: 2,
+        5120: 5,
+        6144: 3,
+        8192: 1,
+        10240: 8,
+        14336: 3,
+        28672: 4,
+        57344: 3
+    },
+    {  # tokens: [385:448]
+        3072: 5,
+        4096: 7,
+        5120: 3,
+        6144: 5,
+        8192: 7,
+        10240: 3,
+        14336: 1,
+        28672: 1,
+        57344: 3
+    },
+    {  # tokens: [449:512]
+        3072: 2,
+        4096: 5,
+        5120: 4,
+        6144: 1,
+        8192: 5,
+        10240: 2,
+        14336: 6,
+        28672: 4,
+        57344: 1
+    },
+    {  # tokens: [513:576]
+        3072: 2,
+        4096: 3,
+        5120: 1,
+        6144: 1,
+        8192: 3,
+        10240: 3,
+        14336: 3,
+        28672: 1,
+        57344: 1
+    },
+    {  # tokens: [577:640]
+        3072: 5,
+        4096: 4,
+        5120: 1,
+        6144: 4,
+        8192: 2,
+        10240: 1,
+        14336: 1,
+        28672: 1,
+        57344: 1
+    },
+    {  # tokens: [641:704]
+        3072: 3,
+        4096: 1,
+        5120: 2,
+        6144: 2,
+        8192: 1,
+        10240: 2,
+        14336: 1,
+        28672: 1,
+        57344: 1
+    },
+    {  # tokens: [705:768]
+        3072: 3,
+        4096: 1,
+        5120: 3,
+        6144: 2,
+        8192: 1,
+        10240: 1,
+        14336: 1,
+        28672: 1,
+        57344: 1
+    }
+]
+
+
+def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
+    """Convert FP32 numbers to sub-byte floating point numbers with the given
+    number of exponent and mantissa bits.
+    Input: torch.Tensor of dtype torch.float
+    Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
+    in the least significant bits. e.g.
+      fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
+      fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
+    Note: there are no special values (NaN, inf) support in this code. Values
+    outside the representable range of FPx after rounding are clamped to the
+    maximum FPx magnitude (sign is preserved).
+    Code below is an adaptation of https://fburl.com/code/ciwofcg4
+    Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers  # noqa: E501
+    Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
+    """
+    assert x.dtype == torch.float
+    assert 1 + ebits + mbits <= 8
+
+    # calculate constants
+    exp_bias = _n_ones(ebits - 1)
+    max_int = _n_ones(ebits + mbits)
+    sign_mask = 1 << (ebits + mbits)
+
+    # TODO document this better
+    magic_adder = _n_ones(MBITS_F32 - mbits - 1)
+
+    # all E bits and M bits are 1s
+    max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
+
+    # E bits = 1, M bits = 0
+    min_normal = 2 ** (1 - exp_bias)
+
+    denorm_exp = (
+        # exp bias conversion between formats
+        (F32_EXP_BIAS - exp_bias)
+        # mantissa length difference between formats
+        + (MBITS_F32 - mbits)
+        # add one to encoded exponent for denormalized numbers
+        + 1
+    )
+    denorm_mask_int = denorm_exp << MBITS_F32
+
+    # reinterpret int32 as float32
+    denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)
+
+    # save the sign
+    # Note that we have torch.uint32, but some ops like cpu bit shifts
+    # do not work on it. So, we stay in int32.
+    x = x.view(torch.int32)
+    sign = x & 0x80000000
+
+    # set everything to positive, will add sign back at the end
+    x = x ^ sign
+
+    # TODO: can the branch floating point comparisons below be done without
+    # converting to float? probably but need to verify
+    x = x.view(torch.float)
+
+    # rewrite saturate/denorm/norm branches without explicit data dependent
+    # control flow, to be more compiler friendly
+    saturate_mask = x >= max_normal
+    denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
+    normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
+
+    #
+    # branch 1: saturate to max val - handled later in the code which combines
+    #   the branches
+    #
+
+    #
+    # branch 2: to conversion to denormal as well as rounding up to normal
+    #
+    denormal_x = x + denorm_mask_float
+    denormal_x = denormal_x.view(torch.int32)
+    denormal_x -= denorm_mask_int
+    denormal_x = denormal_x.to(torch.uint8)
+
+    #
+    # branch 3: stay in normal range, adjust the exponent and round
+    #
+    normal_x = x.view(torch.int32)
+    # resulting mantissa is odd
+    mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
+    # update exponent, rounding bias part 1
+    val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
+    normal_x += val_to_add
+    # rounding bias part 2
+    normal_x += mant_odd
+    # take the bits!
+    normal_x = normal_x >> (MBITS_F32 - mbits)
+    normal_x = normal_x.to(torch.uint8)
+
+    #
+    # combine the branches
+    #
+    x = torch.full_like(x, max_int, dtype=torch.uint8)
+    x = torch.where(denormal_mask, denormal_x, x)
+    x = torch.where(normal_mask, normal_x, x)
+
+    # add sign back
+    sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
+    sign_lp = sign_lp.to(torch.uint8)
+    # Right shift of a negative signed integer can fill the least significant
+    # bits with either 1s or 0s, depending on the implementation. Since PyTorch
+    # doesn't have an uint32 dtype, we mask out these bits to get just the
+    # f4 sign bit
+    sign_lp = sign_lp & sign_mask
+    x = x | sign_lp
+
+    return x.to(torch.uint8)
+
+
+# TODO(future): check if LUT for everything is faster than bit shifting,
+# especially for fp4 (only 2^4=16 unique values).
+def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
+    """Convert sub-byte floating point numbers with the given number of exponent
+    and mantissa bits to FP32.
+    Input: torch.Tensor of dtype uint8, where the bit encoding is stored
+    in the least significant bits. e.g.
+      fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
+      fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
+    Output: torch.Tensor of dtype fp32 with the dequantized value
+    """
+    assert x.dtype == torch.uint8
+    assert 1 + ebits + mbits <= 8
+
+    sign_mask = 1 << (ebits + mbits)
+    exp_bias = _n_ones(ebits - 1)
+    mantissa_mask = _n_ones(mbits)
+
+    # save the sign
+    sign_lp = x & sign_mask
+
+    # set everything to positive, will add sign back at the end
+    x_pos = x ^ sign_lp
+
+    #
+    # 1. Calculate zero mask
+    #
+    zero_mask = x_pos == 0
+
+    #
+    # 2. Calculate the denormal path mask
+    #
+    denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))
+
+    #
+    # 3. Calculate the normal path
+    #
+
+    # calculate the new exponent and shift it to bits 2:9 of the result
+    exp_biased_lp = x_pos >> mbits
+    exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
+    exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32
+
+    # shift the mantissa to bits 10:32 of the result
+    mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
+    mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
+    result = exp_biased_f32 | mantissa_f32
+
+    #
+    # 4. Add the zero and denormal casts to the already casted normal path
+    #
+    result[zero_mask] = 0
+
+    denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS
+
+    # fast path.
+    # without this, performance for FP4_E2M1 is slower by 2x
+    if mbits == 1:
+        result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32
+
+    else:
+        # iterate over all possible values of mantissa
+        # i=0, j=1
+        # i=1, j=10,11
+        # i=2, j=100,101,110,111
+        # and so on
+        for i in range(mbits):
+            for mantissa_cmp in range(1 << i, 1 << (i+1)):
+                # left shift mantissa until it overflows (create an implicit 1)
+                # subtract exponent by the same amount
+                left_shift = mbits - i
+                mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits)
+                exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32
+
+                # we can update this in-place since the values won't overlap
+                # torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int'
+                # thus we use + instead of | here
+                mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 + mantissa_f32
+
+        result = torch.where(denormal_mask, mantissa_lp_int32, result)
+
+    # add sign back
+    sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
+    result = result | sign_f32
+
+    return result.view(torch.float)
+
+
+def quant_llm_linear(
+    EXPONENT: int,
+    MANTISSA: int,
+    _in_feats: Tensor,
+    _weights: Tensor,
+    _scales: Tensor,
+    splitK: int = 1,
+) -> Tensor:
+    """
+    Quant-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details.
+    Arguments
+        EXPONENT: number of exponent bits
+        MANTISSA: number of mantissa bits
+        _in_feats: input activations in FP16
+        _weights: packed FPx weights
+        _scales: scale
+        splitK: split K
+    Returns
+        output of linear layer
+    """
+    return torch.ops.torchao.quant_llm_linear.default(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK)
+
+
+_ONES_TABLE = [_n_ones(i) for i in range(8)]
+
+
+def _pack(x: Tensor, n_bits: int) -> Tensor:
+    return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)])
+
+
+def _unpack(x: Tensor, n_bits: int) -> Tensor:
+    return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2)
+
+
+# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116
+def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:
+    # the original code unpacks/packs the values from/to uint32 while we unpack/pack the values from/to uint8
+    # thus, we need to reverse byte order within a uint32 word.
+    x = x.reshape(-1, 4).flip(1)
+
+    x = _unpack(x, n_bits)
+    x = x.view(-1, 4 * (8 // n_bits))
+
+    if not undo:
+        bit_order = {
+            1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31,
+                0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30],
+            2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14],
+            4: [1, 5, 3, 7, 0, 4, 2, 6],
+        }[n_bits]
+
+    else:
+        # this is inverse of the above, obtained by running
+        # [v.index(i) for i in range(len(v))]
+        bit_order = {
+            1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11,
+                20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15],
+            2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7],
+            4: [4, 0, 6, 2, 5, 1, 7, 3],
+        }[n_bits]
+
+    x = x[:, bit_order]
+    x = _pack(x, n_bits)
+
+    # reverse byte order within a uint32 word again.
+    x = x.reshape(-1, 4).flip(1)
+    return x.flatten()
+
+
+# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing
+# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h
+def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
+    assert tensor.ndim == 2, tensor.dtype == torch.uint8
+    M, N = tensor.shape
+    assert (M % 64 == 0) and (N % 64 == 0)
+
+    # Pass 1 from original code
+    tensor = tensor.view(M // 64, 4, 2, 8, N // 16, 2, 8)
+    tensor = tensor.permute(0, 4, 1, 5, 2, 3, 6)
+    tensor = tensor.reshape(-1, 32, 2)
+    tensor = tensor.permute(1, 0, 2)
+    tensor = tensor.flatten()
+
+    used_bits = 0
+    fragments = []
+
+    for y in [1, 2, 4]:
+        if nbits & y:
+            mask = (1 << y) - 1
+            tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask
+            tensor_ybit = _pack(tensor_ybit, y)
+
+            tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2)
+            tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y)
+            fragments.append(tensor_ybit)
+            used_bits += y
+
+    return torch.cat(fragments, dim=0).view(M, -1)
+
+
+# more optimized version of _pack_tc_fpx() for FP6 by merging ops
+def _pack_tc_fp6(tensor: Tensor) -> Tensor:
+    assert tensor.ndim == 2, tensor.dtype == torch.uint8
+    M, N = tensor.shape
+    assert (M % 64 == 0) and (N % 64 == 0)
+
+    tensor = tensor.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8)
+    tensor = tensor.flip(3)
+
+    tensor_2bit = (tensor >> 4) & 0b11
+    tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6)
+    tensor_2bit = _pack(tensor_2bit.flatten(), 2)
+
+    tensor_4bit = tensor & 0b1111
+    tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6)
+    tensor_4bit = _pack(tensor_4bit.flatten(), 4)
+
+    return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1)
+
+
+# currently only optimize for TC-FP6 packing
+def pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
+    if nbits == 6:
+        return _pack_tc_fp6(tensor)
+    return _pack_tc_fpx(tensor, nbits)
+
+
+def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]:
+    # _n_ones() is not compatible with torch.compile() due to << operator
+    # https://github.com/pytorch/pytorch/issues/119152
+    # exp_bias = _n_ones(ebits - 1)
+    # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
+
+    # workaround: global lookup table
+    exp_bias = _ONES_TABLE[ebits - 1]
+    max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits))
+
+    tensor = tensor.float()
+    scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
+    tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
+    tensor_tc_fpx = pack_tc_fpx(tensor_fpx, 1 + ebits + mbits)
+    return tensor_tc_fpx, scale.half()
+
+
+# inverse of _pack_tc_fpx()
+def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
+    assert tensor.ndim == 2 and tensor.dtype == torch.uint8
+    M = tensor.shape[0]
+    size = tensor.numel()
+    tensor = tensor.flatten()
+    offset = 0
+    used_bits = 0
+
+    tensor_fpx = None
+
+    for y in [1, 2, 4]:
+        if nbits & y:
+            size_ybit = size // nbits * y
+            tensor_ybit = tensor[offset : offset + size_ybit]
+            offset += size_ybit
+
+            tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True)            # undo Pass 3
+            tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2)  # undo Pass 2
+
+            tensor_ybit = _unpack(tensor_ybit.flatten(), y)
+            tensor_ybit = tensor_ybit << (nbits - used_bits - y)
+            used_bits += y
+
+            if tensor_fpx is None:
+                tensor_fpx = tensor_ybit
+            else:
+                tensor_fpx |= tensor_ybit
+
+    # undo Pass 1
+    tensor_fpx = tensor_fpx.view(32, -1, 2).permute(1, 0, 2)
+    tensor_fpx = tensor_fpx.reshape(M // 64, -1, 4, 2, 2, 8, 8)
+    tensor_fpx = tensor_fpx.permute(0, 2, 4, 5, 1, 3, 6)
+    tensor_fpx = tensor_fpx.reshape(M, -1)
+    return tensor_fpx
+
+
+# more optimized version of _unpack_tc_fpx() for FP6 by merging ops
+# inverse of _unpack_tc_fp6()
+def _unpack_tc_fp6(tensor: Tensor) -> Tensor:
+    assert tensor.ndim == 2 and tensor.dtype == torch.uint8
+    M = tensor.shape[0]
+    N = tensor.shape[1] // 3 * 4
+    assert (M % 64 == 0) and (N % 64 == 0)
+    size_2bit = M * N // 4
+    size_4bit = M * N // 2
+    tensor = tensor.view(-1)
+    assert tensor.numel() == size_2bit + size_4bit
+
+    tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit])
+
+    tensor_2bit = _unpack(tensor_2bit, 2)
+    tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2)
+    tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4)
+
+    tensor_4bit = _unpack(tensor_4bit, 4)
+    tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2)
+    tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5)
+
+    tensor_fp6 = (tensor_2bit << 4) | tensor_4bit
+    tensor_fp6 = tensor_fp6.flip(3).reshape(M, N)
+    return tensor_fp6
+
+
+def unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
+    if nbits == 6:
+        return _unpack_tc_fp6(tensor)
+    return _unpack_tc_fpx(tensor, nbits)
+
+
+def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Tensor:
+    fpx_unpacked = unpack_tc_fpx(tensor, 1 + ebits + mbits)
+    tensor = _fpx_unpacked_to_f32(fpx_unpacked, ebits, mbits)
+    if scale is not None:
+        tensor = tensor * scale.float().view(-1, 1)
+    return tensor

+ 73 - 0
kernels/quantization/fp6/configs.h

@@ -0,0 +1,73 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is copied from
+// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/configs.h
+
+#ifndef CONFIGS_H
+#define CONFIGS_H
+
+// #define DEBUG_MODE
+#define PIPELINE_LEVEL_GMEM 2
+#define PIPELINE_LEVEL_SMEM 2  // only support 2
+
+/************************ Hardware Parameters ************************/
+#define WARP_SIZE 32
+#define REG_BIT_WIDTH 32
+// mma: M=16 K=16 N=8
+#define MMA_8 8
+#define MMA_16 16
+// for memory access
+#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128  // LDS.128, cp_async.128, ...
+#define BIT_WIDTH_PER_HALF 16                // Half precision: FP16
+
+/******************** Register Allocation For GEMM ********************/
+#define REG_PER_THREAD_C_TENSOR_16_16 8  // 8 for FP32 Accumulation
+/********************** Memory Padding Parameters **********************/
+// Eliminating bank-conflict
+#define PADDING_BYTES_16 16  // Padding 16 bytes each column
+#define PADDING_SHARED_MEM_FOR_B_8 \
+  8  // Padding 8 half  each column, during CopyFromGlobalToShared() for B
+#define PADDING_SHARED_MEM_FOR_C_4 \
+  4  // Padding 4 float each column, during StoreToSharedMemoryFromRegister()
+     // for C
+/************************* WARP Tiling part-1 *************************/
+#define WARP_ROW_MMA_TENSORS 4
+#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16)  // 64
+#define WARP_K_MMA_TENSORS 4
+#define WARP_K (WARP_K_MMA_TENSORS * MMA_16)  // 64
+template <int BLOCK_ROW_WARPS_, int BLOCK_COL_WARPS_, int WARP_COL_MMA_TENSORS_>
+struct TilingConfig {
+  // Depending on "n" dimension of the GEMM
+  static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_;
+  static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_;
+  static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_;
+  /************************* WARP Tiling part-2 *************************/
+  static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8;
+  /*************************Thread Block Tiling *************************/
+  static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS;
+  static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS;
+  static constexpr int TILE_K = WARP_K;
+  /********************** #Thread per Thread Block **********************/
+  static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS;
+  static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE;
+  /******************************* Others *******************************/
+  static constexpr int SMEM_SIZE_B_TILE =
+      TILE_N * (TILE_K + PADDING_BYTES_16) * 2 *
+      PIPELINE_LEVEL_GMEM;  // sizeof(half)=2, doubleBuffer=2
+  static constexpr int SMEM_SIZE_C_TILE =
+      TILE_N * (TILE_M + PADDING_BYTES_16) * 4;  // sizeof(float)=4
+};
+
+#endif  // CONFIGS_H

+ 332 - 0
kernels/quantization/fp6/fp6_linear.cu

@@ -0,0 +1,332 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is adapted from
+// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu
+
+#include "kernel_matmul.cuh"
+#include "kernel_reduction.cuh"
+
+#include <stdio.h>
+#include <assert.h>
+
+namespace aphrodite {
+
+template <typename TilingConfig, typename OutputDataType, int EXPONENT,
+          int MANTISSA>
+static void Kernel_Ex(cudaStream_t stream, const uint4* Weight,
+                      const half* Scales, const half* B, OutputDataType* C,
+                      const size_t M_Global, const size_t N_Global,
+                      const size_t K_Global, int Split_K) {
+#ifdef DEBUG_MODE
+  printf("\n");
+  printf("Launcher.cu->Kernel_Ex():\n");
+  printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global,
+         Split_K);
+  printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M,
+         TilingConfig::TILE_K, TilingConfig::TILE_N);
+#endif
+  static size_t SHMEM_SZ =
+      max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_PER_TB_A_TILE,
+          TilingConfig::SMEM_SIZE_C_TILE);
+  cudaFuncSetAttribute(
+      QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>,
+      cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ);
+  size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1;
+  size_t dimM = M_Global * Split_K / TilingConfig::TILE_M;
+  dim3 GridDim(dimN, dimM, 1);
+  dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1);
+//
+#ifdef DEBUG_MODE
+  printf(
+      "GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, "
+      "BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\n",
+      GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z,
+      SHMEM_SZ);
+  printf("\n");
+#endif
+  QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>
+      <<<GridDim, BlockDim, SHMEM_SZ, stream>>>(Weight, Scales, B, C, M_Global,
+                                                N_Global, K_Global, Split_K);
+}
+
+template <int EXPONENT, int MANTISSA>
+cudaError_t fpx_linear_kernel(
+    cudaStream_t stream, const uint4* Weight, const half* Scales, const half* B,
+    half* C, const size_t M_Global, const size_t N_Global,
+    const size_t K_Global,
+    float* Reduction_Workspace,  // Reduction_Workspace_Size = Split_K *
+                                 // M_Global * N_Global * sizeof(fp32)
+    int Split_K) {
+  assert(M_Global % 256 == 0);
+  assert(K_Global % 64 == 0);
+  assert(N_Global > 0);
+
+  // Work around to support more N shapes:
+  size_t N_PowerOf2;
+  if (N_Global > 0 && N_Global <= 8) N_PowerOf2 = 8;
+  if (N_Global > 8 && N_Global <= 16) N_PowerOf2 = 16;
+  if (N_Global > 16 && N_Global <= 32) N_PowerOf2 = 32;
+  if (N_Global > 32 && N_Global <= 64) N_PowerOf2 = 64;
+  if (N_Global > 64 && N_Global <= 128) N_PowerOf2 = 128;
+  if (N_Global > 128) N_PowerOf2 = ((N_Global - 1) / 128 + 1) * 128;
+
+  if (Split_K == 1) {
+    switch (N_PowerOf2) {
+      case 8:
+        Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
+            Split_K);
+        break;
+      case 16:
+        Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
+            Split_K);
+        break;
+      case 32:
+        Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
+            Split_K);
+        break;
+      case 64:
+        Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
+            Split_K);
+        break;
+      case 128:
+        Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
+            Split_K);
+        break;
+      default:
+        if (N_PowerOf2 % 128 != 0) {
+          printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
+          return cudaErrorUnknown;
+        }
+        Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
+            Split_K);
+        break;
+    }
+  } else {
+    switch (N_PowerOf2) {
+      case 8:
+        Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
+            K_Global, Split_K);
+        break;
+      case 16:
+        Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
+            K_Global, Split_K);
+        break;
+      case 32:
+        Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
+            K_Global, Split_K);
+        break;
+      case 64:
+        Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
+            K_Global, Split_K);
+        break;
+      case 128:
+        Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
+            K_Global, Split_K);
+        break;
+      default:
+        if (N_PowerOf2 % 128 != 0) {
+          printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
+          return cudaErrorUnknown;
+        }
+        Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
+            K_Global, Split_K);
+        break;
+    }
+    // Reduction for SplitK
+    dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1,
+                 1);
+    dim3 BlockDim(WARP_SIZE, 1, 1);
+    SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>(
+        C, Reduction_Workspace, M_Global, N_Global, Split_K);
+  }
+  return cudaGetLastError();
+}
+}  // namespace aphrodite
+
+#include <torch/all.h>
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <torch/library.h>
+
+// MODIFICATION NOTE: dtype of _weights is changed to uint8
+/*
+Computes FPx-FP16 GEMM (PyTorch interface).
+[Mathematical Formula]
+Standard definition of linear layer:    Out = In * trans(W), where In, Out, and
+W are stored in row-major. After Equivalent transformation    :    trans(Out) =
+W * trans(In). Note that we do not perform "transpose" during runtime, we
+instead interpret the In/Out as column-major matrices when calling our CUDA
+kernel. [Inputs] _in_feats:  tensor of shape [B, IC];                  // half
+  _weights:   int tensor of shape [OC, IC // 8 * x];    // x UINT8 words
+contains 8 FPx weights. _scales:    tensor of shape [OC];                     //
+half splitK:     splitting the MatMul problem along K dimension for higher GPU
+utilization, default 1. [Outputs] _out_feats: tensor of shape [B, OC]; // half
+*/
+torch::Tensor fp_eXmY_linear_forward_cuda(int64_t EXPONENT, int64_t MANTISSA,
+                                          torch::Tensor _in_feats,
+                                          torch::Tensor _weights,
+                                          torch::Tensor _scales,
+                                          int64_t splitK = 1) {
+  const int64_t NBITS = 1 + EXPONENT + MANTISSA;
+  int num_in_feats = _in_feats.size(0);
+  int num_in_channels = _in_feats.size(1);
+  int num_out_channels = _weights.size(0);
+  TORCH_CHECK(num_in_channels % 64 == 0,
+              "Expected in_features to be a multiple of 64, but received ",
+              num_in_channels);
+  TORCH_CHECK((num_in_channels / 8 * NBITS) ==
+              _weights.size(1));  // Making sure the K dimension is matched.
+  //
+  int M = num_out_channels;
+  int K = num_in_channels;
+  int N = num_in_feats;
+  // Input Tensors
+  auto weight = reinterpret_cast<const uint4*>(
+      _weights.data_ptr<uint8_t>());  // weights is [OC, IC] but in FP6.
+  auto in_feats = reinterpret_cast<const half*>(_in_feats.data_ptr<at::Half>());
+  auto scales = reinterpret_cast<const half*>(_scales.data_ptr<at::Half>());
+  // Output Tensors
+  auto options = torch::TensorOptions()
+                     .dtype(_in_feats.dtype())
+                     .device(_in_feats.device());
+  at::Tensor _out_feats =
+      torch::empty({num_in_feats, num_out_channels}, options);
+  auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
+
+  options =
+      torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device());
+  at::Tensor _workspace =
+      torch::empty({splitK, num_in_feats, num_out_channels}, options);
+  auto Reduction_Workspace = reinterpret_cast<float*>(
+      _workspace.data_ptr<float>());  // Reduction_Workspace_Size = Split_K *
+                                      // M_Global * N_Global * sizeof(fp32)
+
+  // MODIFICATION NOTE: use at::cuda::getCurrentCUDAStream() instead of default
+  // stream (0) this fixes problem with CUDA graphs when used with
+  // torch.compile()
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  /*
+   The heuristic is weight_bit - exponent_bit - 1 = mantissa_bit
+   */
+
+  // FP2
+  if (EXPONENT == 1 && MANTISSA == 0)
+    aphrodite::fpx_linear_kernel<1, 0>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+
+  // FP3
+  else if (EXPONENT == 1 && MANTISSA == 1)
+    aphrodite::fpx_linear_kernel<1, 1>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 2 && MANTISSA == 0)
+    aphrodite::fpx_linear_kernel<2, 0>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+
+  // FP4
+  else if (EXPONENT == 1 && MANTISSA == 2)
+    aphrodite::fpx_linear_kernel<1, 2>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 3 && MANTISSA == 0)
+    aphrodite::fpx_linear_kernel<3, 0>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 2 && MANTISSA == 1)
+    aphrodite::fpx_linear_kernel<2, 1>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  // FP5
+  else if (EXPONENT == 1 && MANTISSA == 3)
+    aphrodite::fpx_linear_kernel<1, 3>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 2 && MANTISSA == 2)
+    aphrodite::fpx_linear_kernel<2, 2>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 3 && MANTISSA == 1)
+    aphrodite::fpx_linear_kernel<3, 1>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 4 && MANTISSA == 0)
+    aphrodite::fpx_linear_kernel<4, 0>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+
+  // FP6
+  else if (EXPONENT == 1 && MANTISSA == 4)
+    aphrodite::fpx_linear_kernel<1, 4>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 2 && MANTISSA == 3)
+    aphrodite::fpx_linear_kernel<2, 3>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 3 && MANTISSA == 2)
+    aphrodite::fpx_linear_kernel<3, 2>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 4 && MANTISSA == 1)
+    aphrodite::fpx_linear_kernel<4, 1>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 5 && MANTISSA == 0)
+    aphrodite::fpx_linear_kernel<5, 0>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  // FP7
+  else if (EXPONENT == 1 && MANTISSA == 5)
+    aphrodite::fpx_linear_kernel<1, 5>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 2 && MANTISSA == 4)
+    aphrodite::fpx_linear_kernel<2, 4>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 3 && MANTISSA == 3)
+    aphrodite::fpx_linear_kernel<3, 3>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 4 && MANTISSA == 2)
+    aphrodite::fpx_linear_kernel<4, 2>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 5 && MANTISSA == 1)
+    aphrodite::fpx_linear_kernel<5, 1>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+
+  else
+    TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA,
+                " is not supported.");
+
+  return _out_feats;
+}

+ 354 - 0
kernels/quantization/fp6/kernel_matmul.cuh

@@ -0,0 +1,354 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is modified from
+// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/kernel_matmul.cuh
+
+#include "configs.h"
+#include "utils_gmem.cuh"
+#include "utils_core.cuh"
+
+/************************** Bitwidth of Weight Segments
+ * ************************/
+#define BIT_WIDTH_1 1
+#define BIT_WIDTH_2 2
+#define BIT_WIDTH_4 4
+/*************************** 64*64 Weghts of Weight Matrix
+ * *********************/
+#define WEIGHT_PER_WARP (WARP_M * WARP_K)  // 64*64 = 4096
+#define SMEM_SIZE_PER_WARP_1BIT    \
+  (WEIGHT_PER_WARP * BIT_WIDTH_1 / \
+   8)  // 512 Bytes,  doubleBuffer not taken into consideration
+#define SMEM_SIZE_PER_WARP_2BIT    \
+  (WEIGHT_PER_WARP * BIT_WIDTH_2 / \
+   8)  // 1024 Bytes, doubleBuffer not taken into consideration
+#define SMEM_SIZE_PER_WARP_4BIT    \
+  (WEIGHT_PER_WARP * BIT_WIDTH_4 / \
+   8)  // 2048 Bytes, doubleBuffer not taken into consideration
+#define SMEM_SIZE_PER_TB_1BIT                            \
+  (SMEM_SIZE_PER_WARP_1BIT * TilingConfig::BLOCK_WARPS * \
+   PIPELINE_LEVEL_GMEM)  // #WARP=4; Trible-Buffer for 3-level pipeline for A
+                         // = 6 KB;  double buffer for 2-level pipeline A= 4
+                         // KB.
+#define SMEM_SIZE_PER_TB_2BIT                            \
+  (SMEM_SIZE_PER_WARP_2BIT * TilingConfig::BLOCK_WARPS * \
+   PIPELINE_LEVEL_GMEM)  // #WARP=4; Trible-Buffer for 3-level pipeline for A
+                         // = 12 KB; double buffer for 2-level pipeline A= 8
+                         // KB.
+#define SMEM_SIZE_PER_TB_4BIT                            \
+  (SMEM_SIZE_PER_WARP_4BIT * TilingConfig::BLOCK_WARPS * \
+   PIPELINE_LEVEL_GMEM)  // #WARP=4; Trible-Buffer for 3-level pipeline for A
+                         // = 24 KB; double buffer for 2-level pipeline A= 16
+                         // KB.
+#define SMEM_SIZE_PER_TB_A_TILE                    \
+  (SMEM_SIZE_PER_TB_1BIT + SMEM_SIZE_PER_TB_2BIT + \
+   SMEM_SIZE_PER_TB_4BIT)  // used in fp6_linear.cu, Kernel_Ex().
+/******************** Global Memory Layout For QUANTIZED DATA
+ * *******************/
+#define NUM_INT4_PER_WARP_1BIT (WEIGHT_PER_WARP * BIT_WIDTH_1 / 128)  // 32
+#define NUM_INT4_PER_WARP_2BIT (WEIGHT_PER_WARP * BIT_WIDTH_2 / 128)  // 64
+#define NUM_INT4_PER_WARP_4BIT (WEIGHT_PER_WARP * BIT_WIDTH_4 / 128)  // 128
+
+/*
+ * C = A*B
+ * A: row major with ahead-of-time layout transformation, FP6
+ * B: col major, FP16
+ * C: col major, FP16
+ */
+template <typename TilingConfig, typename OutputDataType, int EXPONENT,
+          int MANTISSA>
+__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
+                                  const half* B, OutputDataType* C,
+                                  const size_t M_Global, const size_t N_Global,
+                                  const size_t K_Global, int Split_K) {
+#ifdef DEBUG_MODE
+  assert(K_Global % TilingConfig::TILE_K == 0);
+  assert(M_Global % TilingConfig::TILE_M == 0);
+  assert(gridDim.y == Split_K * (M_Global / TilingConfig::TILE_M));
+#endif
+  // 1+2+4 weight split
+  constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
+  constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
+  constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
+  constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
+  const uint4* Weight_1bit = Weight;
+  const uint4* Weight_2bit =
+      Weight_1bit +
+      (USE_SEG_1BIT ? M_Global * K_Global * BIT_WIDTH_1 / 128 : 0);
+  const uint4* Weight_4bit =
+      Weight_2bit +
+      (USE_SEG_2BIT ? M_Global * K_Global * BIT_WIDTH_2 / 128 : 0);
+  // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned
+  extern __shared__ __align__(128) half smem[];
+  half(*smem_array)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
+      reinterpret_cast<half(*)[WARP_K + PADDING_SHARED_MEM_FOR_B_8]>(
+          smem + SMEM_SIZE_PER_TB_A_TILE /
+                     2);  // Dynamic shared memory for FP16 B tiles
+  __shared__ half
+      QuantScales[64 *
+                  TilingConfig::BLOCK_WARPS];  // static shared memory for
+                                               // quantization scales, 64 row
+                                               // per warp * 4 warps = 512 Bytes
+  // Thread Block Mapping, considering SplitK
+  const size_t BatchID = blockIdx.y / (M_Global / TilingConfig::TILE_M);
+  const size_t x =
+      blockIdx.x;  // Output Block ID: (BlockID_Row = y; BlockID_Col = x )
+  const size_t y =
+      blockIdx.y %
+      (M_Global / TilingConfig::TILE_M);  // Output Block ID: (BlockID_Row = y;
+                                          // BlockID_Col = x )
+  const size_t Tile_Start_M = y * TilingConfig::TILE_M;
+  const size_t Tile_Start_N = x * TilingConfig::TILE_N;
+  const size_t NumColumnToCopy =
+      (N_Global - Tile_Start_N) < TilingConfig::TILE_N
+          ? (N_Global - Tile_Start_N)
+          : TilingConfig::TILE_N;
+  const size_t NumBlock_K = K_Global / TilingConfig::TILE_K;
+  const size_t AverageNumBlock_K = NumBlock_K / Split_K;
+  const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K;
+  size_t NumIter = AverageNumBlock_K;
+  size_t StartBlockID_K = AverageNumBlock_K * BatchID;
+  if (BatchID < ExtraNumBlock_K) {
+    NumIter++;
+    StartBlockID_K += BatchID;
+  } else
+    StartBlockID_K += ExtraNumBlock_K;
+  // Warp ID.
+  const int warpId = threadIdx.x / WARP_SIZE;
+  int WARP_i = warpId / TilingConfig::BLOCK_COL_WARPS;  // WARP_i: row number;
+                                                        // WARP_j: column number
+  // int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS;
+  //  Global Memory Address for Matrix A (Weight)
+  //  /////////////////////////////////////////////////////////////////////////
+  //  StartPTR for each ThreadBlock(TB)
+  const uint4* TB_StartGPTR_A_1BIT =
+      Weight_1bit +
+      (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_WARP_1BIT;
+  const uint4* TB_StartGPTR_A_2BIT =
+      Weight_2bit +
+      (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_WARP_2BIT;
+  const uint4* TB_StartGPTR_A_4BIT =
+      Weight_4bit +
+      (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_WARP_4BIT;
+  // StartPTR for each WARP.
+  const uint4* WARP_StartGPTR_A_1BIT =
+      TB_StartGPTR_A_1BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_1BIT;
+  const uint4* WARP_StartGPTR_A_2BIT =
+      TB_StartGPTR_A_2BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_2BIT;
+  const uint4* WARP_StartGPTR_A_4BIT =
+      TB_StartGPTR_A_4BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_4BIT;
+  // StartPTR for each WARP, considering SplitK
+  const size_t WARP_Start_UnitID_K = StartBlockID_K;
+  WARP_StartGPTR_A_1BIT += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_1BIT;
+  WARP_StartGPTR_A_2BIT += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_2BIT;
+  WARP_StartGPTR_A_4BIT += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_4BIT;
+  // Copying A tile from Global to Shared, using double-buffer
+  // ////////////////////////////////////////////////////////// StartSPTR for
+  // each ThreadBlock
+  uint32_t* AFrag_1BIT_SPTR = reinterpret_cast<uint32_t*>(smem);
+  uint32_t* AFrag_2BIT_SPTR = AFrag_1BIT_SPTR + SMEM_SIZE_PER_TB_1BIT / 4;
+  uint32_t* AFrag_4BIT_SPTR =
+      AFrag_2BIT_SPTR +
+      SMEM_SIZE_PER_TB_2BIT /
+          4;  // 8 buffers including double buffers, 12 for trible buffers
+  // StartSPTR for each WARP
+  AFrag_1BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_1BIT / 4;
+  AFrag_2BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_2BIT / 4;
+  AFrag_4BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_4BIT / 4;
+  // Pre-fetch of A tile
+  for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) {
+    if (USE_SEG_1BIT)
+      CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_1BIT>(
+          AFrag_1BIT_SPTR + i * SMEM_SIZE_PER_WARP_1BIT / 4 * 4,
+          WARP_StartGPTR_A_1BIT);
+    if (USE_SEG_2BIT)
+      CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_2BIT>(
+          AFrag_2BIT_SPTR + i * SMEM_SIZE_PER_WARP_2BIT / 4 * 4,
+          WARP_StartGPTR_A_2BIT);
+    if (USE_SEG_4BIT)
+      CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(
+          AFrag_4BIT_SPTR + i * SMEM_SIZE_PER_WARP_4BIT / 4 * 4,
+          WARP_StartGPTR_A_4BIT);
+    WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT / 16;
+    WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT / 16;
+    WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT / 16;
+  }
+  // Global Memory Address for Matrix A (QuantScale)
+  // /////////////////////////////////////////////////////////////////////
+  const half* TB_StartGPTR_A_Scale =
+      Scales + (y * TilingConfig::BLOCK_ROW_WARPS) * 64;
+  const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64;
+  CopyFromGlobalToShared_Scales(QuantScales + WARP_i * 64,
+                                WARP_StartGPTR_A_Scales);
+  // Copying B tile from Global to Shared, considering SplitK
+  // /////////////////////////////////////////////////////////////
+  const half* BTile_GPTR =
+      B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K;
+  for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) {
+    CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS>(
+        smem_array + i * TilingConfig::TILE_N, BTile_GPTR, K_Global,
+        NumColumnToCopy);
+    BTile_GPTR += TilingConfig::TILE_K;
+  }
+  // Register Allocation for A,B, and C, Initilazed to Zeros
+  // /////////////////////////////////////////////////////////////////////
+  constexpr int NumRegSets_a =
+      WARP_ROW_MMA_TENSORS;  // 1 set = 4 registers, containing a 16*16 MMA
+                             // block
+  constexpr int NumRegSets_b =
+      (TilingConfig::WARP_COL_MMA_TENSORS == 1)
+          ? 1
+          : TilingConfig::WARP_COL_MMA_TENSORS /
+                2;  // 1 set = 4 registers, containing a 16*16 MMA block
+  uint32_t a[NumRegSets_a * PIPELINE_LEVEL_SMEM]
+            [4];  // double/Trible buffer is used // Registers to store
+                  // decompressed FP6
+  uint32_t b[NumRegSets_b * PIPELINE_LEVEL_SMEM]
+            [4];  // double/Triple buffer is used // Register to store FP16 B
+                  // matrix (a slice)
+  float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16];
+  for (int i = 0; i < NumRegSets_a * NumRegSets_b; i++)
+    for (int j = 0; j < REG_PER_THREAD_C_TENSOR_16_16; j++) c[i][j] = 0.0f;
+  //
+  cp_async_wait_all();
+  __syncthreads();
+
+  /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+  uint32_t Scales_RPTR[4];  // 4 Registers per thread for Quantization Scales
+  ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i * 64);
+  // Initializing the Software Pipeline: writing registers.
+  // ////////////////////////////////////////////////////////////////////////////////////////////////
+  initialize_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
+      a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array,
+      Scales_RPTR);
+// The outer loop.
+// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+#pragma unroll(1)
+  for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) {
+    // Trible-Buffer for A Tile
+    uint32_t* __restrict__ read_SPTR_Frag_1bit =
+        AFrag_1BIT_SPTR +
+        ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT / 4 *
+            4;  // 512  (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
+    uint32_t* __restrict__ read_SPTR_Frag_2bit =
+        AFrag_2BIT_SPTR +
+        ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT / 4 *
+            4;  // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
+    uint32_t* __restrict__ read_SPTR_Frag_4bit =
+        AFrag_4BIT_SPTR +
+        ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT / 4 *
+            4;  // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
+    uint32_t* __restrict__ read2_SPTR_Frag_1bit =
+        AFrag_1BIT_SPTR + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) *
+                              SMEM_SIZE_PER_WARP_1BIT / 4 * 4;
+    uint32_t* __restrict__ read2_SPTR_Frag_2bit =
+        AFrag_2BIT_SPTR + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) *
+                              SMEM_SIZE_PER_WARP_2BIT / 4 * 4;
+    uint32_t* __restrict__ read2_SPTR_Frag_4bit =
+        AFrag_4BIT_SPTR + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) *
+                              SMEM_SIZE_PER_WARP_4BIT / 4 * 4;
+    uint32_t* __restrict__ write_SPTR_Frag_1bit =
+        AFrag_1BIT_SPTR +
+        ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) *
+            SMEM_SIZE_PER_WARP_1BIT / 4 *
+            4;  // 512  (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
+    uint32_t* __restrict__ write_SPTR_Frag_2bit =
+        AFrag_2BIT_SPTR +
+        ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) *
+            SMEM_SIZE_PER_WARP_2BIT / 4 *
+            4;  // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
+    uint32_t* __restrict__ write_SPTR_Frag_4bit =
+        AFrag_4BIT_SPTR +
+        ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) *
+            SMEM_SIZE_PER_WARP_4BIT / 4 *
+            4;  // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
+    // Trible-Buffer for B Tile
+    // MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is
+    // changed to below. similarly for read2_SPTR and write_SPTR.
+    half(*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
+        smem_array +
+        ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
+    half(*__restrict__ read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
+        smem_array +
+        ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
+    half(*__restrict__ write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
+        smem_array +
+        ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) *
+            TilingConfig::TILE_N;
+    //
+    bool GlobalCopy = (tile_id_k + PIPELINE_LEVEL_GMEM - 1) < NumIter;
+    // Copying A tile from Global to Register, Bypassing L1, using double-buffer
+    if (USE_SEG_1BIT)
+      CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_1BIT>(
+          write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy);
+    if (USE_SEG_2BIT)
+      CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_2BIT>(
+          write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy);
+    if (USE_SEG_4BIT)
+      CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(
+          write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy);
+    // copying B tile from GlobalMemory to SharedMemory
+    CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS>(
+        write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy);
+    cp_async_group_commit();
+    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
+        c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit,
+        read_SPTR, Scales_RPTR,
+        1);  // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each
+             // WARP; read_SPTR is shared among WARPs
+    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
+        c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit,
+        read_SPTR, Scales_RPTR, 2);
+    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
+        c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit,
+        read_SPTR, Scales_RPTR, 3);
+    // Barriers and Synchronizations
+    cp_async_wait_group<PIPELINE_LEVEL_GMEM - 2>();
+    __syncthreads();
+    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
+        c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit,
+        read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0);
+    // Updating global PTRs
+    WARP_StartGPTR_A_1BIT +=
+        SMEM_SIZE_PER_WARP_1BIT / 16;  // 2KB/16=128 (1)/16: int4*+1 = char*+16
+    WARP_StartGPTR_A_2BIT +=
+        SMEM_SIZE_PER_WARP_2BIT / 16;  // 4KB/16=256 (1)/16: int4*+1 = char*+16
+    WARP_StartGPTR_A_4BIT +=
+        SMEM_SIZE_PER_WARP_4BIT / 16;  // 8KB/16=512 (1)/16: int4*+1 = char*+16
+    BTile_GPTR += TilingConfig::TILE_K;
+  }
+  /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+  /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+  // Store the C fragments to shared memory.
+  float(*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4] =
+      reinterpret_cast<
+          float(*)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4]>(smem);
+  StoreToSharedMemoryFromRegister<TilingConfig>(smem_CFrag, c);
+  __syncthreads();
+  // Now that shared memory contains all the D tiles, stream them to global
+  // memory.
+  OutputDataType* BlockGlobalPTR = C + BatchID * (M_Global * N_Global) +
+                                   Tile_Start_M + Tile_Start_N * M_Global;
+  for (size_t i = warpId; i < NumColumnToCopy;
+       i += TilingConfig::BLOCK_WARPS)  // i-th column
+#pragma unroll
+    for (size_t j = threadIdx.x % WARP_SIZE; j < TilingConfig::TILE_M;
+         j += WARP_SIZE)  // j-th row
+    {
+      if constexpr (std::is_same<OutputDataType, half>::value)
+        BlockGlobalPTR[j + i * M_Global] = __float2half_rn(smem_CFrag[i][j]);
+      else
+        BlockGlobalPTR[j + i * M_Global] = smem_CFrag[i][j];
+    }
+}

+ 70 - 0
kernels/quantization/fp6/kernel_reduction.cuh

@@ -0,0 +1,70 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is copied from
+// https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_reduction.cuh
+
+/***************************************************************************
+ * Copyright 2023 The FLash-LLM Authors. All rights reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ ***************************************************************************/
+// Used for the reduction of result matrix if Split-K is used
+// Reduction_Workspace:     (Split_K, M_Global, N_Global),  column major
+// C:                       (M_Global, N_Global),           column major
+// Each thread deals with 8 output elements, each elements is the sum of Split_K
+// elements
+//      Read  Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8
+//      float_per_thread (256bit) -> 256 float per warp Write Global: Each
+//      Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread  (128bit) ->
+//      256 half  per warp
+// GridSize = (M_Global*N_Global) / 256
+
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+#define REDUCTION_ELEMENT_PER_THREADBLOCK 256
+#define HALF_PER_128BIT 8
+
+__global__ void SplitK_Reduction(half* C, float* Reduction_Workspace,
+                                 size_t M_Global, size_t N_Global,
+                                 int Split_K) {
+  half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x;
+  float* WARP_GPTR_R =
+      Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x;
+  half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT;
+  float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT;
+  // Initializing Thread-Local Results
+  float Results[HALF_PER_128BIT];
+#pragma unroll
+  for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f;
+  // Reduction
+  for (int i = 0; i < Split_K; i++) {
+#pragma unroll
+    for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j];
+    THREAD_GPTR_R += M_Global * N_Global;
+  }
+// Writing to global memory
+#pragma unroll
+  for (int i = 0; i < HALF_PER_128BIT; i++)
+    THREAD_GPTR_C[i] = __float2half_rn(Results[i]);
+}

+ 82 - 0
kernels/quantization/fp6/ptx_cp.async.cuh

@@ -0,0 +1,82 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is copied from
+// https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_cp.async.cuh
+
+/***************************************************************************
+ * Copyright 2023 The FLash-LLM Authors. All rights reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ ***************************************************************************/
+// Extended from CUTLASS's source code
+
+#ifndef PTX_CP_ASYNC_CUH
+#define PTX_CP_ASYNC_CUH
+
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+template <int SizeInBytes>
+__device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr,
+                                         bool pred_guard = true) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  static_assert(SizeInBytes == 16, "Size is not supported");
+  unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr);
+  asm volatile(
+      "{ \n"
+      "  .reg .pred p;\n"
+      "  setp.ne.b32 p, %0, 0;\n"
+      "  @p cp.async.cg.shared.global [%1], [%2], %3;\n"
+      "}\n" ::"r"((int)pred_guard),
+      "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
+#endif
+}
+
+/// Establishes an ordering w.r.t previously issued cp.async instructions. Does
+/// not block.
+__device__ __forceinline__ void cp_async_group_commit() {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  asm volatile("cp.async.commit_group;\n" ::);
+#endif
+}
+
+/// Blocks until all but <N> previous cp.async.commit_group operations have
+/// committed.
+template <int N>
+__device__ __forceinline__ void cp_async_wait_group() {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
+#endif
+}
+
+/// Blocks until all previous cp.async.commit_group operations have committed.
+// cp.async.wait_all is equivalent to :
+// cp.async.commit_group;
+// cp.async.wait_group 0;
+__device__ __forceinline__ void cp_async_wait_all() {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  asm volatile("cp.async.wait_all;\n" ::);
+#endif
+}
+
+#endif

+ 108 - 0
kernels/quantization/fp6/ptx_mma.cuh

@@ -0,0 +1,108 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is modified from
+// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/ptx_mma.cuh
+
+/***************************************************************************
+ * Copyright 2023 The FLash-LLM Authors. All rights reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ ***************************************************************************/
+#ifndef PTX_MMA_CUH
+#define PTX_MMA_CUH
+
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+#include <assert.h>
+#include "configs.h"
+
+// MODIFICATION NOTE: to support MSVC
+// - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__
+// Reg)[4]
+// - half __restrict__ (*read_SPTR) is changed to half (* __restrict__
+// read_SPTR)
+template <typename TilingConfig>
+__device__ __forceinline__ void B_FromSharedToReg(
+    uint32_t (*__restrict__ Reg)[4],
+    half (*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
+    int slice_id) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  #ifdef DEBUG_MODE
+  static_assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) ||
+                (TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0));
+  #endif
+
+  const int warpId = threadIdx.x / WARP_SIZE;
+  int lane_id = threadIdx.x % WARP_SIZE;
+  int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS;
+  int warp_start_col =
+      TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 *
+      WARP_j;  // each warp may start from reading warp_start_col'th column of
+               // the B tile in shared memory
+  #ifdef DEBUG_MODE
+  assert(warp_start_col == 0);
+  #endif
+
+  int col = (lane_id % 8) + (lane_id / 16) * 8;
+  int row = (lane_id % 16) / 8 * 8;
+  uint32_t smem_local_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(
+      &read_SPTR[warp_start_col + col][slice_id * MMA_16 + row]));
+  if (TilingConfig::WARP_COL_MMA_TENSORS == 1) {
+    asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
+                 : "=r"(Reg[0][0]), "=r"(Reg[0][1])
+                 : "r"(smem_local_ptr));
+  } else {
+  #pragma unroll
+    for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS / 2; i++) {
+      asm volatile(
+          "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
+          : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3])
+          : "r"(smem_local_ptr));
+      smem_local_ptr +=
+          16 * (WARP_K + PADDING_SHARED_MEM_FOR_B_8) * sizeof(half);
+    }
+  }
+#endif
+}
+
+// MODIFICATION NOTE: to support MSVC, the function signature is changed from
+// MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a,
+// uint32_t __restrict__ *b).
+__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t* __restrict__ c,
+                                                  uint32_t* __restrict__ a,
+                                                  uint32_t* __restrict__ b) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  asm volatile(
+      "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+      "{ %0, %1, %2, %3},"
+      "{ %4, %5, %6, %7 },"
+      "{ %8, %9 },"
+      "{ %10, %11, %12, %13 };"
+      : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
+      : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
+        "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
+#endif
+}
+
+#endif  // PTX_MMA_CUH

+ 188 - 0
kernels/quantization/fp6/utils_core.cuh

@@ -0,0 +1,188 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is modified from
+// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_core.cuh
+
+#ifndef UTILS_CORE_CUH
+#define UTILS_CORE_CUH
+
+#include <assert.h>
+
+#include "configs.h"
+#include "ptx_mma.cuh"
+#include "utils_parallel_dequant.cuh"
+
+template <int NUM_INT_PER_THREAD>
+__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[],
+                                                               uint32_t* SPTR,
+                                                               int slice_id) {
+  SPTR += slice_id * (NUM_INT_PER_THREAD * WARP_SIZE);
+  int lane_id = threadIdx.x % WARP_SIZE;
+#pragma unroll
+  for (int i = 0; i < NUM_INT_PER_THREAD; i++) {
+    Reg[i] = SPTR[lane_id + i * WARP_SIZE];
+  }
+}
+
+// MODIFICATION NOTE: to support MSVC, half __restrict__
+// (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
+template <typename TilingConfig, int EXPONENT, int MANTISSA>
+__device__ __forceinline__ void initialize_mma_slice(
+    uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1BIT_SPTR_read,
+    uint32_t* __restrict__ A_2BIT_SPTR_read,
+    uint32_t* __restrict__ A_4BIT_SPTR_read,
+    half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
+    uint32_t* RPTR_Scales) {
+  // 1+2+4 weight split
+  constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
+  constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
+  constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
+  constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
+  // Writing registers
+  // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6
+  // per thread => 6 register per thread;
+  uint32_t a_1bit[1];  // NO double buffer
+  uint32_t a_2bit[2];  // NO double buffer
+  uint32_t a_4bit[4];  // NO double buffer
+  if (USE_SEG_1BIT)
+    CopyFromSharedToRegister_AFrag<1>(a_1bit, A_1BIT_SPTR_read, 0);
+  if (USE_SEG_2BIT)
+    CopyFromSharedToRegister_AFrag<2>(a_2bit, A_2BIT_SPTR_read, 0);
+  if (USE_SEG_4BIT)
+    CopyFromSharedToRegister_AFrag<4>(a_4bit, A_4BIT_SPTR_read, 0);
+  Dequant_32FP6_4Way<EXPONENT, MANTISSA>(
+      a, a_1bit, a_2bit, a_4bit,
+      RPTR_Scales);  // SIMT Dequant: dequantizing FPx to FP16 at register
+                     // level, dequantizing a slice each time
+  B_FromSharedToReg<TilingConfig>(b, B_SPTR_read,
+                                  0);  // Loading B from shared to registers
+}
+
+// MODIFICATION NOTE: to support MSVC, half __restrict__
+// (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
+template <typename TilingConfig, int EXPONENT, int MANTISSA>
+__device__ __forceinline__ void core_mma_slice(
+    float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4],
+    uint32_t (*b)[4], uint32_t* __restrict__ A_1bit_SPTR_read,
+    uint32_t* __restrict__ A_2bit_SPTR_read,
+    uint32_t* __restrict__ A_4bit_SPTR_read,
+    half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
+    uint32_t* RPTR_Scales,
+    int slice_id)  // writing slice[slice_id] to registers, k=0 -> slice_id=1
+                   // for prefetching
+{
+  // 1+2+4 weight split
+  constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
+  constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
+  constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
+  constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
+
+#ifdef DEBUG_MODE
+  assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) ||
+         (TilingConfig::WARP_COL_MMA_TENSORS % 2 ==
+          0));  // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded
+                // to a 16*16 MMA block
+#endif
+  const int NumRegSets_a =
+      WARP_ROW_MMA_TENSORS;  // 1 set = 4 registers, containing a 16*16 MMA
+                             // block
+  const int NumRegSets_b =
+      (TilingConfig::WARP_COL_MMA_TENSORS == 1)
+          ? 1
+          : TilingConfig::WARP_COL_MMA_TENSORS /
+                2;  // 1 set = 4 registers, containing a 16*16 MMA block
+  uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] =
+      reinterpret_cast<uint32_t(*)[REG_PER_THREAD_C_TENSOR_16_16]>(
+          c);  // GlobalRegisters for accumulated FP32 results
+
+  // Setting RPTRs for double buffers
+  uint32_t(*a_read)[4] = a;
+  uint32_t(*a_write)[4] = a;
+  uint32_t(*b_read)[4] = b;
+  uint32_t(*b_write)[4] = b;
+  if (slice_id % 2 == 1) {
+    b_write += NumRegSets_b;
+    a_write += NumRegSets_a;
+  } else {
+    b_read += NumRegSets_b;
+    a_read += NumRegSets_a;
+  }
+
+// Reading registers and issuing core tensor core computations (a slice of A and
+// B tile in shared memory)
+#pragma unroll
+  for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) {
+    if (TilingConfig::WARP_COL_MMA_TENSORS == 1) {
+      MMA_FP16_M16N8K16(c_uint_ptr[i], a_read[i], b_read[0]);
+    } else {
+#pragma unroll
+      for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS / 2; j++) {
+        MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i],
+                          b_read[j]);
+        MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4,
+                          a_read[i], b_read[j] + 2);  // c+4; b+2
+      }
+    }
+  }
+  // Writing registers
+  // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6
+  // per thread => 6 register per thread;
+  uint32_t a_1bit[1];  // NO double buffer
+  uint32_t a_2bit[2];  // NO double buffer
+  uint32_t a_4bit[4];  // NO double buffer
+  if (USE_SEG_1BIT)
+    CopyFromSharedToRegister_AFrag<1>(a_1bit, A_1bit_SPTR_read, slice_id);
+  if (USE_SEG_2BIT)
+    CopyFromSharedToRegister_AFrag<2>(a_2bit, A_2bit_SPTR_read, slice_id);
+  if (USE_SEG_4BIT)
+    CopyFromSharedToRegister_AFrag<4>(a_4bit, A_4bit_SPTR_read, slice_id);
+  Dequant_32FP6_4Way<EXPONENT, MANTISSA>(
+      a_write, a_1bit, a_2bit, a_4bit,
+      RPTR_Scales);  // SIMT Dequant: dequantizing FP6 to FP16 at register
+                     // level, dequantizing a slice each time
+  B_FromSharedToReg<TilingConfig>(
+      b_write, B_SPTR_read, slice_id);  // Loading B from shared to registers
+}
+
+template <typename TilingConfig>
+__device__ __forceinline__ void StoreToSharedMemoryFromRegister(
+    float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4],
+    float c[][REG_PER_THREAD_C_TENSOR_16_16]) {
+  const int lane_id = threadIdx.x % WARP_SIZE;
+  const int warpId = threadIdx.x / WARP_SIZE;
+  int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS);
+#pragma unroll
+  for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) {
+#pragma unroll
+    for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS;
+         j++) {  // Dealing with one 16*8 Tensor
+      int RegSetID = i + (j / 2) * WARP_ROW_MMA_TENSORS;
+      int RegOffset = (j % 2) * (REG_PER_THREAD_C_TENSOR_16_16 / 2);
+      int Tensor_row_offset = warp_row_offset + i * MMA_16;
+      int Tensor_col_offset = j * MMA_8;
+#pragma unroll
+      for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16 / 2; r++) {
+        int row_offset = lane_id / 4;
+        if (r >= 2) row_offset += 8;
+        int col_offset = (lane_id % 4) * 2;
+        if (r % 2 == 1) col_offset += 1;
+        smem_CFrag[Tensor_col_offset + col_offset]
+                  [Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset];
+      }
+    }
+  }
+}
+
+#endif

+ 94 - 0
kernels/quantization/fp6/utils_gmem.cuh

@@ -0,0 +1,94 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is modified from
+// https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh
+
+#ifndef UTILS_GMEM_CUH
+#define UTILS_GMEM_CUH
+
+#include <assert.h>
+#include "configs.h"
+#include "ptx_cp.async.cuh"
+
+/*
+ * Copying A1/A2 from global memory to shared memory.
+ * Usually 1024 or 2048 Bytes
+ */
+template <int SMEM_SIZE_IN_BYTES_PER_WARP>
+__device__ __forceinline__ void CopyFromGlobalToShared_A(
+    uint32_t* SPTR, const uint4* GPTR, bool pred_guard = true) {
+#ifdef DEBUG_MODE
+  static_assert(SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE % 16 == 0);
+#endif
+  int lane_id = threadIdx.x % WARP_SIZE;
+  half* SPTR_HALF = reinterpret_cast<half*>(SPTR);
+  const half* GPTR_HALF = reinterpret_cast<const half*>(GPTR);
+  SPTR_HALF += lane_id * 8;
+  GPTR_HALF += lane_id * 8;
+#pragma unroll
+  for (int i = 0; i < SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE / 16; i++) {
+    cp_async<16>(SPTR_HALF, GPTR_HALF, pred_guard);
+    SPTR_HALF += 256;  // Forward 512 Bytes
+    GPTR_HALF += 256;  // Forward 512 Bytes
+  }
+}
+
+/*
+ * Copying 64 Quant Scales (FP16) from global memory to shared memory.
+ */
+__device__ __forceinline__ void CopyFromGlobalToShared_Scales(
+    half* SPTR_QuantScales, const half* GPTR_A_Scales) {
+  int lane_id = threadIdx.x % WARP_SIZE;
+  int Offset_Shared = lane_id * 2;
+  int Offset_Global = lane_id / 4 + (lane_id % 4) * 16;
+  for (int i = 0; i < 2; i++)
+    SPTR_QuantScales[Offset_Shared + i] = GPTR_A_Scales[Offset_Global + i * 8];
+}
+
+// MODIFICATION NOTE: to support MSVC, half __restrict__
+// (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
+/*
+ * (1) Copying X  rows * 64 columns of FP16 values, originally in row    major
+ * (2) Copying 64 rows * X  columns of FP16 values, originally in column major
+ * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8
+ * Threads
+ */
+template <int MaxNumOfLinesToCopy, int BLOCK_WARPS>
+__device__ __forceinline__ void CopyFromGlobalToShared(
+    half (*__restrict__ SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
+    const half* GlobalPTR, const int GlobalStride,
+    const int NumOfLinesLeft,  // To support arbitrary N dimensions.
+    bool Pred = true) {
+  // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time
+  const int NumOfThreads = BLOCK_WARPS * WARP_SIZE;
+  const int NumOfGroups = NumOfThreads / 8;
+  const int MaxIteration = (MaxNumOfLinesToCopy - 1) / NumOfGroups + 1;
+  // runtime variables
+  const int line_id = threadIdx.x / 8;
+  const int line_offset = (threadIdx.x % 8) * 8;
+  // PTR for source global memory and target shared memory
+  GlobalPTR += line_id * GlobalStride + line_offset;
+  SharedPTR += line_id;
+#pragma unroll
+  for (int i = 0; i < MaxIteration; i++) {
+    bool AsyncCopyPred = (line_id + i * NumOfGroups) < NumOfLinesLeft && Pred;
+    cp_async<16>(&(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred);
+    //
+    GlobalPTR += NumOfGroups * GlobalStride;
+    SharedPTR += NumOfGroups;
+  }
+}
+
+#endif

+ 148 - 0
kernels/quantization/fp6/utils_parallel_dequant.cuh

@@ -0,0 +1,148 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is modified from
+// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_parallel_dequant.cuh
+// To support MSVC, all instances of u_int32_t are changed to uint32_t.
+
+#ifndef UTILS_PARALLELDEQUANT_CUH
+#define UTILS_PARALLELDEQUANT_CUH
+
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+/*
+ * Input:   R1
+ * Outputs: R1, R2
+ * Note:    Simplified Exponent calculation is applied.
+ */
+template <int EXPONENT, int MANTISSA>
+__device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t* In, uint32_t* Out1,
+                                                   uint32_t* Out2) {
+  //
+  constexpr int RIGHT_SHIFT = 5 - EXPONENT;
+  constexpr int MASK1 = 0x80000000;
+  constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA;
+  constexpr int MASK3 = MASK2 & 0x7fffffff;
+  constexpr int MASK = MASK3 | MASK3 >> 16;
+  //
+  *Out1 = *In & 0x80008000;
+  *Out1 |= ((*In) & MASK) >> RIGHT_SHIFT;
+  //
+  *In = (*In) << 8;
+  *Out2 = *In & 0x80008000;
+  *Out2 |= ((*In) & MASK) >> RIGHT_SHIFT;
+}
+
+template <int EXPONENT, int MANTISSA>
+__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair,
+                                              half Scale) {
+  constexpr int BIAS_OFFSET = (int(1) << (5 - 1)) - (int(1) << (EXPONENT - 1));
+  constexpr int BIAS = int(1) << BIAS_OFFSET;
+  //
+  half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair);
+  half* FP16_2 = FP16_1 + 1;
+  uint32_t output;
+  half* output_half_ptr = reinterpret_cast<half*>(&output);
+  output_half_ptr[0] =
+      __hmul(__hmul(*FP16_1, __float2half(1.0f * BIAS)), Scale);
+  output_half_ptr[1] =
+      __hmul(__hmul(*FP16_2, __float2half(1.0f * BIAS)), Scale);
+  return output;
+}
+
+// MODIFICATION NOTE: to support MSVC
+// - u_int32_t __restrict__ Reg[][4] is changed to below.
+// - u_int32_t __restrict__ *read_RPTR_1bit is changed to below. similarly for
+// read_RPTR_2bit and read_RPTR_4bit
+template <int EXPONENT, int MANTISSA>
+__device__ __forceinline__ void Dequant_32FP6_4Way(
+    uint32_t (*__restrict__ Reg)[4], uint32_t* __restrict__ read_RPTR_1bit,
+    uint32_t* __restrict__ read_RPTR_2bit,
+    uint32_t* __restrict__ read_RPTR_4bit, uint32_t* Scales) {
+  // 1+2+4 weight split
+  constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
+  constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
+  constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
+  constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
+  //
+  uint32_t* OutputRegs = reinterpret_cast<uint32_t*>(Reg);
+  uint32_t* Frag_PTR_1bit = read_RPTR_1bit;
+  uint32_t* Frag_PTR_2bit = read_RPTR_2bit;
+  uint32_t* Frag_PTR_4bit = read_RPTR_4bit;
+  half* Scale_RPTR = reinterpret_cast<half*>(Scales);
+// Dequantizing 32 FP6, each Loop dequantizing 4 FP6
+#pragma unroll(8)
+  for (int i = 0; i < 8; i++) {
+    uint32_t Packed_FP6 = 0;
+    uint32_t tmp = 0;
+    // 1bit Frag
+    if (USE_SEG_1BIT) {
+      tmp = (*Frag_PTR_1bit) & 0x80808080;
+      Packed_FP6 |= tmp >> (BIT_WIDTH & 0);
+      if (i % 8 == 7)
+        Frag_PTR_1bit++;
+      else
+        (*Frag_PTR_1bit) = (*Frag_PTR_1bit) << 1;
+    }
+    // 2bit Frag
+    if (USE_SEG_2BIT) {
+      tmp = (*Frag_PTR_2bit) & 0xc0c0c0c0;
+      Packed_FP6 |= tmp >> (BIT_WIDTH & 1);
+      if (i % 4 == 3)
+        Frag_PTR_2bit++;
+      else
+        (*Frag_PTR_2bit) = (*Frag_PTR_2bit) << 2;
+    }
+    // 4bit Frag2
+    if (USE_SEG_4BIT) {
+      tmp = (*Frag_PTR_4bit) & 0xf0f0f0f0;
+      Packed_FP6 |= tmp >> (BIT_WIDTH & 3);
+      if (i % 2 == 1)
+        Frag_PTR_4bit++;
+      else
+        (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4;
+    }
+    //
+    uint32_t out1, out2;
+    FPx_FP16_Cast_4Way<EXPONENT, MANTISSA>(&Packed_FP6, &out1, &out2);
+    //
+    *OutputRegs = MultScale<EXPONENT, MANTISSA>(
+        out1, Scale_RPTR[0]);  // Multiply FP16 scales
+    OutputRegs += 1;
+    *OutputRegs = MultScale<EXPONENT, MANTISSA>(
+        out2, Scale_RPTR[1]);  // Multiply FP16 scales
+    OutputRegs += 1;
+    // Updating offset for FP16 scales for every two iterations
+    if (i % 2 == 1) Scale_RPTR += 2;
+  }
+}
+
+/*
+ *
+ */
+__device__ __forceinline__ void ExtractFromSharedToReg_Scales(
+    uint32_t* Scales, half* WARP_SPTR_Scales) {
+  int lane_id = threadIdx.x % WARP_SIZE;
+  uint32_t* SPTR_uint = reinterpret_cast<uint32_t*>(WARP_SPTR_Scales);
+  uint32_t tmpReg = SPTR_uint[lane_id];
+#pragma unroll
+  for (int i = 0; i < 4; i++) {
+    // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize);
+    Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4);
+  }
+}
+
+#endif

+ 6 - 0
kernels/quantization/quant_ops.h

@@ -132,6 +132,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
                               torch::Tensor& workspace, int64_t size_m,
                               int64_t size_n, int64_t size_k);
 
+torch::Tensor fp_eXmY_linear_forward_cuda(int64_t EXPONENT, int64_t MANTISSA,
+                                          torch::Tensor _in_feats,
+                                          torch::Tensor _weights,
+                                          torch::Tensor _scales,
+                                          int64_t splitK = 1);
+
 #endif
 
 void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,

+ 9 - 0
kernels/torch_bindings.cpp

@@ -196,6 +196,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   // QuIP# Decompress
   ops.def("quip_decompress", &decompress_e8p_origorder);
   ops.impl("quip_decompress", torch::kCUDA, &decompress_e8p_origorder);
+
+  // fp6_llm
+  ops.def(
+      "fp_eXmY_linear_forward_cuda(int EXPONENT, int MANTISSA,"
+      "                            Tensor _in_feats, Tensor _weights,"
+      "                            Tensor _scales, int splitK=1) -> Tensor");
+  ops.impl("fp_eXmY_linear_forward_cuda", torch::kCUDA,
+           &fp_eXmY_linear_forward_cuda);
+
 #endif
 
   // Quantized GEMM for GPTQ.

+ 9 - 0
tests/benchmarks/engine/throughput.py

@@ -66,6 +66,7 @@ def run_aphrodite(
     model: str,
     tokenizer: str,
     quantization: Optional[str],
+    quant_llm_fp_bits: Optional[int],
     tensor_parallel_size: int,
     seed: int,
     n: int,
@@ -90,6 +91,7 @@ def run_aphrodite(
         model=model,
         tokenizer=tokenizer,
         quantization=quantization,
+        quant_llm_fp_bits=quant_llm_fp_bits,
         tensor_parallel_size=tensor_parallel_size,
         seed=seed,
         trust_remote_code=trust_remote_code,
@@ -226,6 +228,7 @@ def main(args: argparse.Namespace):
     if args.backend == "aphrodite":
         elapsed_time = run_aphrodite(
             requests, args.model, args.tokenizer, args.quantization,
+            args.quant_llm_fp_bits,
             args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
             args.trust_remote_code, args.dtype, args.max_model_len,
             args.enforce_eager, args.kv_cache_dtype,
@@ -286,6 +289,12 @@ if __name__ == "__main__":
                         '-q',
                         choices=[*QUANTIZATION_METHODS, None],
                         default=None)
+    parser.add_argument('--quant-llm-fp-bits',
+                        type=int,
+                        default=None,
+                        choices=[4, 5, 6, 7],
+                        help="Number of bits for the FP quantization in "
+                        "QuantLLM")
     parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
     parser.add_argument("--n",
                         type=int,