Browse Source

feat: quant_llm support (#755)

* feat: add fp6 quant kernels

Co-authored-by: intervitens <intervitens@tutanota.com>

* add implementation

* formatting

* fix: add prefix

* codespell

* compile kernel for e3m3

* add fp4_e2m1 support

* add support for fp2, fp3 and configurable exponent

* bump capability to sm_80

* missed one more directive

* clean up

---------

Co-authored-by: intervitens <intervitens@tutanota.com>
AlpinDale 5 months ago
parent
commit
73177656ed

+ 1 - 0
CMakeLists.txt

@@ -199,6 +199,7 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
   FetchContent_MakeAvailable(cutlass)
   FetchContent_MakeAvailable(cutlass)
 
 
   list(APPEND APHRODITE_EXT_SRC
   list(APPEND APHRODITE_EXT_SRC
+    "kernels/quantization/fp6/fp6_linear.cu"
     "kernels/mamba/mamba_ssm/selective_scan_fwd.cu"
     "kernels/mamba/mamba_ssm/selective_scan_fwd.cu"
     "kernels/mamba/causal_conv1d/causal_conv1d.cu"
     "kernels/mamba/causal_conv1d/causal_conv1d.cu"
     "kernels/quantization/aqlm/gemm_kernels.cu"
     "kernels/quantization/aqlm/gemm_kernels.cu"

+ 14 - 0
aphrodite/_custom_ops.py

@@ -465,6 +465,20 @@ def ggml_mul_mat_a8(
     return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
     return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
 
 
 
 
+# fp6
+def fp_eXmY_linear_forward_cuda(
+    EXPONENT: int,
+    MANTISSA: int,
+    _in_feats: torch.Tensor,
+    _weights: torch.Tensor,
+    _scales: torch.Tensor,
+    splitK: int = 1,
+) -> torch.Tensor:
+    return torch.ops._C.fp_eXmY_linear_forward_cuda(EXPONENT, MANTISSA,
+                                                    _in_feats, _weights,
+                                                    _scales, splitK)
+
+
 # mamba
 # mamba
 def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
 def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
                       bias_: Optional[torch.Tensor],
                       bias_: Optional[torch.Tensor],

+ 75 - 0
aphrodite/common/config.py

@@ -48,6 +48,12 @@ _PP_SUPPORTED_MODELS = [
 ]
 ]
 
 
 _OPTIMIZED_QUANTS = [
 _OPTIMIZED_QUANTS = [
+    "fp2",
+    "fp3",
+    "fp4",
+    "fp5",
+    "fp6",
+    "fp7",
     "fp8",
     "fp8",
     "marlin",
     "marlin",
     "gptq_marlin_24",
     "gptq_marlin_24",
@@ -57,6 +63,7 @@ _OPTIMIZED_QUANTS = [
     "compressed-tensors",
     "compressed-tensors",
     "compressed_tensors",
     "compressed_tensors",
     "experts_int8",
     "experts_int8",
+    "quant_llm",
 ]
 ]
 
 
 
 
@@ -95,6 +102,8 @@ class ModelConfig:
             weights. If None, we assume the model weights are not quantized.
             weights. If None, we assume the model weights are not quantized.
         deepspeed_fp_bits: Number of bits to use for DeepSpeed FP quantization.
         deepspeed_fp_bits: Number of bits to use for DeepSpeed FP quantization.
             Supported number of bits are: 4, 6, 8, 12.
             Supported number of bits are: 4, 6, 8, 12.
+        quant_llm_fp_bits: Number of bits to use for QuantLLM FP quantization.
+            Supported number of bits are: 5, 6, 7.
         quantization_param_path: Path to JSON file containing scaling factors.
         quantization_param_path: Path to JSON file containing scaling factors.
             Used to load KV cache scaling factors into the model when KV cache
             Used to load KV cache scaling factors into the model when KV cache
             type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
             type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
@@ -142,6 +151,8 @@ class ModelConfig:
         max_model_len: Optional[int] = None,
         max_model_len: Optional[int] = None,
         quantization: Optional[str] = None,
         quantization: Optional[str] = None,
         deepspeed_fp_bits: Optional[int] = None,
         deepspeed_fp_bits: Optional[int] = None,
+        quant_llm_fp_bits: Optional[int] = None,
+        quant_llm_exp_bits: Optional[int] = None,
         quantization_param_path: Optional[str] = None,
         quantization_param_path: Optional[str] = None,
         enforce_eager: Optional[bool] = None,
         enforce_eager: Optional[bool] = None,
         max_context_len_to_capture: Optional[int] = None,
         max_context_len_to_capture: Optional[int] = None,
@@ -168,6 +179,8 @@ class ModelConfig:
             self.tokenizer_revision = tokenizer_revision
             self.tokenizer_revision = tokenizer_revision
         self.quantization = quantization
         self.quantization = quantization
         self.deepspeed_fp_bits = deepspeed_fp_bits
         self.deepspeed_fp_bits = deepspeed_fp_bits
+        self.quant_llm_fp_bits = quant_llm_fp_bits
+        self.quant_llm_exp_bits = quant_llm_exp_bits
         self.quantization_param_path = quantization_param_path
         self.quantization_param_path = quantization_param_path
         self.enforce_eager = enforce_eager
         self.enforce_eager = enforce_eager
         self.max_context_len_to_capture = max_context_len_to_capture
         self.max_context_len_to_capture = max_context_len_to_capture
@@ -316,6 +329,68 @@ class ModelConfig:
                 "quant_method": "deepspeedfp"
                 "quant_method": "deepspeedfp"
             }
             }
 
 
+        VALID_QUANT_LLM_FP_BITS = [2, 3, 4, 5, 6, 7]
+        VALID_QUANT_LLM_EXPONENTS = [1, 2, 3, 4, 5]
+        # The formula is mantissa_bits = fp_bits - exp_bits - 1
+        # The default exp_bits for each fp_bits are as follows:
+        DEFAULT_EXP_BITS = {
+            2: 1,
+            3: 2,
+            4: 2,
+            5: 2,
+            6: 2,
+            7: 3,
+        }
+
+        if self.quantization == "quant_llm":
+            if self.quant_llm_fp_bits is None:
+                raise ValueError(
+                    "quant_llm_fp_bits must be specified when using "
+                    "quant_llm quantization."
+                )
+            if self.quant_llm_fp_bits not in VALID_QUANT_LLM_FP_BITS:
+                raise ValueError(
+                    f"Invalid quant_llm_fp_bits: {self.quant_llm_fp_bits}. "
+                    f"Must be one of {VALID_QUANT_LLM_FP_BITS}."
+                )
+            if self.quant_llm_exp_bits is None:
+                self.quant_llm_exp_bits = DEFAULT_EXP_BITS[
+                    self.quant_llm_fp_bits]
+            else:
+                if self.quant_llm_exp_bits not in VALID_QUANT_LLM_EXPONENTS:
+                    raise ValueError(
+                        f"Invalid exponent bits: {self.quant_llm_exp_bits}. "
+                        f"Must be one of {VALID_QUANT_LLM_EXPONENTS}."
+                    )
+
+            self.hf_config.quantization_config = {
+                "bits": self.quant_llm_fp_bits,
+                "exp_bits": self.quant_llm_exp_bits,
+                "quant_method": "quant_llm"
+            }
+            
+        online_quant_methods = ["fp2", "fp3", "fp4", "fp5", "fp6", "fp7"]
+        if self.quantization is not None and self.quantization in \
+            online_quant_methods:
+            fp_bits = int(self.quantization[2])
+            if fp_bits not in VALID_QUANT_LLM_FP_BITS:
+                raise ValueError(
+                    f"Invalid quant_llm_fp_bits: {fp_bits}. "
+                    f"Must be one of {VALID_QUANT_LLM_FP_BITS}."
+                )
+            if fp_bits in [2, 3]:
+                logger.warning("FP2 and FP3 quantization methods lead to "
+                               "significant accuracy loss. Use them with "
+                               "caution. Model may be incoherent.")
+            exp_bits = DEFAULT_EXP_BITS[fp_bits]
+            self.hf_config.quantization_config = {
+                "bits": fp_bits,
+                "exp_bits": exp_bits,
+                "quant_method": self.quantization
+            }
+            self.dtype = torch.float16
+            self.enforce_eager = True
+
         if self.quantization is not None:
         if self.quantization is not None:
             if self.quantization not in supported_quantization:
             if self.quantization not in supported_quantization:
                 raise ValueError(
                 raise ValueError(

+ 18 - 2
aphrodite/engine/args_tools.py

@@ -94,6 +94,8 @@ class EngineArgs:
     quantization_param_path: Optional[str] = None
     quantization_param_path: Optional[str] = None
     preemption_mode: Optional[str] = None
     preemption_mode: Optional[str] = None
     deepspeed_fp_bits: Optional[int] = None
     deepspeed_fp_bits: Optional[int] = None
+    quant_llm_fp_bits: Optional[int] = None
+    quant_llm_exp_bits: Optional[int] = None
     # Cache Options
     # Cache Options
     kv_cache_dtype: str = "auto"
     kv_cache_dtype: str = "auto"
     block_size: int = 16
     block_size: int = 16
@@ -498,8 +500,20 @@ class EngineArgs:
                             type=int,
                             type=int,
                             default=None,
                             default=None,
                             help="Category: Quantization Options\n"
                             help="Category: Quantization Options\n"
-                            "Number of floating bits to use for the deepseed "
-                            "quantization. Supported bits are: 4, 6, 8, 12. ")
+                            "Number of floating bits to use for the deepspeed "
+                            "quantization. Supported bits are: 4, 6, 8, 12.")
+        parser.add_argument("--quant-llm-fp-bits",
+                            type=int,
+                            default=None,
+                            help="Category: Quantization Options\n"
+                            "Number of floating bits to use for the quant_llm "
+                            "quantization. Supported bits are: 4 to 15.")
+        parser.add_argument("--quant-llm-exp-bits",
+                            type=int,
+                            default=None,
+                            help="Category: Quantization Options\n"
+                            "Number of exponent bits to use for the quant_llm "
+                            "quantization. Supported bits are: 1 to 5.")
         # Cache Options
         # Cache Options
         parser.add_argument(
         parser.add_argument(
             '--kv-cache-dtype',
             '--kv-cache-dtype',
@@ -886,6 +900,8 @@ class EngineArgs:
             max_model_len=self.max_model_len,
             max_model_len=self.max_model_len,
             quantization=self.quantization,
             quantization=self.quantization,
             deepspeed_fp_bits=self.deepspeed_fp_bits,
             deepspeed_fp_bits=self.deepspeed_fp_bits,
+            quant_llm_fp_bits=self.quant_llm_fp_bits,
+            quant_llm_exp_bits=self.quant_llm_exp_bits,
             quantization_param_path=self.quantization_param_path,
             quantization_param_path=self.quantization_param_path,
             enforce_eager=self.enforce_eager,
             enforce_eager=self.enforce_eager,
             max_context_len_to_capture=self.max_context_len_to_capture,
             max_context_len_to_capture=self.max_context_len_to_capture,

+ 9 - 0
aphrodite/quantization/__init__.py

@@ -11,6 +11,7 @@ from aphrodite.quantization.deepspeedfp import DeepSpeedFPConfig
 from aphrodite.quantization.eetq import EETQConfig
 from aphrodite.quantization.eetq import EETQConfig
 from aphrodite.quantization.experts_int8 import ExpertsInt8Config
 from aphrodite.quantization.experts_int8 import ExpertsInt8Config
 from aphrodite.quantization.fbgemm_fp8 import FBGEMMFp8Config
 from aphrodite.quantization.fbgemm_fp8 import FBGEMMFp8Config
+from aphrodite.quantization.fp6 import QuantLLMFPConfig
 from aphrodite.quantization.fp8 import Fp8Config
 from aphrodite.quantization.fp8 import Fp8Config
 from aphrodite.quantization.gguf import GGUFConfig
 from aphrodite.quantization.gguf import GGUFConfig
 from aphrodite.quantization.gptq import GPTQConfig
 from aphrodite.quantization.gptq import GPTQConfig
@@ -29,6 +30,7 @@ QUANTIZATION_METHODS = {
     "tpu_int8": Int8TpuConfig,
     "tpu_int8": Int8TpuConfig,
     "eetq": EETQConfig,
     "eetq": EETQConfig,
     "fp8": Fp8Config,
     "fp8": Fp8Config,
+    "quant_llm": QuantLLMFPConfig,
     "fbgemm_fp8": FBGEMMFp8Config,
     "fbgemm_fp8": FBGEMMFp8Config,
     "gguf": GGUFConfig,
     "gguf": GGUFConfig,
     # The order of gptq methods is important for config.py iteration over
     # The order of gptq methods is important for config.py iteration over
@@ -44,6 +46,13 @@ QUANTIZATION_METHODS = {
     "bitsandbytes": BitsAndBytesConfig,
     "bitsandbytes": BitsAndBytesConfig,
     "qqq": QQQConfig,
     "qqq": QQQConfig,
     "experts_int8": ExpertsInt8Config,
     "experts_int8": ExpertsInt8Config,
+    # the quant_llm methods
+    "fp2": QuantLLMFPConfig,
+    "fp3": QuantLLMFPConfig,
+    "fp4": QuantLLMFPConfig,
+    "fp5": QuantLLMFPConfig,
+    "fp6": QuantLLMFPConfig,
+    "fp7": QuantLLMFPConfig,
 }
 }
 
 
 
 

+ 198 - 0
aphrodite/quantization/fp6.py

@@ -0,0 +1,198 @@
+from typing import Any, Dict, List, Optional
+
+import torch
+import torch.nn as nn
+from loguru import logger
+
+from aphrodite import _custom_ops as ops
+from aphrodite.distributed import get_tensor_model_parallel_rank
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.quantization.utils.fp6_utils import (_SPLIT_K_MAP,
+                                                    from_scaled_tc_fpx,
+                                                    to_scaled_tc_fpx)
+
+
+class QuantLLMFPConfig(QuantizationConfig):
+    """Config for QuantLLM FP quantizer. It supports fp2, fp3, fp4,
+    fp5, fp6, fp7.
+    
+    Reference: https://arxiv.org/abs/2401.14112
+    
+    Args: 
+        weight_bits: the target quantization bits, should be one of
+            2, 3, 4, 5, 6, 7.
+    """
+
+    def __init__(
+        self,
+        weight_bits: int = 6,
+        exp_bits: int = 2,
+    ) -> None:
+        self.weight_bits = weight_bits
+        self.exponent_bits = exp_bits
+
+        self.mantissa_bits = weight_bits - self.exponent_bits - 1
+
+        self.valid_types = [torch.float16]
+
+        if self.weight_bits not in [2, 3, 4, 5, 6, 7]:
+            raise ValueError(
+                "Currently, only 4-bit, 5-bit, 6-bit, and 7-bit "
+                "quantization are "
+                f"supported for QuantLLM FP quantizaiton, but got "
+                f"{self.weight_bits} bits.")
+        
+        if get_tensor_model_parallel_rank() == 0:
+            logger.info(f"Loading model in FP{self.weight_bits}_E"
+                        f"{self.exponent_bits}M{self.mantissa_bits} format.")
+
+    def __repr__(self) -> str:
+        return (f"QuantLLMFPConfig(weight_bits={self.weight_bits}), "
+                f"exponent_bits={self.exponent_bits}")
+
+    @classmethod
+    def get_name(cls) -> str:
+        return "QuantLLMFP"
+
+    @classmethod
+    def from_config(cls, config: Dict[str, Any]) -> "QuantLLMFPConfig":
+        weight_bits = cls.get_from_keys(config, ["bits"])
+        exp_bits = cls.get_from_keys(config, ["exp_bits"])
+        return cls(weight_bits=weight_bits, exp_bits=exp_bits)
+
+    def get_linear_method(self) -> "QuantLLMFPLinearMethod":
+        return QuantLLMFPLinearMethod(self)
+
+    def get_scaled_act_names(self) -> List[str]:
+        return []
+
+    @classmethod
+    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+        return [torch.half]
+
+    @classmethod
+    # Need to figure it out
+    def get_min_capability(cls) -> int:
+        return 80
+
+    @staticmethod
+    def get_config_filenames() -> List[str]:
+        return [
+            "quant_config.json",
+            "quantize_config.json",
+        ]
+
+    def get_quant_method(
+            self,
+            layer: torch.nn.Module,
+            prefix: str) -> Optional["QuantLLMFPLinearMethod"]:
+        if isinstance(layer, LinearBase):
+            return QuantLLMFPLinearMethod(self)
+        return None
+
+
+class QuantLLMFPLinearMethod(LinearMethodBase):
+    """Linear method for QuantLLMFP quantizer.
+    Args:
+        quant_config: the QuantLLMFP quantization config.
+    """
+
+    def __init__(self, quant_config: QuantLLMFPConfig):
+        self.quant_config = quant_config
+        self.weight = None
+
+    def create_weights(self,
+                       layer: torch.nn.Module,
+                       input_size_per_partition: int,
+                       output_partition_sizes: List[int],
+                       input_size: int,
+                       output_size: int,
+                       params_dtype: torch.dtype,
+                       weight_loader=None,
+                       **extra_weight_attrs):
+        del output_size
+        del input_size
+        output_size_per_partition = sum(output_partition_sizes)
+        weight = QuantLLMFPParameter(
+            torch.Size((output_size_per_partition, input_size_per_partition)),
+            params_dtype=params_dtype,
+            quant_config=self.quant_config,
+        )
+        set_weight_attrs(weight, {
+            "input_dim": 1,
+            "output_dim": 0,
+        })
+        layer.register_parameter("weight", weight)
+
+        def quant_weight_loader(param, loaded_weight, *args, **kwargs):
+            # Calls the original weight loader (if any), quantizes the result,
+            # and then loads the quantized parameter.
+            if weight_loader is not None:
+                orig_param_data = param.data
+                param.data = param.quant_llmdequantize()
+                weight_loader(param, loaded_weight, *args, **kwargs)
+                param.data, loaded_weight = orig_param_data, param.data
+            param.quant_llmquantize_(loaded_weight.cuda())
+
+        extra_weight_attrs["weight_loader"] = quant_weight_loader
+        set_weight_attrs(weight, extra_weight_attrs)
+
+    def apply(self,
+              layer,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+        weight = layer.weight
+        weights = weight.data
+        scales = weight.scales
+        out_dim, in_dim = weights.shape
+        bsize = x.shape[0]
+        splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(
+            out_dim, 1) if bsize <= 768 else 1
+        if bias is None:
+            return ops.fp_eXmY_linear_forward_cuda(
+                self.quant_config.exponent_bits,
+                self.quant_config.mantissa_bits,
+                x, weights, scales, splitK)
+        else:
+            return ops.fp_eXmY_linear_forward_cuda(
+                self.quant_config.exponent_bits,
+                self.quant_config.mantissa_bits,
+                x, weights, scales, splitK) + bias
+
+class QuantLLMFPParameter(nn.Parameter):
+    """
+    QuantLLMFP quantized parameter class that implements fp5/fp6/fp7
+    quantization. Weights are stored in quantized form on
+    GPUs, and can be directly applied to float16 activations.
+    """
+
+    def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
+                quant_config: QuantLLMFPConfig):
+
+        data = torch.empty(torch.Size((orig_shape[0],
+                            orig_shape[1] * quant_config.weight_bits // 8)),
+                                   dtype=torch.uint8)
+
+
+        self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
+        self.scales = torch.empty(orig_shape[0],
+                                  dtype=torch.float16)
+        self.quant_config = quant_config
+        self.orig_shape = orig_shape
+        return self
+
+    def quant_llmquantize_(self, tensor: torch.Tensor):
+        assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
+        data, scales = to_scaled_tc_fpx(
+            tensor.data, self.quant_config.exponent_bits,
+            self.quant_config.mantissa_bits)
+        self.data.copy_(data)
+        self.scales.copy_(scales)
+
+    def quant_llmdequantize(self, output_dtype=None):
+        output_dtype = output_dtype or torch.get_default_dtype()
+        return from_scaled_tc_fpx(self.data, self.quant_config.exponent_bits, 
+                        self.quant_config.mantissa_bits, self.scales
+                        ).to(output_dtype)

+ 585 - 0
aphrodite/quantization/utils/fp6_utils.py

@@ -0,0 +1,585 @@
+# ruff: noqa
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3).
+# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain:
+#   1. No encodings are reserved for special values (+/-inf, NaN).
+#   2. When downcasting from FP32 to FPx,
+#      - Rounding mode is round to nearest, ties to even.
+#      - Values outside the representable range of FPx after rounding are clamped to the maximum FPx
+#      magnitude (sign is preserved).
+from functools import reduce
+from typing import Tuple
+
+import torch
+from torch import Tensor
+
+
+def _n_ones(n: int) -> int:
+    return (1 << n) - 1
+
+
+EBITS_F32, MBITS_F32 = 8, 23
+F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
+
+# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py
+_SPLIT_K_MAP = [
+    {  # tokens: [1, 64]
+        3072: 18,
+        4096: 13,
+        5120: 10,
+        6144: 9,
+        8192: 6,
+        10240: 5,
+        14336: 7,
+        28672: 7,
+        57344: 7
+    },
+    {  # tokens: [65:128]
+        3072: 9,
+        4096: 6,
+        5120: 5,
+        6144: 9,
+        8192: 3,
+        10240: 5,
+        14336: 7,
+        28672: 7,
+        57344: 6
+    },
+    {  # tokens: [129:192]
+        3072: 6,
+        4096: 4,
+        5120: 7,
+        6144: 3,
+        8192: 2,
+        10240: 5,
+        14336: 5,
+        28672: 5,
+        57344: 4
+    },
+    {  # tokens: [193:256]
+        3072: 9,
+        4096: 3,
+        5120: 5,
+        6144: 2,
+        8192: 5,
+        10240: 4,
+        14336: 8,
+        28672: 6,
+        57344: 4
+    },
+    {  # tokens: [257:320]
+        3072: 7,
+        4096: 5,
+        5120: 2,
+        6144: 5,
+        8192: 4,
+        10240: 1,
+        14336: 3,
+        28672: 3,
+        57344: 4
+    },
+    {  # tokens: [321:384]
+        3072: 3,
+        4096: 2,
+        5120: 5,
+        6144: 3,
+        8192: 1,
+        10240: 8,
+        14336: 3,
+        28672: 4,
+        57344: 3
+    },
+    {  # tokens: [385:448]
+        3072: 5,
+        4096: 7,
+        5120: 3,
+        6144: 5,
+        8192: 7,
+        10240: 3,
+        14336: 1,
+        28672: 1,
+        57344: 3
+    },
+    {  # tokens: [449:512]
+        3072: 2,
+        4096: 5,
+        5120: 4,
+        6144: 1,
+        8192: 5,
+        10240: 2,
+        14336: 6,
+        28672: 4,
+        57344: 1
+    },
+    {  # tokens: [513:576]
+        3072: 2,
+        4096: 3,
+        5120: 1,
+        6144: 1,
+        8192: 3,
+        10240: 3,
+        14336: 3,
+        28672: 1,
+        57344: 1
+    },
+    {  # tokens: [577:640]
+        3072: 5,
+        4096: 4,
+        5120: 1,
+        6144: 4,
+        8192: 2,
+        10240: 1,
+        14336: 1,
+        28672: 1,
+        57344: 1
+    },
+    {  # tokens: [641:704]
+        3072: 3,
+        4096: 1,
+        5120: 2,
+        6144: 2,
+        8192: 1,
+        10240: 2,
+        14336: 1,
+        28672: 1,
+        57344: 1
+    },
+    {  # tokens: [705:768]
+        3072: 3,
+        4096: 1,
+        5120: 3,
+        6144: 2,
+        8192: 1,
+        10240: 1,
+        14336: 1,
+        28672: 1,
+        57344: 1
+    }
+]
+
+
+def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
+    """Convert FP32 numbers to sub-byte floating point numbers with the given
+    number of exponent and mantissa bits.
+    Input: torch.Tensor of dtype torch.float
+    Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
+    in the least significant bits. e.g.
+      fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
+      fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
+    Note: there are no special values (NaN, inf) support in this code. Values
+    outside the representable range of FPx after rounding are clamped to the
+    maximum FPx magnitude (sign is preserved).
+    Code below is an adaptation of https://fburl.com/code/ciwofcg4
+    Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers  # noqa: E501
+    Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
+    """
+    assert x.dtype == torch.float
+    assert 1 + ebits + mbits <= 8
+
+    # calculate constants
+    exp_bias = _n_ones(ebits - 1)
+    max_int = _n_ones(ebits + mbits)
+    sign_mask = 1 << (ebits + mbits)
+
+    # TODO document this better
+    magic_adder = _n_ones(MBITS_F32 - mbits - 1)
+
+    # all E bits and M bits are 1s
+    max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
+
+    # E bits = 1, M bits = 0
+    min_normal = 2 ** (1 - exp_bias)
+
+    denorm_exp = (
+        # exp bias conversion between formats
+        (F32_EXP_BIAS - exp_bias)
+        # mantissa length difference between formats
+        + (MBITS_F32 - mbits)
+        # add one to encoded exponent for denormalized numbers
+        + 1
+    )
+    denorm_mask_int = denorm_exp << MBITS_F32
+
+    # reinterpret int32 as float32
+    denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)
+
+    # save the sign
+    # Note that we have torch.uint32, but some ops like cpu bit shifts
+    # do not work on it. So, we stay in int32.
+    x = x.view(torch.int32)
+    sign = x & 0x80000000
+
+    # set everything to positive, will add sign back at the end
+    x = x ^ sign
+
+    # TODO: can the branch floating point comparisons below be done without
+    # converting to float? probably but need to verify
+    x = x.view(torch.float)
+
+    # rewrite saturate/denorm/norm branches without explicit data dependent
+    # control flow, to be more compiler friendly
+    saturate_mask = x >= max_normal
+    denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
+    normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
+
+    #
+    # branch 1: saturate to max val - handled later in the code which combines
+    #   the branches
+    #
+
+    #
+    # branch 2: to conversion to denormal as well as rounding up to normal
+    #
+    denormal_x = x + denorm_mask_float
+    denormal_x = denormal_x.view(torch.int32)
+    denormal_x -= denorm_mask_int
+    denormal_x = denormal_x.to(torch.uint8)
+
+    #
+    # branch 3: stay in normal range, adjust the exponent and round
+    #
+    normal_x = x.view(torch.int32)
+    # resulting mantissa is odd
+    mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
+    # update exponent, rounding bias part 1
+    val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
+    normal_x += val_to_add
+    # rounding bias part 2
+    normal_x += mant_odd
+    # take the bits!
+    normal_x = normal_x >> (MBITS_F32 - mbits)
+    normal_x = normal_x.to(torch.uint8)
+
+    #
+    # combine the branches
+    #
+    x = torch.full_like(x, max_int, dtype=torch.uint8)
+    x = torch.where(denormal_mask, denormal_x, x)
+    x = torch.where(normal_mask, normal_x, x)
+
+    # add sign back
+    sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
+    sign_lp = sign_lp.to(torch.uint8)
+    # Right shift of a negative signed integer can fill the least significant
+    # bits with either 1s or 0s, depending on the implementation. Since PyTorch
+    # doesn't have an uint32 dtype, we mask out these bits to get just the
+    # f4 sign bit
+    sign_lp = sign_lp & sign_mask
+    x = x | sign_lp
+
+    return x.to(torch.uint8)
+
+
+# TODO(future): check if LUT for everything is faster than bit shifting,
+# especially for fp4 (only 2^4=16 unique values).
+def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
+    """Convert sub-byte floating point numbers with the given number of exponent
+    and mantissa bits to FP32.
+    Input: torch.Tensor of dtype uint8, where the bit encoding is stored
+    in the least significant bits. e.g.
+      fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
+      fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
+    Output: torch.Tensor of dtype fp32 with the dequantized value
+    """
+    assert x.dtype == torch.uint8
+    assert 1 + ebits + mbits <= 8
+
+    sign_mask = 1 << (ebits + mbits)
+    exp_bias = _n_ones(ebits - 1)
+    mantissa_mask = _n_ones(mbits)
+
+    # save the sign
+    sign_lp = x & sign_mask
+
+    # set everything to positive, will add sign back at the end
+    x_pos = x ^ sign_lp
+
+    #
+    # 1. Calculate zero mask
+    #
+    zero_mask = x_pos == 0
+
+    #
+    # 2. Calculate the denormal path mask
+    #
+    denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))
+
+    #
+    # 3. Calculate the normal path
+    #
+
+    # calculate the new exponent and shift it to bits 2:9 of the result
+    exp_biased_lp = x_pos >> mbits
+    exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
+    exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32
+
+    # shift the mantissa to bits 10:32 of the result
+    mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
+    mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
+    result = exp_biased_f32 | mantissa_f32
+
+    #
+    # 4. Add the zero and denormal casts to the already casted normal path
+    #
+    result[zero_mask] = 0
+
+    denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS
+
+    # fast path.
+    # without this, performance for FP4_E2M1 is slower by 2x
+    if mbits == 1:
+        result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32
+
+    else:
+        # iterate over all possible values of mantissa
+        # i=0, j=1
+        # i=1, j=10,11
+        # i=2, j=100,101,110,111
+        # and so on
+        for i in range(mbits):
+            for mantissa_cmp in range(1 << i, 1 << (i+1)):
+                # left shift mantissa until it overflows (create an implicit 1)
+                # subtract exponent by the same amount
+                left_shift = mbits - i
+                mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits)
+                exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32
+
+                # we can update this in-place since the values won't overlap
+                # torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int'
+                # thus we use + instead of | here
+                mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 + mantissa_f32
+
+        result = torch.where(denormal_mask, mantissa_lp_int32, result)
+
+    # add sign back
+    sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
+    result = result | sign_f32
+
+    return result.view(torch.float)
+
+
+def quant_llm_linear(
+    EXPONENT: int,
+    MANTISSA: int,
+    _in_feats: Tensor,
+    _weights: Tensor,
+    _scales: Tensor,
+    splitK: int = 1,
+) -> Tensor:
+    """
+    Quant-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details.
+    Arguments
+        EXPONENT: number of exponent bits
+        MANTISSA: number of mantissa bits
+        _in_feats: input activations in FP16
+        _weights: packed FPx weights
+        _scales: scale
+        splitK: split K
+    Returns
+        output of linear layer
+    """
+    return torch.ops.torchao.quant_llm_linear.default(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK)
+
+
+_ONES_TABLE = [_n_ones(i) for i in range(8)]
+
+
+def _pack(x: Tensor, n_bits: int) -> Tensor:
+    return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)])
+
+
+def _unpack(x: Tensor, n_bits: int) -> Tensor:
+    return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2)
+
+
+# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116
+def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor:
+    # the original code unpacks/packs the values from/to uint32 while we unpack/pack the values from/to uint8
+    # thus, we need to reverse byte order within a uint32 word.
+    x = x.reshape(-1, 4).flip(1)
+
+    x = _unpack(x, n_bits)
+    x = x.view(-1, 4 * (8 // n_bits))
+
+    if not undo:
+        bit_order = {
+            1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31,
+                0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30],
+            2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14],
+            4: [1, 5, 3, 7, 0, 4, 2, 6],
+        }[n_bits]
+
+    else:
+        # this is inverse of the above, obtained by running
+        # [v.index(i) for i in range(len(v))]
+        bit_order = {
+            1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11,
+                20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15],
+            2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7],
+            4: [4, 0, 6, 2, 5, 1, 7, 3],
+        }[n_bits]
+
+    x = x[:, bit_order]
+    x = _pack(x, n_bits)
+
+    # reverse byte order within a uint32 word again.
+    x = x.reshape(-1, 4).flip(1)
+    return x.flatten()
+
+
+# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing
+# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h
+def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
+    assert tensor.ndim == 2, tensor.dtype == torch.uint8
+    M, N = tensor.shape
+    assert (M % 64 == 0) and (N % 64 == 0)
+
+    # Pass 1 from original code
+    tensor = tensor.view(M // 64, 4, 2, 8, N // 16, 2, 8)
+    tensor = tensor.permute(0, 4, 1, 5, 2, 3, 6)
+    tensor = tensor.reshape(-1, 32, 2)
+    tensor = tensor.permute(1, 0, 2)
+    tensor = tensor.flatten()
+
+    used_bits = 0
+    fragments = []
+
+    for y in [1, 2, 4]:
+        if nbits & y:
+            mask = (1 << y) - 1
+            tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask
+            tensor_ybit = _pack(tensor_ybit, y)
+
+            tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2)
+            tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y)
+            fragments.append(tensor_ybit)
+            used_bits += y
+
+    return torch.cat(fragments, dim=0).view(M, -1)
+
+
+# more optimized version of _pack_tc_fpx() for FP6 by merging ops
+def _pack_tc_fp6(tensor: Tensor) -> Tensor:
+    assert tensor.ndim == 2, tensor.dtype == torch.uint8
+    M, N = tensor.shape
+    assert (M % 64 == 0) and (N % 64 == 0)
+
+    tensor = tensor.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8)
+    tensor = tensor.flip(3)
+
+    tensor_2bit = (tensor >> 4) & 0b11
+    tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6)
+    tensor_2bit = _pack(tensor_2bit.flatten(), 2)
+
+    tensor_4bit = tensor & 0b1111
+    tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6)
+    tensor_4bit = _pack(tensor_4bit.flatten(), 4)
+
+    return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1)
+
+
+# currently only optimize for TC-FP6 packing
+def pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
+    if nbits == 6:
+        return _pack_tc_fp6(tensor)
+    return _pack_tc_fpx(tensor, nbits)
+
+
+def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]:
+    # _n_ones() is not compatible with torch.compile() due to << operator
+    # https://github.com/pytorch/pytorch/issues/119152
+    # exp_bias = _n_ones(ebits - 1)
+    # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
+
+    # workaround: global lookup table
+    exp_bias = _ONES_TABLE[ebits - 1]
+    max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits))
+
+    tensor = tensor.float()
+    scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
+    tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
+    tensor_tc_fpx = pack_tc_fpx(tensor_fpx, 1 + ebits + mbits)
+    return tensor_tc_fpx, scale.half()
+
+
+# inverse of _pack_tc_fpx()
+def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
+    assert tensor.ndim == 2 and tensor.dtype == torch.uint8
+    M = tensor.shape[0]
+    size = tensor.numel()
+    tensor = tensor.flatten()
+    offset = 0
+    used_bits = 0
+
+    tensor_fpx = None
+
+    for y in [1, 2, 4]:
+        if nbits & y:
+            size_ybit = size // nbits * y
+            tensor_ybit = tensor[offset : offset + size_ybit]
+            offset += size_ybit
+
+            tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True)            # undo Pass 3
+            tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2)  # undo Pass 2
+
+            tensor_ybit = _unpack(tensor_ybit.flatten(), y)
+            tensor_ybit = tensor_ybit << (nbits - used_bits - y)
+            used_bits += y
+
+            if tensor_fpx is None:
+                tensor_fpx = tensor_ybit
+            else:
+                tensor_fpx |= tensor_ybit
+
+    # undo Pass 1
+    tensor_fpx = tensor_fpx.view(32, -1, 2).permute(1, 0, 2)
+    tensor_fpx = tensor_fpx.reshape(M // 64, -1, 4, 2, 2, 8, 8)
+    tensor_fpx = tensor_fpx.permute(0, 2, 4, 5, 1, 3, 6)
+    tensor_fpx = tensor_fpx.reshape(M, -1)
+    return tensor_fpx
+
+
+# more optimized version of _unpack_tc_fpx() for FP6 by merging ops
+# inverse of _unpack_tc_fp6()
+def _unpack_tc_fp6(tensor: Tensor) -> Tensor:
+    assert tensor.ndim == 2 and tensor.dtype == torch.uint8
+    M = tensor.shape[0]
+    N = tensor.shape[1] // 3 * 4
+    assert (M % 64 == 0) and (N % 64 == 0)
+    size_2bit = M * N // 4
+    size_4bit = M * N // 2
+    tensor = tensor.view(-1)
+    assert tensor.numel() == size_2bit + size_4bit
+
+    tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit])
+
+    tensor_2bit = _unpack(tensor_2bit, 2)
+    tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2)
+    tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4)
+
+    tensor_4bit = _unpack(tensor_4bit, 4)
+    tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2)
+    tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5)
+
+    tensor_fp6 = (tensor_2bit << 4) | tensor_4bit
+    tensor_fp6 = tensor_fp6.flip(3).reshape(M, N)
+    return tensor_fp6
+
+
+def unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor:
+    if nbits == 6:
+        return _unpack_tc_fp6(tensor)
+    return _unpack_tc_fpx(tensor, nbits)
+
+
+def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Tensor:
+    fpx_unpacked = unpack_tc_fpx(tensor, 1 + ebits + mbits)
+    tensor = _fpx_unpacked_to_f32(fpx_unpacked, ebits, mbits)
+    if scale is not None:
+        tensor = tensor * scale.float().view(-1, 1)
+    return tensor

+ 73 - 0
kernels/quantization/fp6/configs.h

@@ -0,0 +1,73 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is copied from
+// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/configs.h
+
+#ifndef CONFIGS_H
+#define CONFIGS_H
+
+// #define DEBUG_MODE
+#define PIPELINE_LEVEL_GMEM 2
+#define PIPELINE_LEVEL_SMEM 2  // only support 2
+
+/************************ Hardware Parameters ************************/
+#define WARP_SIZE 32
+#define REG_BIT_WIDTH 32
+// mma: M=16 K=16 N=8
+#define MMA_8 8
+#define MMA_16 16
+// for memory access
+#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128  // LDS.128, cp_async.128, ...
+#define BIT_WIDTH_PER_HALF 16                // Half precision: FP16
+
+/******************** Register Allocation For GEMM ********************/
+#define REG_PER_THREAD_C_TENSOR_16_16 8  // 8 for FP32 Accumulation
+/********************** Memory Padding Parameters **********************/
+// Eliminating bank-conflict
+#define PADDING_BYTES_16 16  // Padding 16 bytes each column
+#define PADDING_SHARED_MEM_FOR_B_8 \
+  8  // Padding 8 half  each column, during CopyFromGlobalToShared() for B
+#define PADDING_SHARED_MEM_FOR_C_4 \
+  4  // Padding 4 float each column, during StoreToSharedMemoryFromRegister()
+     // for C
+/************************* WARP Tiling part-1 *************************/
+#define WARP_ROW_MMA_TENSORS 4
+#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16)  // 64
+#define WARP_K_MMA_TENSORS 4
+#define WARP_K (WARP_K_MMA_TENSORS * MMA_16)  // 64
+template <int BLOCK_ROW_WARPS_, int BLOCK_COL_WARPS_, int WARP_COL_MMA_TENSORS_>
+struct TilingConfig {
+  // Depending on "n" dimension of the GEMM
+  static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_;
+  static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_;
+  static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_;
+  /************************* WARP Tiling part-2 *************************/
+  static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8;
+  /*************************Thread Block Tiling *************************/
+  static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS;
+  static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS;
+  static constexpr int TILE_K = WARP_K;
+  /********************** #Thread per Thread Block **********************/
+  static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS;
+  static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE;
+  /******************************* Others *******************************/
+  static constexpr int SMEM_SIZE_B_TILE =
+      TILE_N * (TILE_K + PADDING_BYTES_16) * 2 *
+      PIPELINE_LEVEL_GMEM;  // sizeof(half)=2, doubleBuffer=2
+  static constexpr int SMEM_SIZE_C_TILE =
+      TILE_N * (TILE_M + PADDING_BYTES_16) * 4;  // sizeof(float)=4
+};
+
+#endif  // CONFIGS_H

+ 332 - 0
kernels/quantization/fp6/fp6_linear.cu

@@ -0,0 +1,332 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is adapted from
+// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu
+
+#include "kernel_matmul.cuh"
+#include "kernel_reduction.cuh"
+
+#include <stdio.h>
+#include <assert.h>
+
+namespace aphrodite {
+
+template <typename TilingConfig, typename OutputDataType, int EXPONENT,
+          int MANTISSA>
+static void Kernel_Ex(cudaStream_t stream, const uint4* Weight,
+                      const half* Scales, const half* B, OutputDataType* C,
+                      const size_t M_Global, const size_t N_Global,
+                      const size_t K_Global, int Split_K) {
+#ifdef DEBUG_MODE
+  printf("\n");
+  printf("Launcher.cu->Kernel_Ex():\n");
+  printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global,
+         Split_K);
+  printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M,
+         TilingConfig::TILE_K, TilingConfig::TILE_N);
+#endif
+  static size_t SHMEM_SZ =
+      max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_PER_TB_A_TILE,
+          TilingConfig::SMEM_SIZE_C_TILE);
+  cudaFuncSetAttribute(
+      QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>,
+      cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ);
+  size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1;
+  size_t dimM = M_Global * Split_K / TilingConfig::TILE_M;
+  dim3 GridDim(dimN, dimM, 1);
+  dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1);
+//
+#ifdef DEBUG_MODE
+  printf(
+      "GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, "
+      "BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\n",
+      GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z,
+      SHMEM_SZ);
+  printf("\n");
+#endif
+  QUANT_GEMM_Kernel<TilingConfig, OutputDataType, EXPONENT, MANTISSA>
+      <<<GridDim, BlockDim, SHMEM_SZ, stream>>>(Weight, Scales, B, C, M_Global,
+                                                N_Global, K_Global, Split_K);
+}
+
+template <int EXPONENT, int MANTISSA>
+cudaError_t fpx_linear_kernel(
+    cudaStream_t stream, const uint4* Weight, const half* Scales, const half* B,
+    half* C, const size_t M_Global, const size_t N_Global,
+    const size_t K_Global,
+    float* Reduction_Workspace,  // Reduction_Workspace_Size = Split_K *
+                                 // M_Global * N_Global * sizeof(fp32)
+    int Split_K) {
+  assert(M_Global % 256 == 0);
+  assert(K_Global % 64 == 0);
+  assert(N_Global > 0);
+
+  // Work around to support more N shapes:
+  size_t N_PowerOf2;
+  if (N_Global > 0 && N_Global <= 8) N_PowerOf2 = 8;
+  if (N_Global > 8 && N_Global <= 16) N_PowerOf2 = 16;
+  if (N_Global > 16 && N_Global <= 32) N_PowerOf2 = 32;
+  if (N_Global > 32 && N_Global <= 64) N_PowerOf2 = 64;
+  if (N_Global > 64 && N_Global <= 128) N_PowerOf2 = 128;
+  if (N_Global > 128) N_PowerOf2 = ((N_Global - 1) / 128 + 1) * 128;
+
+  if (Split_K == 1) {
+    switch (N_PowerOf2) {
+      case 8:
+        Kernel_Ex<TilingConfig<4, 1, 1>, half, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
+            Split_K);
+        break;
+      case 16:
+        Kernel_Ex<TilingConfig<4, 1, 2>, half, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
+            Split_K);
+        break;
+      case 32:
+        Kernel_Ex<TilingConfig<4, 1, 4>, half, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
+            Split_K);
+        break;
+      case 64:
+        Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
+            Split_K);
+        break;
+      case 128:
+        Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
+            Split_K);
+        break;
+      default:
+        if (N_PowerOf2 % 128 != 0) {
+          printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
+          return cudaErrorUnknown;
+        }
+        Kernel_Ex<TilingConfig<4, 1, 8>, half, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, C, M_Global, N_Global, K_Global,
+            Split_K);
+        break;
+    }
+  } else {
+    switch (N_PowerOf2) {
+      case 8:
+        Kernel_Ex<TilingConfig<4, 1, 1>, float, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
+            K_Global, Split_K);
+        break;
+      case 16:
+        Kernel_Ex<TilingConfig<4, 1, 2>, float, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
+            K_Global, Split_K);
+        break;
+      case 32:
+        Kernel_Ex<TilingConfig<4, 1, 4>, float, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
+            K_Global, Split_K);
+        break;
+      case 64:
+        Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
+            K_Global, Split_K);
+        break;
+      case 128:
+        Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
+            K_Global, Split_K);
+        break;
+      default:
+        if (N_PowerOf2 % 128 != 0) {
+          printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
+          return cudaErrorUnknown;
+        }
+        Kernel_Ex<TilingConfig<4, 1, 8>, float, EXPONENT, MANTISSA>(
+            stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global,
+            K_Global, Split_K);
+        break;
+    }
+    // Reduction for SplitK
+    dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1,
+                 1);
+    dim3 BlockDim(WARP_SIZE, 1, 1);
+    SplitK_Reduction<<<GridDim, BlockDim, 0, stream>>>(
+        C, Reduction_Workspace, M_Global, N_Global, Split_K);
+  }
+  return cudaGetLastError();
+}
+}  // namespace aphrodite
+
+#include <torch/all.h>
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <torch/library.h>
+
+// MODIFICATION NOTE: dtype of _weights is changed to uint8
+/*
+Computes FPx-FP16 GEMM (PyTorch interface).
+[Mathematical Formula]
+Standard definition of linear layer:    Out = In * trans(W), where In, Out, and
+W are stored in row-major. After Equivalent transformation    :    trans(Out) =
+W * trans(In). Note that we do not perform "transpose" during runtime, we
+instead interpret the In/Out as column-major matrices when calling our CUDA
+kernel. [Inputs] _in_feats:  tensor of shape [B, IC];                  // half
+  _weights:   int tensor of shape [OC, IC // 8 * x];    // x UINT8 words
+contains 8 FPx weights. _scales:    tensor of shape [OC];                     //
+half splitK:     splitting the MatMul problem along K dimension for higher GPU
+utilization, default 1. [Outputs] _out_feats: tensor of shape [B, OC]; // half
+*/
+torch::Tensor fp_eXmY_linear_forward_cuda(int64_t EXPONENT, int64_t MANTISSA,
+                                          torch::Tensor _in_feats,
+                                          torch::Tensor _weights,
+                                          torch::Tensor _scales,
+                                          int64_t splitK = 1) {
+  const int64_t NBITS = 1 + EXPONENT + MANTISSA;
+  int num_in_feats = _in_feats.size(0);
+  int num_in_channels = _in_feats.size(1);
+  int num_out_channels = _weights.size(0);
+  TORCH_CHECK(num_in_channels % 64 == 0,
+              "Expected in_features to be a multiple of 64, but received ",
+              num_in_channels);
+  TORCH_CHECK((num_in_channels / 8 * NBITS) ==
+              _weights.size(1));  // Making sure the K dimension is matched.
+  //
+  int M = num_out_channels;
+  int K = num_in_channels;
+  int N = num_in_feats;
+  // Input Tensors
+  auto weight = reinterpret_cast<const uint4*>(
+      _weights.data_ptr<uint8_t>());  // weights is [OC, IC] but in FP6.
+  auto in_feats = reinterpret_cast<const half*>(_in_feats.data_ptr<at::Half>());
+  auto scales = reinterpret_cast<const half*>(_scales.data_ptr<at::Half>());
+  // Output Tensors
+  auto options = torch::TensorOptions()
+                     .dtype(_in_feats.dtype())
+                     .device(_in_feats.device());
+  at::Tensor _out_feats =
+      torch::empty({num_in_feats, num_out_channels}, options);
+  auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
+
+  options =
+      torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device());
+  at::Tensor _workspace =
+      torch::empty({splitK, num_in_feats, num_out_channels}, options);
+  auto Reduction_Workspace = reinterpret_cast<float*>(
+      _workspace.data_ptr<float>());  // Reduction_Workspace_Size = Split_K *
+                                      // M_Global * N_Global * sizeof(fp32)
+
+  // MODIFICATION NOTE: use at::cuda::getCurrentCUDAStream() instead of default
+  // stream (0) this fixes problem with CUDA graphs when used with
+  // torch.compile()
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  /*
+   The heuristic is weight_bit - exponent_bit - 1 = mantissa_bit
+   */
+
+  // FP2
+  if (EXPONENT == 1 && MANTISSA == 0)
+    aphrodite::fpx_linear_kernel<1, 0>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+
+  // FP3
+  else if (EXPONENT == 1 && MANTISSA == 1)
+    aphrodite::fpx_linear_kernel<1, 1>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 2 && MANTISSA == 0)
+    aphrodite::fpx_linear_kernel<2, 0>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+
+  // FP4
+  else if (EXPONENT == 1 && MANTISSA == 2)
+    aphrodite::fpx_linear_kernel<1, 2>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 3 && MANTISSA == 0)
+    aphrodite::fpx_linear_kernel<3, 0>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 2 && MANTISSA == 1)
+    aphrodite::fpx_linear_kernel<2, 1>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  // FP5
+  else if (EXPONENT == 1 && MANTISSA == 3)
+    aphrodite::fpx_linear_kernel<1, 3>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 2 && MANTISSA == 2)
+    aphrodite::fpx_linear_kernel<2, 2>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 3 && MANTISSA == 1)
+    aphrodite::fpx_linear_kernel<3, 1>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 4 && MANTISSA == 0)
+    aphrodite::fpx_linear_kernel<4, 0>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+
+  // FP6
+  else if (EXPONENT == 1 && MANTISSA == 4)
+    aphrodite::fpx_linear_kernel<1, 4>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 2 && MANTISSA == 3)
+    aphrodite::fpx_linear_kernel<2, 3>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 3 && MANTISSA == 2)
+    aphrodite::fpx_linear_kernel<3, 2>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 4 && MANTISSA == 1)
+    aphrodite::fpx_linear_kernel<4, 1>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 5 && MANTISSA == 0)
+    aphrodite::fpx_linear_kernel<5, 0>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  // FP7
+  else if (EXPONENT == 1 && MANTISSA == 5)
+    aphrodite::fpx_linear_kernel<1, 5>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 2 && MANTISSA == 4)
+    aphrodite::fpx_linear_kernel<2, 4>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 3 && MANTISSA == 3)
+    aphrodite::fpx_linear_kernel<3, 3>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 4 && MANTISSA == 2)
+    aphrodite::fpx_linear_kernel<4, 2>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+  else if (EXPONENT == 5 && MANTISSA == 1)
+    aphrodite::fpx_linear_kernel<5, 1>(stream, weight, scales, in_feats,
+                                       out_feats, M, N, K, Reduction_Workspace,
+                                       splitK);
+
+  else
+    TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA,
+                " is not supported.");
+
+  return _out_feats;
+}

+ 354 - 0
kernels/quantization/fp6/kernel_matmul.cuh

@@ -0,0 +1,354 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is modified from
+// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/kernel_matmul.cuh
+
+#include "configs.h"
+#include "utils_gmem.cuh"
+#include "utils_core.cuh"
+
+/************************** Bitwidth of Weight Segments
+ * ************************/
+#define BIT_WIDTH_1 1
+#define BIT_WIDTH_2 2
+#define BIT_WIDTH_4 4
+/*************************** 64*64 Weghts of Weight Matrix
+ * *********************/
+#define WEIGHT_PER_WARP (WARP_M * WARP_K)  // 64*64 = 4096
+#define SMEM_SIZE_PER_WARP_1BIT    \
+  (WEIGHT_PER_WARP * BIT_WIDTH_1 / \
+   8)  // 512 Bytes,  doubleBuffer not taken into consideration
+#define SMEM_SIZE_PER_WARP_2BIT    \
+  (WEIGHT_PER_WARP * BIT_WIDTH_2 / \
+   8)  // 1024 Bytes, doubleBuffer not taken into consideration
+#define SMEM_SIZE_PER_WARP_4BIT    \
+  (WEIGHT_PER_WARP * BIT_WIDTH_4 / \
+   8)  // 2048 Bytes, doubleBuffer not taken into consideration
+#define SMEM_SIZE_PER_TB_1BIT                            \
+  (SMEM_SIZE_PER_WARP_1BIT * TilingConfig::BLOCK_WARPS * \
+   PIPELINE_LEVEL_GMEM)  // #WARP=4; Trible-Buffer for 3-level pipeline for A
+                         // = 6 KB;  double buffer for 2-level pipeline A= 4
+                         // KB.
+#define SMEM_SIZE_PER_TB_2BIT                            \
+  (SMEM_SIZE_PER_WARP_2BIT * TilingConfig::BLOCK_WARPS * \
+   PIPELINE_LEVEL_GMEM)  // #WARP=4; Trible-Buffer for 3-level pipeline for A
+                         // = 12 KB; double buffer for 2-level pipeline A= 8
+                         // KB.
+#define SMEM_SIZE_PER_TB_4BIT                            \
+  (SMEM_SIZE_PER_WARP_4BIT * TilingConfig::BLOCK_WARPS * \
+   PIPELINE_LEVEL_GMEM)  // #WARP=4; Trible-Buffer for 3-level pipeline for A
+                         // = 24 KB; double buffer for 2-level pipeline A= 16
+                         // KB.
+#define SMEM_SIZE_PER_TB_A_TILE                    \
+  (SMEM_SIZE_PER_TB_1BIT + SMEM_SIZE_PER_TB_2BIT + \
+   SMEM_SIZE_PER_TB_4BIT)  // used in fp6_linear.cu, Kernel_Ex().
+/******************** Global Memory Layout For QUANTIZED DATA
+ * *******************/
+#define NUM_INT4_PER_WARP_1BIT (WEIGHT_PER_WARP * BIT_WIDTH_1 / 128)  // 32
+#define NUM_INT4_PER_WARP_2BIT (WEIGHT_PER_WARP * BIT_WIDTH_2 / 128)  // 64
+#define NUM_INT4_PER_WARP_4BIT (WEIGHT_PER_WARP * BIT_WIDTH_4 / 128)  // 128
+
+/*
+ * C = A*B
+ * A: row major with ahead-of-time layout transformation, FP6
+ * B: col major, FP16
+ * C: col major, FP16
+ */
+template <typename TilingConfig, typename OutputDataType, int EXPONENT,
+          int MANTISSA>
+__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
+                                  const half* B, OutputDataType* C,
+                                  const size_t M_Global, const size_t N_Global,
+                                  const size_t K_Global, int Split_K) {
+#ifdef DEBUG_MODE
+  assert(K_Global % TilingConfig::TILE_K == 0);
+  assert(M_Global % TilingConfig::TILE_M == 0);
+  assert(gridDim.y == Split_K * (M_Global / TilingConfig::TILE_M));
+#endif
+  // 1+2+4 weight split
+  constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
+  constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
+  constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
+  constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
+  const uint4* Weight_1bit = Weight;
+  const uint4* Weight_2bit =
+      Weight_1bit +
+      (USE_SEG_1BIT ? M_Global * K_Global * BIT_WIDTH_1 / 128 : 0);
+  const uint4* Weight_4bit =
+      Weight_2bit +
+      (USE_SEG_2BIT ? M_Global * K_Global * BIT_WIDTH_2 / 128 : 0);
+  // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned
+  extern __shared__ __align__(128) half smem[];
+  half(*smem_array)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
+      reinterpret_cast<half(*)[WARP_K + PADDING_SHARED_MEM_FOR_B_8]>(
+          smem + SMEM_SIZE_PER_TB_A_TILE /
+                     2);  // Dynamic shared memory for FP16 B tiles
+  __shared__ half
+      QuantScales[64 *
+                  TilingConfig::BLOCK_WARPS];  // static shared memory for
+                                               // quantization scales, 64 row
+                                               // per warp * 4 warps = 512 Bytes
+  // Thread Block Mapping, considering SplitK
+  const size_t BatchID = blockIdx.y / (M_Global / TilingConfig::TILE_M);
+  const size_t x =
+      blockIdx.x;  // Output Block ID: (BlockID_Row = y; BlockID_Col = x )
+  const size_t y =
+      blockIdx.y %
+      (M_Global / TilingConfig::TILE_M);  // Output Block ID: (BlockID_Row = y;
+                                          // BlockID_Col = x )
+  const size_t Tile_Start_M = y * TilingConfig::TILE_M;
+  const size_t Tile_Start_N = x * TilingConfig::TILE_N;
+  const size_t NumColumnToCopy =
+      (N_Global - Tile_Start_N) < TilingConfig::TILE_N
+          ? (N_Global - Tile_Start_N)
+          : TilingConfig::TILE_N;
+  const size_t NumBlock_K = K_Global / TilingConfig::TILE_K;
+  const size_t AverageNumBlock_K = NumBlock_K / Split_K;
+  const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K;
+  size_t NumIter = AverageNumBlock_K;
+  size_t StartBlockID_K = AverageNumBlock_K * BatchID;
+  if (BatchID < ExtraNumBlock_K) {
+    NumIter++;
+    StartBlockID_K += BatchID;
+  } else
+    StartBlockID_K += ExtraNumBlock_K;
+  // Warp ID.
+  const int warpId = threadIdx.x / WARP_SIZE;
+  int WARP_i = warpId / TilingConfig::BLOCK_COL_WARPS;  // WARP_i: row number;
+                                                        // WARP_j: column number
+  // int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS;
+  //  Global Memory Address for Matrix A (Weight)
+  //  /////////////////////////////////////////////////////////////////////////
+  //  StartPTR for each ThreadBlock(TB)
+  const uint4* TB_StartGPTR_A_1BIT =
+      Weight_1bit +
+      (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_WARP_1BIT;
+  const uint4* TB_StartGPTR_A_2BIT =
+      Weight_2bit +
+      (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_WARP_2BIT;
+  const uint4* TB_StartGPTR_A_4BIT =
+      Weight_4bit +
+      (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_WARP_4BIT;
+  // StartPTR for each WARP.
+  const uint4* WARP_StartGPTR_A_1BIT =
+      TB_StartGPTR_A_1BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_1BIT;
+  const uint4* WARP_StartGPTR_A_2BIT =
+      TB_StartGPTR_A_2BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_2BIT;
+  const uint4* WARP_StartGPTR_A_4BIT =
+      TB_StartGPTR_A_4BIT + WARP_i * NumBlock_K * NUM_INT4_PER_WARP_4BIT;
+  // StartPTR for each WARP, considering SplitK
+  const size_t WARP_Start_UnitID_K = StartBlockID_K;
+  WARP_StartGPTR_A_1BIT += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_1BIT;
+  WARP_StartGPTR_A_2BIT += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_2BIT;
+  WARP_StartGPTR_A_4BIT += WARP_Start_UnitID_K * NUM_INT4_PER_WARP_4BIT;
+  // Copying A tile from Global to Shared, using double-buffer
+  // ////////////////////////////////////////////////////////// StartSPTR for
+  // each ThreadBlock
+  uint32_t* AFrag_1BIT_SPTR = reinterpret_cast<uint32_t*>(smem);
+  uint32_t* AFrag_2BIT_SPTR = AFrag_1BIT_SPTR + SMEM_SIZE_PER_TB_1BIT / 4;
+  uint32_t* AFrag_4BIT_SPTR =
+      AFrag_2BIT_SPTR +
+      SMEM_SIZE_PER_TB_2BIT /
+          4;  // 8 buffers including double buffers, 12 for trible buffers
+  // StartSPTR for each WARP
+  AFrag_1BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_1BIT / 4;
+  AFrag_2BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_2BIT / 4;
+  AFrag_4BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_4BIT / 4;
+  // Pre-fetch of A tile
+  for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) {
+    if (USE_SEG_1BIT)
+      CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_1BIT>(
+          AFrag_1BIT_SPTR + i * SMEM_SIZE_PER_WARP_1BIT / 4 * 4,
+          WARP_StartGPTR_A_1BIT);
+    if (USE_SEG_2BIT)
+      CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_2BIT>(
+          AFrag_2BIT_SPTR + i * SMEM_SIZE_PER_WARP_2BIT / 4 * 4,
+          WARP_StartGPTR_A_2BIT);
+    if (USE_SEG_4BIT)
+      CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(
+          AFrag_4BIT_SPTR + i * SMEM_SIZE_PER_WARP_4BIT / 4 * 4,
+          WARP_StartGPTR_A_4BIT);
+    WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT / 16;
+    WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT / 16;
+    WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT / 16;
+  }
+  // Global Memory Address for Matrix A (QuantScale)
+  // /////////////////////////////////////////////////////////////////////
+  const half* TB_StartGPTR_A_Scale =
+      Scales + (y * TilingConfig::BLOCK_ROW_WARPS) * 64;
+  const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64;
+  CopyFromGlobalToShared_Scales(QuantScales + WARP_i * 64,
+                                WARP_StartGPTR_A_Scales);
+  // Copying B tile from Global to Shared, considering SplitK
+  // /////////////////////////////////////////////////////////////
+  const half* BTile_GPTR =
+      B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K;
+  for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) {
+    CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS>(
+        smem_array + i * TilingConfig::TILE_N, BTile_GPTR, K_Global,
+        NumColumnToCopy);
+    BTile_GPTR += TilingConfig::TILE_K;
+  }
+  // Register Allocation for A,B, and C, Initilazed to Zeros
+  // /////////////////////////////////////////////////////////////////////
+  constexpr int NumRegSets_a =
+      WARP_ROW_MMA_TENSORS;  // 1 set = 4 registers, containing a 16*16 MMA
+                             // block
+  constexpr int NumRegSets_b =
+      (TilingConfig::WARP_COL_MMA_TENSORS == 1)
+          ? 1
+          : TilingConfig::WARP_COL_MMA_TENSORS /
+                2;  // 1 set = 4 registers, containing a 16*16 MMA block
+  uint32_t a[NumRegSets_a * PIPELINE_LEVEL_SMEM]
+            [4];  // double/Trible buffer is used // Registers to store
+                  // decompressed FP6
+  uint32_t b[NumRegSets_b * PIPELINE_LEVEL_SMEM]
+            [4];  // double/Triple buffer is used // Register to store FP16 B
+                  // matrix (a slice)
+  float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16];
+  for (int i = 0; i < NumRegSets_a * NumRegSets_b; i++)
+    for (int j = 0; j < REG_PER_THREAD_C_TENSOR_16_16; j++) c[i][j] = 0.0f;
+  //
+  cp_async_wait_all();
+  __syncthreads();
+
+  /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+  uint32_t Scales_RPTR[4];  // 4 Registers per thread for Quantization Scales
+  ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i * 64);
+  // Initializing the Software Pipeline: writing registers.
+  // ////////////////////////////////////////////////////////////////////////////////////////////////
+  initialize_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
+      a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array,
+      Scales_RPTR);
+// The outer loop.
+// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+#pragma unroll(1)
+  for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) {
+    // Trible-Buffer for A Tile
+    uint32_t* __restrict__ read_SPTR_Frag_1bit =
+        AFrag_1BIT_SPTR +
+        ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT / 4 *
+            4;  // 512  (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
+    uint32_t* __restrict__ read_SPTR_Frag_2bit =
+        AFrag_2BIT_SPTR +
+        ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT / 4 *
+            4;  // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
+    uint32_t* __restrict__ read_SPTR_Frag_4bit =
+        AFrag_4BIT_SPTR +
+        ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT / 4 *
+            4;  // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
+    uint32_t* __restrict__ read2_SPTR_Frag_1bit =
+        AFrag_1BIT_SPTR + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) *
+                              SMEM_SIZE_PER_WARP_1BIT / 4 * 4;
+    uint32_t* __restrict__ read2_SPTR_Frag_2bit =
+        AFrag_2BIT_SPTR + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) *
+                              SMEM_SIZE_PER_WARP_2BIT / 4 * 4;
+    uint32_t* __restrict__ read2_SPTR_Frag_4bit =
+        AFrag_4BIT_SPTR + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) *
+                              SMEM_SIZE_PER_WARP_4BIT / 4 * 4;
+    uint32_t* __restrict__ write_SPTR_Frag_1bit =
+        AFrag_1BIT_SPTR +
+        ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) *
+            SMEM_SIZE_PER_WARP_1BIT / 4 *
+            4;  // 512  (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
+    uint32_t* __restrict__ write_SPTR_Frag_2bit =
+        AFrag_2BIT_SPTR +
+        ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) *
+            SMEM_SIZE_PER_WARP_2BIT / 4 *
+            4;  // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
+    uint32_t* __restrict__ write_SPTR_Frag_4bit =
+        AFrag_4BIT_SPTR +
+        ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) *
+            SMEM_SIZE_PER_WARP_4BIT / 4 *
+            4;  // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
+    // Trible-Buffer for B Tile
+    // MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is
+    // changed to below. similarly for read2_SPTR and write_SPTR.
+    half(*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
+        smem_array +
+        ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
+    half(*__restrict__ read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
+        smem_array +
+        ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
+    half(*__restrict__ write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
+        smem_array +
+        ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) *
+            TilingConfig::TILE_N;
+    //
+    bool GlobalCopy = (tile_id_k + PIPELINE_LEVEL_GMEM - 1) < NumIter;
+    // Copying A tile from Global to Register, Bypassing L1, using double-buffer
+    if (USE_SEG_1BIT)
+      CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_1BIT>(
+          write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy);
+    if (USE_SEG_2BIT)
+      CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_2BIT>(
+          write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy);
+    if (USE_SEG_4BIT)
+      CopyFromGlobalToShared_A<SMEM_SIZE_PER_WARP_4BIT>(
+          write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy);
+    // copying B tile from GlobalMemory to SharedMemory
+    CopyFromGlobalToShared<TilingConfig::TILE_N, TilingConfig::BLOCK_WARPS>(
+        write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy);
+    cp_async_group_commit();
+    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
+        c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit,
+        read_SPTR, Scales_RPTR,
+        1);  // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each
+             // WARP; read_SPTR is shared among WARPs
+    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
+        c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit,
+        read_SPTR, Scales_RPTR, 2);
+    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
+        c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit,
+        read_SPTR, Scales_RPTR, 3);
+    // Barriers and Synchronizations
+    cp_async_wait_group<PIPELINE_LEVEL_GMEM - 2>();
+    __syncthreads();
+    core_mma_slice<TilingConfig, EXPONENT, MANTISSA>(
+        c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit,
+        read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0);
+    // Updating global PTRs
+    WARP_StartGPTR_A_1BIT +=
+        SMEM_SIZE_PER_WARP_1BIT / 16;  // 2KB/16=128 (1)/16: int4*+1 = char*+16
+    WARP_StartGPTR_A_2BIT +=
+        SMEM_SIZE_PER_WARP_2BIT / 16;  // 4KB/16=256 (1)/16: int4*+1 = char*+16
+    WARP_StartGPTR_A_4BIT +=
+        SMEM_SIZE_PER_WARP_4BIT / 16;  // 8KB/16=512 (1)/16: int4*+1 = char*+16
+    BTile_GPTR += TilingConfig::TILE_K;
+  }
+  /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+  /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+  // Store the C fragments to shared memory.
+  float(*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4] =
+      reinterpret_cast<
+          float(*)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4]>(smem);
+  StoreToSharedMemoryFromRegister<TilingConfig>(smem_CFrag, c);
+  __syncthreads();
+  // Now that shared memory contains all the D tiles, stream them to global
+  // memory.
+  OutputDataType* BlockGlobalPTR = C + BatchID * (M_Global * N_Global) +
+                                   Tile_Start_M + Tile_Start_N * M_Global;
+  for (size_t i = warpId; i < NumColumnToCopy;
+       i += TilingConfig::BLOCK_WARPS)  // i-th column
+#pragma unroll
+    for (size_t j = threadIdx.x % WARP_SIZE; j < TilingConfig::TILE_M;
+         j += WARP_SIZE)  // j-th row
+    {
+      if constexpr (std::is_same<OutputDataType, half>::value)
+        BlockGlobalPTR[j + i * M_Global] = __float2half_rn(smem_CFrag[i][j]);
+      else
+        BlockGlobalPTR[j + i * M_Global] = smem_CFrag[i][j];
+    }
+}

+ 70 - 0
kernels/quantization/fp6/kernel_reduction.cuh

@@ -0,0 +1,70 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is copied from
+// https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_reduction.cuh
+
+/***************************************************************************
+ * Copyright 2023 The FLash-LLM Authors. All rights reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ ***************************************************************************/
+// Used for the reduction of result matrix if Split-K is used
+// Reduction_Workspace:     (Split_K, M_Global, N_Global),  column major
+// C:                       (M_Global, N_Global),           column major
+// Each thread deals with 8 output elements, each elements is the sum of Split_K
+// elements
+//      Read  Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8
+//      float_per_thread (256bit) -> 256 float per warp Write Global: Each
+//      Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread  (128bit) ->
+//      256 half  per warp
+// GridSize = (M_Global*N_Global) / 256
+
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+#define REDUCTION_ELEMENT_PER_THREADBLOCK 256
+#define HALF_PER_128BIT 8
+
+__global__ void SplitK_Reduction(half* C, float* Reduction_Workspace,
+                                 size_t M_Global, size_t N_Global,
+                                 int Split_K) {
+  half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x;
+  float* WARP_GPTR_R =
+      Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x;
+  half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT;
+  float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT;
+  // Initializing Thread-Local Results
+  float Results[HALF_PER_128BIT];
+#pragma unroll
+  for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f;
+  // Reduction
+  for (int i = 0; i < Split_K; i++) {
+#pragma unroll
+    for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j];
+    THREAD_GPTR_R += M_Global * N_Global;
+  }
+// Writing to global memory
+#pragma unroll
+  for (int i = 0; i < HALF_PER_128BIT; i++)
+    THREAD_GPTR_C[i] = __float2half_rn(Results[i]);
+}

+ 82 - 0
kernels/quantization/fp6/ptx_cp.async.cuh

@@ -0,0 +1,82 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is copied from
+// https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_cp.async.cuh
+
+/***************************************************************************
+ * Copyright 2023 The FLash-LLM Authors. All rights reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ ***************************************************************************/
+// Extended from CUTLASS's source code
+
+#ifndef PTX_CP_ASYNC_CUH
+#define PTX_CP_ASYNC_CUH
+
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+template <int SizeInBytes>
+__device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr,
+                                         bool pred_guard = true) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  static_assert(SizeInBytes == 16, "Size is not supported");
+  unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr);
+  asm volatile(
+      "{ \n"
+      "  .reg .pred p;\n"
+      "  setp.ne.b32 p, %0, 0;\n"
+      "  @p cp.async.cg.shared.global [%1], [%2], %3;\n"
+      "}\n" ::"r"((int)pred_guard),
+      "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
+#endif
+}
+
+/// Establishes an ordering w.r.t previously issued cp.async instructions. Does
+/// not block.
+__device__ __forceinline__ void cp_async_group_commit() {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  asm volatile("cp.async.commit_group;\n" ::);
+#endif
+}
+
+/// Blocks until all but <N> previous cp.async.commit_group operations have
+/// committed.
+template <int N>
+__device__ __forceinline__ void cp_async_wait_group() {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
+#endif
+}
+
+/// Blocks until all previous cp.async.commit_group operations have committed.
+// cp.async.wait_all is equivalent to :
+// cp.async.commit_group;
+// cp.async.wait_group 0;
+__device__ __forceinline__ void cp_async_wait_all() {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  asm volatile("cp.async.wait_all;\n" ::);
+#endif
+}
+
+#endif

+ 108 - 0
kernels/quantization/fp6/ptx_mma.cuh

@@ -0,0 +1,108 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is modified from
+// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/ptx_mma.cuh
+
+/***************************************************************************
+ * Copyright 2023 The FLash-LLM Authors. All rights reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ ***************************************************************************/
+#ifndef PTX_MMA_CUH
+#define PTX_MMA_CUH
+
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+#include <assert.h>
+#include "configs.h"
+
+// MODIFICATION NOTE: to support MSVC
+// - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__
+// Reg)[4]
+// - half __restrict__ (*read_SPTR) is changed to half (* __restrict__
+// read_SPTR)
+template <typename TilingConfig>
+__device__ __forceinline__ void B_FromSharedToReg(
+    uint32_t (*__restrict__ Reg)[4],
+    half (*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
+    int slice_id) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  #ifdef DEBUG_MODE
+  static_assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) ||
+                (TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0));
+  #endif
+
+  const int warpId = threadIdx.x / WARP_SIZE;
+  int lane_id = threadIdx.x % WARP_SIZE;
+  int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS;
+  int warp_start_col =
+      TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 *
+      WARP_j;  // each warp may start from reading warp_start_col'th column of
+               // the B tile in shared memory
+  #ifdef DEBUG_MODE
+  assert(warp_start_col == 0);
+  #endif
+
+  int col = (lane_id % 8) + (lane_id / 16) * 8;
+  int row = (lane_id % 16) / 8 * 8;
+  uint32_t smem_local_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(
+      &read_SPTR[warp_start_col + col][slice_id * MMA_16 + row]));
+  if (TilingConfig::WARP_COL_MMA_TENSORS == 1) {
+    asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
+                 : "=r"(Reg[0][0]), "=r"(Reg[0][1])
+                 : "r"(smem_local_ptr));
+  } else {
+  #pragma unroll
+    for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS / 2; i++) {
+      asm volatile(
+          "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
+          : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3])
+          : "r"(smem_local_ptr));
+      smem_local_ptr +=
+          16 * (WARP_K + PADDING_SHARED_MEM_FOR_B_8) * sizeof(half);
+    }
+  }
+#endif
+}
+
+// MODIFICATION NOTE: to support MSVC, the function signature is changed from
+// MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a,
+// uint32_t __restrict__ *b).
+__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t* __restrict__ c,
+                                                  uint32_t* __restrict__ a,
+                                                  uint32_t* __restrict__ b) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  asm volatile(
+      "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+      "{ %0, %1, %2, %3},"
+      "{ %4, %5, %6, %7 },"
+      "{ %8, %9 },"
+      "{ %10, %11, %12, %13 };"
+      : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
+      : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
+        "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
+#endif
+}
+
+#endif  // PTX_MMA_CUH

+ 188 - 0
kernels/quantization/fp6/utils_core.cuh

@@ -0,0 +1,188 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is modified from
+// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_core.cuh
+
+#ifndef UTILS_CORE_CUH
+#define UTILS_CORE_CUH
+
+#include <assert.h>
+
+#include "configs.h"
+#include "ptx_mma.cuh"
+#include "utils_parallel_dequant.cuh"
+
+template <int NUM_INT_PER_THREAD>
+__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[],
+                                                               uint32_t* SPTR,
+                                                               int slice_id) {
+  SPTR += slice_id * (NUM_INT_PER_THREAD * WARP_SIZE);
+  int lane_id = threadIdx.x % WARP_SIZE;
+#pragma unroll
+  for (int i = 0; i < NUM_INT_PER_THREAD; i++) {
+    Reg[i] = SPTR[lane_id + i * WARP_SIZE];
+  }
+}
+
+// MODIFICATION NOTE: to support MSVC, half __restrict__
+// (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
+template <typename TilingConfig, int EXPONENT, int MANTISSA>
+__device__ __forceinline__ void initialize_mma_slice(
+    uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A_1BIT_SPTR_read,
+    uint32_t* __restrict__ A_2BIT_SPTR_read,
+    uint32_t* __restrict__ A_4BIT_SPTR_read,
+    half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
+    uint32_t* RPTR_Scales) {
+  // 1+2+4 weight split
+  constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
+  constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
+  constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
+  constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
+  // Writing registers
+  // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6
+  // per thread => 6 register per thread;
+  uint32_t a_1bit[1];  // NO double buffer
+  uint32_t a_2bit[2];  // NO double buffer
+  uint32_t a_4bit[4];  // NO double buffer
+  if (USE_SEG_1BIT)
+    CopyFromSharedToRegister_AFrag<1>(a_1bit, A_1BIT_SPTR_read, 0);
+  if (USE_SEG_2BIT)
+    CopyFromSharedToRegister_AFrag<2>(a_2bit, A_2BIT_SPTR_read, 0);
+  if (USE_SEG_4BIT)
+    CopyFromSharedToRegister_AFrag<4>(a_4bit, A_4BIT_SPTR_read, 0);
+  Dequant_32FP6_4Way<EXPONENT, MANTISSA>(
+      a, a_1bit, a_2bit, a_4bit,
+      RPTR_Scales);  // SIMT Dequant: dequantizing FPx to FP16 at register
+                     // level, dequantizing a slice each time
+  B_FromSharedToReg<TilingConfig>(b, B_SPTR_read,
+                                  0);  // Loading B from shared to registers
+}
+
+// MODIFICATION NOTE: to support MSVC, half __restrict__
+// (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
+template <typename TilingConfig, int EXPONENT, int MANTISSA>
+__device__ __forceinline__ void core_mma_slice(
+    float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4],
+    uint32_t (*b)[4], uint32_t* __restrict__ A_1bit_SPTR_read,
+    uint32_t* __restrict__ A_2bit_SPTR_read,
+    uint32_t* __restrict__ A_4bit_SPTR_read,
+    half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
+    uint32_t* RPTR_Scales,
+    int slice_id)  // writing slice[slice_id] to registers, k=0 -> slice_id=1
+                   // for prefetching
+{
+  // 1+2+4 weight split
+  constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
+  constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
+  constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
+  constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
+
+#ifdef DEBUG_MODE
+  assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) ||
+         (TilingConfig::WARP_COL_MMA_TENSORS % 2 ==
+          0));  // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded
+                // to a 16*16 MMA block
+#endif
+  const int NumRegSets_a =
+      WARP_ROW_MMA_TENSORS;  // 1 set = 4 registers, containing a 16*16 MMA
+                             // block
+  const int NumRegSets_b =
+      (TilingConfig::WARP_COL_MMA_TENSORS == 1)
+          ? 1
+          : TilingConfig::WARP_COL_MMA_TENSORS /
+                2;  // 1 set = 4 registers, containing a 16*16 MMA block
+  uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] =
+      reinterpret_cast<uint32_t(*)[REG_PER_THREAD_C_TENSOR_16_16]>(
+          c);  // GlobalRegisters for accumulated FP32 results
+
+  // Setting RPTRs for double buffers
+  uint32_t(*a_read)[4] = a;
+  uint32_t(*a_write)[4] = a;
+  uint32_t(*b_read)[4] = b;
+  uint32_t(*b_write)[4] = b;
+  if (slice_id % 2 == 1) {
+    b_write += NumRegSets_b;
+    a_write += NumRegSets_a;
+  } else {
+    b_read += NumRegSets_b;
+    a_read += NumRegSets_a;
+  }
+
+// Reading registers and issuing core tensor core computations (a slice of A and
+// B tile in shared memory)
+#pragma unroll
+  for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) {
+    if (TilingConfig::WARP_COL_MMA_TENSORS == 1) {
+      MMA_FP16_M16N8K16(c_uint_ptr[i], a_read[i], b_read[0]);
+    } else {
+#pragma unroll
+      for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS / 2; j++) {
+        MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i],
+                          b_read[j]);
+        MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4,
+                          a_read[i], b_read[j] + 2);  // c+4; b+2
+      }
+    }
+  }
+  // Writing registers
+  // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6
+  // per thread => 6 register per thread;
+  uint32_t a_1bit[1];  // NO double buffer
+  uint32_t a_2bit[2];  // NO double buffer
+  uint32_t a_4bit[4];  // NO double buffer
+  if (USE_SEG_1BIT)
+    CopyFromSharedToRegister_AFrag<1>(a_1bit, A_1bit_SPTR_read, slice_id);
+  if (USE_SEG_2BIT)
+    CopyFromSharedToRegister_AFrag<2>(a_2bit, A_2bit_SPTR_read, slice_id);
+  if (USE_SEG_4BIT)
+    CopyFromSharedToRegister_AFrag<4>(a_4bit, A_4bit_SPTR_read, slice_id);
+  Dequant_32FP6_4Way<EXPONENT, MANTISSA>(
+      a_write, a_1bit, a_2bit, a_4bit,
+      RPTR_Scales);  // SIMT Dequant: dequantizing FP6 to FP16 at register
+                     // level, dequantizing a slice each time
+  B_FromSharedToReg<TilingConfig>(
+      b_write, B_SPTR_read, slice_id);  // Loading B from shared to registers
+}
+
+template <typename TilingConfig>
+__device__ __forceinline__ void StoreToSharedMemoryFromRegister(
+    float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4],
+    float c[][REG_PER_THREAD_C_TENSOR_16_16]) {
+  const int lane_id = threadIdx.x % WARP_SIZE;
+  const int warpId = threadIdx.x / WARP_SIZE;
+  int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS);
+#pragma unroll
+  for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) {
+#pragma unroll
+    for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS;
+         j++) {  // Dealing with one 16*8 Tensor
+      int RegSetID = i + (j / 2) * WARP_ROW_MMA_TENSORS;
+      int RegOffset = (j % 2) * (REG_PER_THREAD_C_TENSOR_16_16 / 2);
+      int Tensor_row_offset = warp_row_offset + i * MMA_16;
+      int Tensor_col_offset = j * MMA_8;
+#pragma unroll
+      for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16 / 2; r++) {
+        int row_offset = lane_id / 4;
+        if (r >= 2) row_offset += 8;
+        int col_offset = (lane_id % 4) * 2;
+        if (r % 2 == 1) col_offset += 1;
+        smem_CFrag[Tensor_col_offset + col_offset]
+                  [Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset];
+      }
+    }
+  }
+}
+
+#endif

+ 94 - 0
kernels/quantization/fp6/utils_gmem.cuh

@@ -0,0 +1,94 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is modified from
+// https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh
+
+#ifndef UTILS_GMEM_CUH
+#define UTILS_GMEM_CUH
+
+#include <assert.h>
+#include "configs.h"
+#include "ptx_cp.async.cuh"
+
+/*
+ * Copying A1/A2 from global memory to shared memory.
+ * Usually 1024 or 2048 Bytes
+ */
+template <int SMEM_SIZE_IN_BYTES_PER_WARP>
+__device__ __forceinline__ void CopyFromGlobalToShared_A(
+    uint32_t* SPTR, const uint4* GPTR, bool pred_guard = true) {
+#ifdef DEBUG_MODE
+  static_assert(SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE % 16 == 0);
+#endif
+  int lane_id = threadIdx.x % WARP_SIZE;
+  half* SPTR_HALF = reinterpret_cast<half*>(SPTR);
+  const half* GPTR_HALF = reinterpret_cast<const half*>(GPTR);
+  SPTR_HALF += lane_id * 8;
+  GPTR_HALF += lane_id * 8;
+#pragma unroll
+  for (int i = 0; i < SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE / 16; i++) {
+    cp_async<16>(SPTR_HALF, GPTR_HALF, pred_guard);
+    SPTR_HALF += 256;  // Forward 512 Bytes
+    GPTR_HALF += 256;  // Forward 512 Bytes
+  }
+}
+
+/*
+ * Copying 64 Quant Scales (FP16) from global memory to shared memory.
+ */
+__device__ __forceinline__ void CopyFromGlobalToShared_Scales(
+    half* SPTR_QuantScales, const half* GPTR_A_Scales) {
+  int lane_id = threadIdx.x % WARP_SIZE;
+  int Offset_Shared = lane_id * 2;
+  int Offset_Global = lane_id / 4 + (lane_id % 4) * 16;
+  for (int i = 0; i < 2; i++)
+    SPTR_QuantScales[Offset_Shared + i] = GPTR_A_Scales[Offset_Global + i * 8];
+}
+
+// MODIFICATION NOTE: to support MSVC, half __restrict__
+// (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
+/*
+ * (1) Copying X  rows * 64 columns of FP16 values, originally in row    major
+ * (2) Copying 64 rows * X  columns of FP16 values, originally in column major
+ * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8
+ * Threads
+ */
+template <int MaxNumOfLinesToCopy, int BLOCK_WARPS>
+__device__ __forceinline__ void CopyFromGlobalToShared(
+    half (*__restrict__ SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
+    const half* GlobalPTR, const int GlobalStride,
+    const int NumOfLinesLeft,  // To support arbitrary N dimensions.
+    bool Pred = true) {
+  // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time
+  const int NumOfThreads = BLOCK_WARPS * WARP_SIZE;
+  const int NumOfGroups = NumOfThreads / 8;
+  const int MaxIteration = (MaxNumOfLinesToCopy - 1) / NumOfGroups + 1;
+  // runtime variables
+  const int line_id = threadIdx.x / 8;
+  const int line_offset = (threadIdx.x % 8) * 8;
+  // PTR for source global memory and target shared memory
+  GlobalPTR += line_id * GlobalStride + line_offset;
+  SharedPTR += line_id;
+#pragma unroll
+  for (int i = 0; i < MaxIteration; i++) {
+    bool AsyncCopyPred = (line_id + i * NumOfGroups) < NumOfLinesLeft && Pred;
+    cp_async<16>(&(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred);
+    //
+    GlobalPTR += NumOfGroups * GlobalStride;
+    SharedPTR += NumOfGroups;
+  }
+}
+
+#endif

+ 148 - 0
kernels/quantization/fp6/utils_parallel_dequant.cuh

@@ -0,0 +1,148 @@
+//    Copyright 2024 FP6-LLM authors
+//
+//    Licensed under the Apache License, Version 2.0 (the "License");
+//    you may not use this file except in compliance with the License.
+//    You may obtain a copy of the License at
+//
+//        http://www.apache.org/licenses/LICENSE-2.0
+//
+//    Unless required by applicable law or agreed to in writing, software
+//    distributed under the License is distributed on an "AS IS" BASIS,
+//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+//    See the License for the specific language governing permissions and
+//    limitations under the License.
+//
+// This file is modified from
+// https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_parallel_dequant.cuh
+// To support MSVC, all instances of u_int32_t are changed to uint32_t.
+
+#ifndef UTILS_PARALLELDEQUANT_CUH
+#define UTILS_PARALLELDEQUANT_CUH
+
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+/*
+ * Input:   R1
+ * Outputs: R1, R2
+ * Note:    Simplified Exponent calculation is applied.
+ */
+template <int EXPONENT, int MANTISSA>
+__device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t* In, uint32_t* Out1,
+                                                   uint32_t* Out2) {
+  //
+  constexpr int RIGHT_SHIFT = 5 - EXPONENT;
+  constexpr int MASK1 = 0x80000000;
+  constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA;
+  constexpr int MASK3 = MASK2 & 0x7fffffff;
+  constexpr int MASK = MASK3 | MASK3 >> 16;
+  //
+  *Out1 = *In & 0x80008000;
+  *Out1 |= ((*In) & MASK) >> RIGHT_SHIFT;
+  //
+  *In = (*In) << 8;
+  *Out2 = *In & 0x80008000;
+  *Out2 |= ((*In) & MASK) >> RIGHT_SHIFT;
+}
+
+template <int EXPONENT, int MANTISSA>
+__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair,
+                                              half Scale) {
+  constexpr int BIAS_OFFSET = (int(1) << (5 - 1)) - (int(1) << (EXPONENT - 1));
+  constexpr int BIAS = int(1) << BIAS_OFFSET;
+  //
+  half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair);
+  half* FP16_2 = FP16_1 + 1;
+  uint32_t output;
+  half* output_half_ptr = reinterpret_cast<half*>(&output);
+  output_half_ptr[0] =
+      __hmul(__hmul(*FP16_1, __float2half(1.0f * BIAS)), Scale);
+  output_half_ptr[1] =
+      __hmul(__hmul(*FP16_2, __float2half(1.0f * BIAS)), Scale);
+  return output;
+}
+
+// MODIFICATION NOTE: to support MSVC
+// - u_int32_t __restrict__ Reg[][4] is changed to below.
+// - u_int32_t __restrict__ *read_RPTR_1bit is changed to below. similarly for
+// read_RPTR_2bit and read_RPTR_4bit
+template <int EXPONENT, int MANTISSA>
+__device__ __forceinline__ void Dequant_32FP6_4Way(
+    uint32_t (*__restrict__ Reg)[4], uint32_t* __restrict__ read_RPTR_1bit,
+    uint32_t* __restrict__ read_RPTR_2bit,
+    uint32_t* __restrict__ read_RPTR_4bit, uint32_t* Scales) {
+  // 1+2+4 weight split
+  constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA;
+  constexpr int USE_SEG_1BIT = BIT_WIDTH & 1;
+  constexpr int USE_SEG_2BIT = BIT_WIDTH & 2;
+  constexpr int USE_SEG_4BIT = BIT_WIDTH & 4;
+  //
+  uint32_t* OutputRegs = reinterpret_cast<uint32_t*>(Reg);
+  uint32_t* Frag_PTR_1bit = read_RPTR_1bit;
+  uint32_t* Frag_PTR_2bit = read_RPTR_2bit;
+  uint32_t* Frag_PTR_4bit = read_RPTR_4bit;
+  half* Scale_RPTR = reinterpret_cast<half*>(Scales);
+// Dequantizing 32 FP6, each Loop dequantizing 4 FP6
+#pragma unroll(8)
+  for (int i = 0; i < 8; i++) {
+    uint32_t Packed_FP6 = 0;
+    uint32_t tmp = 0;
+    // 1bit Frag
+    if (USE_SEG_1BIT) {
+      tmp = (*Frag_PTR_1bit) & 0x80808080;
+      Packed_FP6 |= tmp >> (BIT_WIDTH & 0);
+      if (i % 8 == 7)
+        Frag_PTR_1bit++;
+      else
+        (*Frag_PTR_1bit) = (*Frag_PTR_1bit) << 1;
+    }
+    // 2bit Frag
+    if (USE_SEG_2BIT) {
+      tmp = (*Frag_PTR_2bit) & 0xc0c0c0c0;
+      Packed_FP6 |= tmp >> (BIT_WIDTH & 1);
+      if (i % 4 == 3)
+        Frag_PTR_2bit++;
+      else
+        (*Frag_PTR_2bit) = (*Frag_PTR_2bit) << 2;
+    }
+    // 4bit Frag2
+    if (USE_SEG_4BIT) {
+      tmp = (*Frag_PTR_4bit) & 0xf0f0f0f0;
+      Packed_FP6 |= tmp >> (BIT_WIDTH & 3);
+      if (i % 2 == 1)
+        Frag_PTR_4bit++;
+      else
+        (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4;
+    }
+    //
+    uint32_t out1, out2;
+    FPx_FP16_Cast_4Way<EXPONENT, MANTISSA>(&Packed_FP6, &out1, &out2);
+    //
+    *OutputRegs = MultScale<EXPONENT, MANTISSA>(
+        out1, Scale_RPTR[0]);  // Multiply FP16 scales
+    OutputRegs += 1;
+    *OutputRegs = MultScale<EXPONENT, MANTISSA>(
+        out2, Scale_RPTR[1]);  // Multiply FP16 scales
+    OutputRegs += 1;
+    // Updating offset for FP16 scales for every two iterations
+    if (i % 2 == 1) Scale_RPTR += 2;
+  }
+}
+
+/*
+ *
+ */
+__device__ __forceinline__ void ExtractFromSharedToReg_Scales(
+    uint32_t* Scales, half* WARP_SPTR_Scales) {
+  int lane_id = threadIdx.x % WARP_SIZE;
+  uint32_t* SPTR_uint = reinterpret_cast<uint32_t*>(WARP_SPTR_Scales);
+  uint32_t tmpReg = SPTR_uint[lane_id];
+#pragma unroll
+  for (int i = 0; i < 4; i++) {
+    // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize);
+    Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4);
+  }
+}
+
+#endif

+ 6 - 0
kernels/quantization/quant_ops.h

@@ -132,6 +132,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
                               torch::Tensor& workspace, int64_t size_m,
                               torch::Tensor& workspace, int64_t size_m,
                               int64_t size_n, int64_t size_k);
                               int64_t size_n, int64_t size_k);
 
 
+torch::Tensor fp_eXmY_linear_forward_cuda(int64_t EXPONENT, int64_t MANTISSA,
+                                          torch::Tensor _in_feats,
+                                          torch::Tensor _weights,
+                                          torch::Tensor _scales,
+                                          int64_t splitK = 1);
+
 #endif
 #endif
 
 
 void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
 void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,

+ 9 - 0
kernels/torch_bindings.cpp

@@ -196,6 +196,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   // QuIP# Decompress
   // QuIP# Decompress
   ops.def("quip_decompress", &decompress_e8p_origorder);
   ops.def("quip_decompress", &decompress_e8p_origorder);
   ops.impl("quip_decompress", torch::kCUDA, &decompress_e8p_origorder);
   ops.impl("quip_decompress", torch::kCUDA, &decompress_e8p_origorder);
+
+  // fp6_llm
+  ops.def(
+      "fp_eXmY_linear_forward_cuda(int EXPONENT, int MANTISSA,"
+      "                            Tensor _in_feats, Tensor _weights,"
+      "                            Tensor _scales, int splitK=1) -> Tensor");
+  ops.impl("fp_eXmY_linear_forward_cuda", torch::kCUDA,
+           &fp_eXmY_linear_forward_cuda);
+
 #endif
 #endif
 
 
   // Quantized GEMM for GPTQ.
   // Quantized GEMM for GPTQ.

+ 9 - 0
tests/benchmarks/engine/throughput.py

@@ -66,6 +66,7 @@ def run_aphrodite(
     model: str,
     model: str,
     tokenizer: str,
     tokenizer: str,
     quantization: Optional[str],
     quantization: Optional[str],
+    quant_llm_fp_bits: Optional[int],
     tensor_parallel_size: int,
     tensor_parallel_size: int,
     seed: int,
     seed: int,
     n: int,
     n: int,
@@ -90,6 +91,7 @@ def run_aphrodite(
         model=model,
         model=model,
         tokenizer=tokenizer,
         tokenizer=tokenizer,
         quantization=quantization,
         quantization=quantization,
+        quant_llm_fp_bits=quant_llm_fp_bits,
         tensor_parallel_size=tensor_parallel_size,
         tensor_parallel_size=tensor_parallel_size,
         seed=seed,
         seed=seed,
         trust_remote_code=trust_remote_code,
         trust_remote_code=trust_remote_code,
@@ -226,6 +228,7 @@ def main(args: argparse.Namespace):
     if args.backend == "aphrodite":
     if args.backend == "aphrodite":
         elapsed_time = run_aphrodite(
         elapsed_time = run_aphrodite(
             requests, args.model, args.tokenizer, args.quantization,
             requests, args.model, args.tokenizer, args.quantization,
+            args.quant_llm_fp_bits,
             args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
             args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
             args.trust_remote_code, args.dtype, args.max_model_len,
             args.trust_remote_code, args.dtype, args.max_model_len,
             args.enforce_eager, args.kv_cache_dtype,
             args.enforce_eager, args.kv_cache_dtype,
@@ -286,6 +289,12 @@ if __name__ == "__main__":
                         '-q',
                         '-q',
                         choices=[*QUANTIZATION_METHODS, None],
                         choices=[*QUANTIZATION_METHODS, None],
                         default=None)
                         default=None)
+    parser.add_argument('--quant-llm-fp-bits',
+                        type=int,
+                        default=None,
+                        choices=[4, 5, 6, 7],
+                        help="Number of bits for the FP quantization in "
+                        "QuantLLM")
     parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
     parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
     parser.add_argument("--n",
     parser.add_argument("--n",
                         type=int,
                         type=int,