Browse Source

kernels: split marlin kernels for faster compile, fix MoE, temporarily remove HQQ (#1119)

* kernels: split marlin kernels for faster compile, fix MoE, temporarily remove HQQ

* more fixes
AlpinDale 1 month ago
parent
commit
fa84f8102e
29 changed files with 2305 additions and 1534 deletions
  1. 6 1
      CMakeLists.txt
  2. 3 5
      aphrodite/_custom_ops.py
  3. 10 4
      aphrodite/modeling/layers/fused_moe/__init__.py
  4. 235 0
      aphrodite/modeling/layers/fused_moe/fused_marlin_moe.py
  5. 24 118
      aphrodite/modeling/layers/fused_moe/fused_moe.py
  6. 52 25
      aphrodite/modeling/layers/fused_moe/layer.py
  7. 2 0
      aphrodite/modeling/models/mixtral_quant.py
  8. 1 0
      aphrodite/quantization/__init__.py
  9. 28 16
      aphrodite/quantization/compressed_tensors/compressed_tensors_moe.py
  10. 20 4
      aphrodite/quantization/utils/marlin_utils.py
  11. 0 1
      aphrodite/quantization/utils/marlin_utils_fp8.py
  12. 7 4
      aphrodite/quantization/utils/marlin_utils_test.py
  13. 15 7
      aphrodite/quantization/utils/quant_utils.py
  14. 1425 0
      kernels/moe/marlin_kernels/marlin_moe_kernel.h
  15. 29 0
      kernels/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
  16. 20 0
      kernels/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
  17. 29 0
      kernels/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
  18. 18 0
      kernels/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
  19. 27 1092
      kernels/moe/marlin_moe_ops.cu
  20. 5 2
      kernels/moe/marlin_moe_ops.h
  21. 2 3
      kernels/moe/topk_softmax_kernels.cu
  22. 5 4
      kernels/moe/torch_bindings.cpp
  23. 91 225
      kernels/quantization/gptq_marlin/gptq_marlin.cu
  24. 1 1
      kernels/quantization/gptq_marlin/gptq_marlin_repack.cu
  25. 1 1
      kernels/quantization/gptq_marlin/marlin.cuh
  26. 1 3
      kernels/quantization/gptq_marlin/marlin_dtypes.cuh
  27. 1 1
      kernels/quantization/quant_ops.h
  28. 24 13
      kernels/torch_bindings.cpp
  29. 223 4
      tests/kernels/test_moe.py

+ 6 - 1
CMakeLists.txt

@@ -392,10 +392,15 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
 
 set(APHRODITE_MOE_EXT_SRC
   "kernels/moe/torch_bindings.cpp"
-  "kernels/moe/softmax.cu")
+  "kernels/moe/topk_softmax_kernels.cu")
 
 if(APHRODITE_GPU_LANG STREQUAL "CUDA")
   list(APPEND APHRODITE_MOE_EXT_SRC
+      "kernels/moe/marlin_kernels/marlin_moe_kernel.h"
+      "kernels/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
+      "kernels/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
+      "kernels/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
+      "kernels/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
       "kernels/moe/marlin_moe_ops.cu")
 endif()
 

+ 3 - 5
aphrodite/_custom_ops.py

@@ -594,7 +594,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
                            num_bits: int) -> torch.Tensor:
     num_experts = b_q_weight.shape[0]
     assert size_k % 16 == 0
-    output = torch.empty((num_experts, size_k // 16, size_n * 2),
+    output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
                          device=b_q_weight.device,
                          dtype=b_q_weight.dtype)
     for e in range(num_experts):
@@ -616,13 +616,11 @@ def gptq_marlin_gemm(a: torch.Tensor,
                      size_k: int,
                      is_k_full: bool,
                      has_zp: bool = False,
-                     use_fp32_reduce: bool = False,
-                     is_zp_float: bool = False) -> torch.Tensor:
+                     use_fp32_reduce: bool = False) -> torch.Tensor:
     return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
                                          g_idx, perm, workspace, b_q_type,
                                          size_m, size_n, size_k, is_k_full,
-                                         has_zp, use_fp32_reduce,
-                                         is_zp_float)
+                                         has_zp, use_fp32_reduce)
 
 
 # machete

+ 10 - 4
aphrodite/modeling/layers/fused_moe/__init__.py

@@ -2,16 +2,22 @@ from aphrodite.modeling.layers.fused_moe.layer import (
     FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
 from aphrodite.triton_utils import HAS_TRITON
 
-__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"]
+__all__ = [
+    "FusedMoE",
+    "FusedMoEMethodBase",
+    "FusedMoeWeightScaleSupported",
+]
 
 if HAS_TRITON:
-
+    from aphrodite.modeling.layers.fused_moe.fused_marlin_moe import (
+        fused_marlin_moe, single_marlin_moe)
     from aphrodite.modeling.layers.fused_moe.fused_moe import (
-        fused_experts, fused_marlin_moe, fused_moe, fused_topk,
-        get_config_file_name, grouped_topk)
+        fused_experts, fused_moe, fused_topk, get_config_file_name,
+        grouped_topk)
 
     __all__ += [
         "fused_marlin_moe",
+        "single_marlin_moe",
         "fused_moe",
         "fused_topk",
         "fused_experts",

+ 235 - 0
aphrodite/modeling/layers/fused_moe/fused_marlin_moe.py

@@ -0,0 +1,235 @@
+"""Fused MoE utilities for GPTQ."""
+import functools
+from typing import Any, Dict, Optional
+
+import torch
+
+from aphrodite import _custom_ops as ops
+from aphrodite.modeling.layers.fused_moe.fused_moe import (
+    fused_topk, moe_align_block_size, try_get_optimal_moe_config)
+from aphrodite.scalar_type import scalar_types
+
+
+def single_marlin_moe(
+    hidden_states: torch.Tensor,
+    w: torch.Tensor,
+    scales: torch.Tensor,
+    gating_output: torch.Tensor,
+    g_idx: torch.Tensor,
+    perm: torch.Tensor,
+    topk: int,
+    renormalize: bool,
+    override_config: Optional[Dict[str, Any]] = None,
+    num_bits: int = 8,
+) -> torch.Tensor:
+    """
+    This function computes the multiplication of hidden_states with expert
+    weights used in Marlin MoE, using weights w and top-k gating mechanism.
+    Its purpose is testing and debugging the fused MoE kernel.
+
+    Parameters:
+    - hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
+    - w (torch.Tensor): The set of expert weights.
+    - scales (torch.Tensor): The quantization scales.
+    - gating_output (torch.Tensor): The output of the gating operation
+        (before softmax).
+    - g_idx (torch.Tensor): The act_order indices.
+    - perm (torch.Tensor): The act_order input permutation.
+    - topk (int): The number of top-k experts to select.
+    - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
+    - override_config (Optional[Dict[str, Any]]): Optional override
+        for the kernel configuration.
+    - num_bits (bool): The number of bits in expert weights quantization.
+
+    Returns:
+    - torch.Tensor: The output tensor after applying the MoE layer.
+    """
+    # Check constraints.
+    assert hidden_states.shape[0] == gating_output.shape[0], (
+        "Number of tokens mismatch")
+    assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
+    assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
+    assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
+    assert w.is_contiguous(), "Expert weights must be contiguous"
+    assert hidden_states.dtype == torch.float16
+    assert num_bits in [4, 8]
+
+    M, K = hidden_states.shape
+    E = w.shape[0]
+    N = w.shape[2] // (num_bits // 2)
+
+    topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
+                                        renormalize)
+
+    # This might not be an optimal config for a single MMM
+    get_config_func = functools.partial(try_get_optimal_moe_config,
+                                        w.shape,
+                                        w.shape,
+                                        topk_ids.shape[1],
+                                        None,
+                                        override_config=override_config,
+                                        is_marlin=True)
+    config = get_config_func(M)
+
+    block_size_m = config['BLOCK_SIZE_M']
+
+    sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
+
+    max_workspace_size = (N // 64) * 16
+    workspace = torch.zeros(max_workspace_size,
+                            dtype=torch.int,
+                            device="cuda",
+                            requires_grad=False)
+
+    scalar_type = (scalar_types.uint4b8
+                   if num_bits == 4 else scalar_types.uint8b128)
+
+    intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
+        hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
+        g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
+        block_size_m, True, False)
+
+    return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
+
+
+def fused_marlin_moe(
+    hidden_states: torch.Tensor,
+    w1: torch.Tensor,
+    w2: torch.Tensor,
+    gating_output: torch.Tensor,
+    g_idx1: torch.Tensor,
+    g_idx2: torch.Tensor,
+    perm1: torch.Tensor,
+    perm2: torch.Tensor,
+    topk_weights: torch.Tensor,
+    topk_ids: torch.Tensor,
+    override_config: Optional[Dict[str, Any]] = None,
+    w1_scale: Optional[torch.Tensor] = None,
+    w2_scale: Optional[torch.Tensor] = None,
+    num_bits: int = 8,
+) -> torch.Tensor:
+    """
+    This function computes a Mixture of Experts (MoE) layer using two sets of
+    weights, w1 and w2, and top-k gating mechanism.
+
+    Parameters:
+    - hidden_states (torch.Tensor): The input tensor to the MoE layer.
+    - w1 (torch.Tensor): The first set of expert weights.
+    - w2 (torch.Tensor): The second set of expert weights.
+    - gating_output (torch.Tensor): The output of the gating operation
+        (before softmax).
+    - g_idx1 (torch.Tensor): The first set of act_order indices.
+    - g_idx2 (torch.Tensor): The second set of act_order indices.
+    - perm1 (torch.Tensor): The first act_order input permutation.
+    - perm2 (torch.Tensor): The second act_order input permutation.
+    - topk_weights (torch.Tensor): Top-k weights.
+    - topk_ids (torch.Tensor): Indices of topk-k elements.
+    - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
+    - override_config (Optional[Dict[str, Any]]): Optional override
+        for the kernel configuration.
+    - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
+        w1.
+    - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
+        w2.
+    - num_bits (bool): The number of bits in expert weights quantization.
+
+    Returns:
+    - torch.Tensor: The output tensor after applying the MoE layer.
+    """
+    # Check constraints.
+    assert hidden_states.shape[0] == gating_output.shape[
+        0], "Number of tokens mismatch"
+    assert hidden_states.shape[
+        1] == w1.shape[1] * 16, "Hidden size mismatch w1"
+    assert hidden_states.shape[1] == w2.shape[2] // (
+        num_bits // 2), "Hidden size mismatch w2"
+    assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
+    assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
+    assert w1.is_contiguous(), "Expert weights1 must be contiguous"
+    assert w2.is_contiguous(), "Expert weights2 must be contiguous"
+    assert hidden_states.dtype == torch.float16
+    assert num_bits in [4, 8]
+
+    M, K = hidden_states.shape
+    E = w1.shape[0]
+    N = w2.shape[1] * 16
+    topk = topk_ids.shape[1]
+
+    get_config_func = functools.partial(
+        try_get_optimal_moe_config,
+        w1.shape,
+        w2.shape,
+        topk_ids.shape[1],
+        None,
+        override_config=override_config,
+        is_marlin=True,
+    )
+    config = get_config_func(M)
+
+    block_size_m = config["BLOCK_SIZE_M"]
+
+    sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
+
+    max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
+    workspace = torch.zeros(max_workspace_size,
+                            dtype=torch.int,
+                            device="cuda",
+                            requires_grad=False)
+
+    scalar_type = (scalar_types.uint4b8
+                   if num_bits == 4 else scalar_types.uint8b128)
+
+    intermediate_cache2 = torch.empty(
+        (M * topk_ids.shape[1], N),
+        device=hidden_states.device,
+        dtype=hidden_states.dtype,
+    )
+
+    intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
+        hidden_states,
+        w1,
+        sorted_token_ids,
+        topk_weights,
+        topk_ids,
+        w1_scale,
+        g_idx1,
+        perm1,
+        workspace,
+        scalar_type,
+        M,
+        2 * N,
+        K,
+        True,
+        E,
+        topk,
+        block_size_m,
+        True,
+        False,
+    )
+
+    ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
+
+    intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
+        intermediate_cache2,
+        w2,
+        sorted_token_ids,
+        topk_weights,
+        topk_ids,
+        w2_scale,
+        g_idx2,
+        perm2,
+        workspace,
+        scalar_type,
+        M,
+        K,
+        N,
+        True,
+        E,
+        topk,
+        block_size_m,
+        False,
+        True,
+    )
+
+    return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
+                     dim=1)

+ 24 - 118
aphrodite/modeling/layers/fused_moe/fused_moe.py

@@ -321,15 +321,22 @@ def get_moe_configs(E: int, N: int,
     return None
 
 
-def get_default_config(M: int, E: int, N: int, K: int, topk: int,
-                       dtype: Optional[str],
-                       is_marlin: bool) -> Dict[str, int]:
+def get_default_config(
+    M: int,
+    E: int,
+    N: int,
+    K: int,
+    topk: int,
+    dtype: Optional[str],
+    is_marlin: bool,
+) -> Dict[str, int]:
     config = {
         'BLOCK_SIZE_M': 64,
         'BLOCK_SIZE_N': 64,
         'BLOCK_SIZE_K': 32,
         'GROUP_SIZE_M': 8
     }
+    # A heuristic: fused marlin works faster with this config for small M
     if M <= E or (is_marlin and M <= 32):
         config = {
             'BLOCK_SIZE_M': 16,
@@ -340,14 +347,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
     return config
 
 
-def try_get_optimal_moe_config(w1_shape: Tuple[int, ...],
-                               w2_shape: Tuple[int, ...],
-                               top_k: int,
-                               dtype: Optional[str],
-                               M: int,
-                               override_config: Optional[Dict[str,
-                                                              Any]] = None,
-                               is_marlin: bool = False):
+def try_get_optimal_moe_config(
+    w1_shape: Tuple[int, ...],
+    w2_shape: Tuple[int, ...],
+    top_k: int,
+    dtype: Optional[str],
+    M: int,
+    override_config: Optional[Dict[str, Any]] = None,
+    is_marlin: bool = False,
+):
     if override_config:
         config = override_config
     else:
@@ -389,6 +397,7 @@ def fused_topk(
                                         topk,
                                         dtype=torch.int32,
                                         device=hidden_states.device)
+
     ops.topk_softmax(
         topk_weights,
         topk_ids,
@@ -399,6 +408,7 @@ def fused_topk(
 
     if renormalize:
         topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
+
     return topk_weights, topk_ids
 
 
@@ -432,114 +442,8 @@ def grouped_topk(hidden_states: torch.Tensor,
 
     if renormalize:
         topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
-    return topk_weights, topk_ids.to(torch.int32)
-
-
-def fused_marlin_moe(hidden_states: torch.Tensor,
-                     w1: torch.Tensor,
-                     w2: torch.Tensor,
-                     gating_output: torch.Tensor,
-                     g_idx1: torch.Tensor,
-                     g_idx2: torch.Tensor,
-                     rand_perm1: torch.Tensor,
-                     rand_perm2: torch.Tensor,
-                     topk: int,
-                     custom_routing_function: Optional[Callable] = None,
-                     renormalize: bool = True,
-                     override_config: Optional[Dict[str, Any]] = None,
-                     use_fp8: bool = False,
-                     w1_scale: Optional[torch.Tensor] = None,
-                     w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
-    """
-    This function computes a Mixture of Experts (MoE) layer using two sets of
-    weights, w1 and w2, and top-k gating mechanism.
-    Parameters:
-    - hidden_states (torch.Tensor): The input tensor to the MoE layer.
-    - w1 (torch.Tensor): The first set of expert weights.
-    - w2 (torch.Tensor): The second set of expert weights.
-    - gating_output (torch.Tensor): The output of the gating operation
-        (before softmax).
-    - topk (int): The number of top-k experts to select.
-    - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
-    - inplace (bool): If True, perform the operation in-place.
-        Defaults to False.
-    - override_config (Optional[Dict[str, Any]]): Optional override
-        for the kernel configuration.
-    - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
-        products for w1 and w2. Defaults to False.
-    - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
-        w1.
-    - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
-        w2.
-    Returns:
-    - torch.Tensor: The output tensor after applying the MoE layer.
-    """
-    # Check constraints.
-    assert hidden_states.shape[0] == gating_output.shape[0], (
-        "Number of tokens mismatch")
-    assert hidden_states.shape[
-        1] == w1.shape[1] * 16, "Hidden size mismatch w1"
-    assert hidden_states.shape[
-        1] == w2.shape[2] // 2, "Hidden size mismatch w2"
-    assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
-    assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
-    assert w1.is_contiguous(), "Expert weights1 must be contiguous"
-    assert w2.is_contiguous(), "Expert weights2 must be contiguous"
-    assert hidden_states.dtype in [
-        torch.float32, torch.float16, torch.bfloat16
-    ]
-
-    #TODO fp8 is not implemented yet
-    assert not use_fp8
-
-    M, K = hidden_states.shape
-    E = w1.shape[0]
-    N = w2.shape[1] * 16
-
-    if custom_routing_function is None:
-        topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
-                                            renormalize)
-    else:
-        topk_weights, topk_ids = custom_routing_function(
-            hidden_states, gating_output, topk, renormalize)
-
-    get_config_func = functools.partial(try_get_optimal_moe_config,
-                                        w1.shape,
-                                        w2.shape,
-                                        topk_ids.shape[1],
-                                        "float8" if use_fp8 else None,
-                                        override_config=override_config,
-                                        is_marlin=True)
-    config = get_config_func(M)
-
-    block_size_m = config['BLOCK_SIZE_M']
-
-    sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
-
-    max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
-    workspace = torch.zeros(max_workspace_size,
-                            dtype=torch.int,
-                            device="cuda",
-                            requires_grad=False)
-
-    intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N),
-                                      device=hidden_states.device,
-                                      dtype=hidden_states.dtype)
-
-    intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
-        hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale,
-        g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk,
-        block_size_m, True, False)
-
-    ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
-
-    intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
-        intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids,
-        w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
-        block_size_m, False, True)
 
-    return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
-                     dim=1)
+    return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
 
 
 def get_config_dtype_str(dtype: torch.dtype,
@@ -581,6 +485,8 @@ def fused_experts(hidden_states: torch.Tensor,
 
     num_tokens, _ = hidden_states.shape
     E, N, _ = w1.shape
+    # We execute the fused_moe kernel in chunks to circumvent this issue:
+    # https://github.com/vllm-project/vllm/issues/5938
     CHUNK_SIZE = envs.APHRODITE_FUSED_MOE_CHUNK_SIZE
     M = min(num_tokens, CHUNK_SIZE)
     config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,

+ 52 - 25
aphrodite/modeling/layers/fused_moe/layer.py

@@ -131,7 +131,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
             custom_routing_function: Optional[Callable] = None
     ) -> torch.Tensor:
 
-        from aphrodite.modeling.layers.fused_moe.moe_pallas import fused_moe
+        from aphrodite.modeling.layers.fused_moe.fused_moe import fused_moe
         assert not use_grouped_topk
         assert num_expert_group is None
         assert topk_group is None
@@ -302,10 +302,28 @@ class FusedMoE(torch.nn.Module):
         # Input scales can be loaded directly and should be equal.
         param_data[expert_id] = loaded_weight
 
+    def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
+                    shard_dim: int, loaded_weight: torch.tensor, tp_rank: int):
+
+        if shard_id == "w2":
+            self._load_w2(shard_id=shard_id,
+                          shard_dim=shard_dim,
+                          loaded_weight=loaded_weight,
+                          expert_data=expert_data,
+                          tp_rank=tp_rank)
+        else:
+            assert shard_id in ("w1", "w3")
+            expert_data.copy_(loaded_weight)
+
     def weight_loader(self, param: torch.nn.Parameter,
                       loaded_weight: torch.Tensor, weight_name: str,
                       shard_id: str, expert_id: int) -> None:
 
+        # compressed-tensors represents weights on disk which are flipped
+        loaded_weight = loaded_weight.t().contiguous() if (
+            self.quant_method.__class__.__name__
+            == "CompressedTensorsMoEMethod") else loaded_weight
+
         if shard_id not in ("w1", "w2", "w3"):
             raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
                              f"got {shard_id}.")
@@ -321,21 +339,43 @@ class FusedMoE(torch.nn.Module):
         expert_data = param.data[expert_id]
         tp_rank = get_tensor_model_parallel_rank()
 
-        # is_transposed: whether or not the parameter is transposed on disk
-        # If transposed, the loaded weight will be transposed and the dim
-        # to shard the loaded weight will be flipped.
+        # is_transposed: if the dim to shard the weight
+        # should be flipped. Required by GPTQ, compressed-tensors
+        # should be whatever dimension intermediate_size is
         is_transposed = getattr(param, "is_transposed", False)
         shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
         if is_transposed:
-            loaded_weight = loaded_weight.t().contiguous()
             shard_dim = ~shard_dim
 
-        # Case weight_scales
-        if "weight_scale" in weight_name:
-            # load the weight scaling based on the quantization scheme
-            # supported weight scales can be found in
+        # Case input scale: input_scale loading is only supported for fp8
+        if "input_scale" in weight_name:
+            if param.data[expert_id] != 1 and (param.data[expert_id] -
+                                               loaded_weight).abs() > 1e-5:
+                raise ValueError(
+                    "input_scales of w1 and w3 of a layer "
+                    f"must be equal. But got {param.data[expert_id]} "
+                    f"vs. {loaded_weight}")
+
+            self._load_single_value(param=param,
+                                    loaded_weight=loaded_weight,
+                                    expert_id=expert_id)
+            return
+
+        # Case g_idx
+        if "g_idx" in weight_name:
+            self._load_g_idx(shard_dim=0,
+                             shard_id=shard_id,
+                             loaded_weight=loaded_weight,
+                             expert_data=expert_data,
+                             tp_rank=tp_rank)
+            return
+
+        # Case weight scales and zero_points
+        if ("scale" in weight_name or "zero" in weight_name):
+            # load the weight scales and zp based on the quantization scheme
+            # supported weight scales/zp can be found in
             # FusedMoeWeightScaleSupported
-            # TODO @dsikka: once hardened, refactor to use vLLM Parameters
+            # TODO @dsikka: once hardened, refactor to use Aphrodite Parameters
             # specific to each case
             quant_method = getattr(param, "quant_method", None)
             if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
@@ -362,22 +402,9 @@ class FusedMoE(torch.nn.Module):
                     f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
             return
 
+        # Case weight_shape
         if "weight_shape" in weight_name:
-            self._load_single_value(param=param,
-                                    loaded_weight=loaded_weight,
-                                    expert_id=expert_id)
-            return
-
-        # Case input scale
-        if "input_scale" in weight_name:
-            # Note: input_scale loading is only supported for fp8
-            if param.data[expert_id] != 1 and (param.data[expert_id] -
-                                               loaded_weight).abs() > 1e-5:
-                raise ValueError(
-                    "input_scales of w1 and w3 of a layer "
-                    f"must be equal. But got {param.data[expert_id]} "
-                    f"vs. {loaded_weight}")
-
+            # only required by compressed-tensors
             self._load_single_value(param=param,
                                     loaded_weight=loaded_weight,
                                     expert_id=expert_id)

+ 2 - 0
aphrodite/modeling/models/mixtral_quant.py

@@ -346,6 +346,8 @@ class MixtralForCausalLM(nn.Module):
         self.lm_head = ParallelLMHead(config.vocab_size,
                                       config.hidden_size,
                                       quant_config=quant_config)
+        if self.config.tie_word_embeddings:
+            self.lm_head.weight = self.model.embed_tokens.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 

+ 1 - 0
aphrodite/quantization/__init__.py

@@ -47,6 +47,7 @@ QUANTIZATION_METHODS = {
     "quip": QuipConfig,
     "squeezellm": SqueezeLLMConfig,
     "compressed-tensors": CompressedTensorsConfig,
+    "compressed_tensors": CompressedTensorsConfig,
     "bitsandbytes": BitsAndBytesConfig,
     "qqq": QQQConfig,
     "hqq": HQQMarlinConfig,

+ 28 - 16
aphrodite/quantization/compressed_tensors/compressed_tensors_moe.py

@@ -6,7 +6,7 @@ import torch
 from compressed_tensors import CompressionFormat
 
 from aphrodite import _custom_ops as ops
-from aphrodite.modeling.layers.fused_moe import FusedMoEMethodBase
+from aphrodite.modeling.layers.fused_moe import FusedMoE, FusedMoEMethodBase
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.compressed_tensors.schemes import (
     WNA16_SUPPORTED_BITS)
@@ -32,7 +32,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
         config = self.quant_config.target_scheme_map["Linear"].get("weights")
         self.num_bits = config.num_bits
         self.packed_factor = 32 // config.num_bits
-        self.strategy = config.strategy.value
+        self.strategy = config.strategy
         self.group_size = config.group_size
         assert config.symmetric, (
             "Only symmetric quantization is supported for MoE")
@@ -268,19 +268,31 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
         custom_routing_function: Optional[Callable] = None,
     ) -> torch.Tensor:
 
-        from aphrodite.modeling.layers.fused_moe.fused_moe import (
+        from aphrodite.modeling.layers.fused_moe.fused_marlin_moe import (
             fused_marlin_moe)
 
-        return fused_marlin_moe(x,
-                                layer.w13_weight_packed,
-                                layer.w2_weight_packed,
-                                router_logits,
-                                layer.w13_g_idx,
-                                layer.w2_g_idx,
-                                layer.w13_g_idx_sort_indices,
-                                layer.w2_g_idx_sort_indices,
-                                top_k,
-                                custom_routing_function,
-                                renormalize=renormalize,
-                                w1_scale=layer.w13_weight_scale,
-                                w2_scale=layer.w2_weight_scale)
+        topk_weights, topk_ids = FusedMoE.select_experts(
+            hidden_states=x,
+            router_logits=router_logits,
+            use_grouped_topk=use_grouped_topk,
+            top_k=top_k,
+            renormalize=renormalize,
+            topk_group=topk_group,
+            num_expert_group=num_expert_group,
+            custom_routing_function=custom_routing_function)
+
+        return fused_marlin_moe(
+            x,
+            layer.w13_weight_packed,
+            layer.w2_weight_packed,
+            router_logits,
+            layer.w13_g_idx,
+            layer.w2_g_idx,
+            layer.w13_g_idx_sort_indices,
+            layer.w2_g_idx_sort_indices,
+            topk_weights,
+            topk_ids,
+            w1_scale=layer.w13_weight_scale,
+            w2_scale=layer.w2_weight_scale,
+            num_bits=self.num_bits,
+        )

+ 20 - 4
aphrodite/quantization/utils/marlin_utils.py

@@ -67,6 +67,7 @@ def _check_marlin_supported(
         return (False, f"Marlin does not support group_size = {group_size}. "
                 f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
                 "are supported.")
+
     return True, None
 
 
@@ -193,6 +194,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
     return s
 
 
+def marlin_moe_permute_scales(
+    s: torch.Tensor,
+    size_k: int,
+    size_n: int,
+    group_size: int,
+):
+    num_experts = s.shape[0]
+    output = torch.empty(
+        (num_experts, s.shape[1], s.shape[2]),
+        device=s.device,
+        dtype=s.dtype,
+    )
+    for e in range(num_experts):
+        output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
+    return output
+
+
 def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
                        num_bits: int) -> torch.Tensor:
     # Permute zero-points in a similar way to scales, but do not use the
@@ -268,8 +286,7 @@ def apply_gptq_marlin_linear(
                                   size_k=input_size_per_partition,
                                   is_k_full=is_k_full,
                                   has_zp=False,
-                                  use_fp32_reduce=use_fp32_reduce,
-                                  is_zp_float=False)
+                                  use_fp32_reduce=use_fp32_reduce)
 
     if bias is not None:
         output.add_(bias)  # In-place add
@@ -306,8 +323,7 @@ def apply_awq_marlin_linear(
                                   size_k=input_size_per_partition,
                                   is_k_full=True,
                                   has_zp=True,
-                                  use_fp32_reduce=use_fp32_reduce,
-                                  is_zp_float=True)
+                                  use_fp32_reduce=use_fp32_reduce)
 
     if bias is not None:
         output.add_(bias)  # In-place add

+ 0 - 1
aphrodite/quantization/utils/marlin_utils_fp8.py

@@ -76,7 +76,6 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
 
     # WEIGHT SCALES
     scales = layer.weight_scale.to(layer.orig_dtype)
-
     # Permute scales
     marlin_scales = marlin_permute_scales(s=scales,
                                           size_k=part_size_k,

+ 7 - 4
aphrodite/quantization/utils/marlin_utils_test.py

@@ -1,6 +1,6 @@
 """Utility functions used for tests and benchmarks"""
 
-from typing import List
+from typing import List, Optional
 
 import numpy as np
 import torch
@@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int):
     return perm
 
 
-def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
-                    act_order: bool):
+def marlin_quantize(w: torch.Tensor,
+                    quant_type: ScalarType,
+                    group_size: int,
+                    act_order: bool,
+                    test_perm: Optional[torch.Tensor] = None):
     size_k, size_n = w.shape
     num_bits = quant_type.size_bits
 
@@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
 
     # Quantize (and apply act_order if provided)
     w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
-        w, quant_type, group_size, act_order)
+        w, quant_type, group_size, act_order, test_perm)
 
     # For act_order, sort the "weights" and "g_idx" so that group ids are
     # increasing

+ 15 - 7
aphrodite/quantization/utils/quant_utils.py

@@ -1,5 +1,5 @@
 """This file is used for /tests and /benchmarks"""
-from typing import List
+from typing import List, Optional
 
 import numpy
 import torch
@@ -10,7 +10,7 @@ from aphrodite.scalar_type import ScalarType, scalar_types
 SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
 SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
 
-# NOTE: this is a hack. We should update each model to register the
+# Note: this is a hack. We should update each model to register the
 # stacked params and get it from there instead in a future PR.
 # fused_name: List[shard_name]
 FUSED_LAYER_NAME_MAPPING = {
@@ -95,7 +95,10 @@ def get_pack_factor(num_bits):
     return 32 // num_bits
 
 
-def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
+def permute_rows(q_w: torch.Tensor,
+                 w_ref: torch.Tensor,
+                 group_size: int,
+                 test_perm: Optional[torch.Tensor] = None):
     assert q_w.shape == w_ref.shape
 
     orig_device = q_w.device
@@ -106,7 +109,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
         g_idx[i] = i // group_size
 
     # Simulate act_order by doing a random permutation on K
-    rand_perm = torch.randperm(k_size)
+    rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
 
     g_idx = g_idx[rand_perm].contiguous()
     q_w = q_w[rand_perm, :].contiguous()
@@ -168,6 +171,7 @@ def quantize_weights(w: torch.Tensor,
     w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
     w_q = torch.clamp(w_q, min_q_val, max_q_val)
 
+    # Compute ref (dequantized)
     # For some kernels (namely Machete) the zero-points are applied after the
     # scales are applied, for this case computing the reference in similar way
     # allows us to use tighter error tolerances in our unit tests.
@@ -205,8 +209,11 @@ def quantize_weights(w: torch.Tensor,
     )
 
 
-def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
-                          group_size: int, act_order: bool):
+def gptq_quantize_weights(w: torch.Tensor,
+                          quant_type: ScalarType,
+                          group_size: int,
+                          act_order: bool,
+                          test_perm: Optional[torch.Tensor] = None):
     size_k, _ = w.shape
 
     assert w.is_floating_point(), "w must be float"
@@ -227,7 +234,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
         ), "For act_order, groupsize = {} must be less than size_k = {}".format(
             group_size, size_k)
 
-        w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
+        w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size,
+                                                    test_perm)
 
     return w_ref, w_q, w_s, g_idx, rand_perm
 

+ 1425 - 0
kernels/moe/marlin_kernels/marlin_moe_kernel.h

@@ -0,0 +1,1425 @@
+#pragma once
+
+#include <torch/all.h>
+
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+#include <iostream>
+
+#include "core/scalar_type.hpp"
+
+namespace marlin_moe {
+
+constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+
+// Instances of `Vec` are used to organize groups of >>registers<<, as needed
+// for instance as inputs to tensor core operations. Consequently, all
+// corresponding index accesses must be compile-time constants, which is why we
+// extensively use `#pragma unroll` throughout the kernel code to guarantee
+// this.
+template <typename T, int n>
+struct Vec {
+  T elems[n];
+  __device__ T& operator[](int i) { return elems[i]; }
+};
+
+using I4 = Vec<int, 4>;
+
+// Matrix fragments for tensor core instructions; their precise layout is
+// documented here:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+using FragA = Vec<half2, 4>;
+using FragB = Vec<half2, 2>;
+using FragC = Vec<float, 4>;
+using FragS = Vec<half2, 1>;  // quantization scales
+
+// Predicated asynchronous global->shared copy; used for inputs A where we apply
+// predication to handle batchsizes that are not multiples of 16.
+__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
+                                      bool pred = true) {
+  const int BYTES = 16;
+  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
+  asm volatile(
+      "{\n"
+      "   .reg .pred p;\n"
+      "   setp.ne.b32 p, %0, 0;\n"
+      "   @p cp.async.cg.shared.global [%1], [%2], %3;\n"
+      "}\n" ::"r"((int)pred),
+      "r"(smem), "l"(glob_ptr), "n"(BYTES));
+}
+
+// Asynchronous global->shared copy
+__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
+  const int BYTES = 16;
+  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
+  asm volatile(
+      "{\n"
+      "   cp.async.cg.shared.global [%0], [%1], %2;\n"
+      "}\n" ::"r"(smem),
+      "l"(glob_ptr), "n"(BYTES));
+}
+
+// Async copy fence.
+__device__ inline void cp_async_fence() {
+  asm volatile("cp.async.commit_group;\n" ::);
+}
+
+// Wait until at most `n` async copy stages are still pending.
+template <int n>
+__device__ inline void cp_async_wait() {
+  asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
+}
+
+// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
+// output/accumulation.
+__device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
+                           FragC& frag_c) {
+  const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
+  const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
+  float* c = reinterpret_cast<float*>(&frag_c);
+  asm volatile(
+      "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
+      "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
+      : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
+      : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
+        "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
+}
+
+// Instruction for loading a full 16x16 matrix fragment of operand A from shared
+// memory, directly in tensor core layout.
+__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
+  uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
+  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
+  asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
+               : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
+               : "r"(smem));
+}
+
+// Lookup-table based 3-input logical operation; explicitly used for
+// dequantization as the compiler does not seem to automatically recognize it in
+// all cases.
+template <int lut>
+__device__ inline int lop3(int a, int b, int c) {
+  int res;
+  asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+               : "=r"(res)
+               : "r"(a), "r"(b), "r"(c), "n"(lut));
+  return res;
+}
+
+// Constructs destination register by taking bytes from 2 sources (based on
+// mask)
+template <int start_byte, int mask>
+__device__ inline uint32_t prmt(uint32_t a) {
+  uint32_t res;
+  asm volatile("prmt.b32 %0, %1, %2, %3;\n"
+               : "=r"(res)
+               : "r"(a), "n"(start_byte), "n"(mask));
+  return res;
+}
+
+template <aphrodite::ScalarTypeId w_type_id>
+__device__ inline FragB dequant(int q);
+
+// Efficiently dequantize 4bit values packed in an int32 value into a full
+// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
+// with some small changes:
+// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
+template <>
+__device__ inline FragB dequant<aphrodite::kU4B8.id()>(int q) {
+  const int LO = 0x000f000f;
+  const int HI = 0x00f000f0;
+  const int EX = 0x64006400;
+  // Guarantee that the `(a & b) | c` operations are LOP3s.
+  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
+  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
+  // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
+  // directly into `SUB` and `ADD`.
+  const int SUB = 0x64086408;
+  const int MUL = 0x2c002c00;
+  const int ADD = 0xd480d480;
+  FragB frag_b;
+  frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
+                      *reinterpret_cast<const half2*>(&SUB));
+  frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
+                      *reinterpret_cast<const half2*>(&MUL),
+                      *reinterpret_cast<const half2*>(&ADD));
+  return frag_b;
+}
+
+// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16
+// Reference:
+// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
+template <>
+__device__ inline FragB dequant<aphrodite::kU8B128.id()>(int q) {
+  static constexpr uint32_t mask_for_elt_01 = 0x5250;
+  static constexpr uint32_t mask_for_elt_23 = 0x5351;
+  static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
+
+  uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
+  uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
+
+  static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
+
+  FragB frag_b;
+  frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
+                      *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
+  frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
+                      *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
+  return frag_b;
+}
+
+// Multiply dequantized values by the corresponding quantization scale; used
+// only for grouped quantization.
+__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
+  half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
+  frag_b[0] = __hmul2(frag_b[0], s);
+  frag_b[1] = __hmul2(frag_b[1], s);
+}
+
+// Given 2 floats multiply by 2 scales (halves)
+__device__ inline void scale_float(float* c, FragS& s) {
+  __half* s_ptr = reinterpret_cast<__half*>(&s);
+  c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
+  c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
+}
+
+// Same as above, but for act_order (each K is multiplied individually)
+__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2,
+                              FragS& frag_s_3, FragS& frag_s_4, int i) {
+  __half2 s_val_1_2;
+  s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i];
+  s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i];
+
+  __half2 s_val_3_4;
+  s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i];
+  s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i];
+
+  frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
+  frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
+}
+
+// Wait until barrier reaches `count`, then lock for current threadblock.
+__device__ inline void barrier_acquire(int* lock, int count) {
+  if (threadIdx.x == 0) {
+    int state = -1;
+    do
+      // Guarantee that subsequent writes by this threadblock will be visible
+      // globally.
+      asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
+                   : "=r"(state)
+                   : "l"(lock));
+    while (state != count);
+  }
+  __syncthreads();
+}
+
+// Release barrier and increment visitation count.
+__device__ inline void barrier_release(int* lock, bool reset = false) {
+  __syncthreads();
+  if (threadIdx.x == 0) {
+    if (reset) {
+      lock[0] = 0;
+      return;
+    }
+    int val = 1;
+    // Make sure that all writes since acquiring this barrier are visible
+    // globally, while releasing the barrier.
+    asm volatile("fence.acq_rel.gpu;\n");
+    asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
+                 :
+                 : "l"(lock), "r"(val));
+  }
+}
+
+template <const aphrodite::ScalarTypeId w_type_id,  // weight ScalarType id
+          const int threads,          // number of threads in a threadblock
+          const int thread_m_blocks,  // number of 16x16 blocks in the m
+                                      // dimension (batchsize) of the
+                                      // threadblock
+          const int thread_n_blocks,  // same for n dimension (output)
+          const int thread_k_blocks,  // same for k dimension (reduction)
+          const int stages,  // number of stages for the async global->shared
+                             // fetch pipeline
+          const bool has_act_order,    // whether act_order is enabled
+          const int group_blocks = -1  // number of consecutive 16x16 blocks
+                                       // with a separate quantization scale
+          >
+__device__ inline void MarlinMoESingle(
+    const int4* __restrict__ A,  // fp16 input matrix of shape mxk
+    const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn
+    int4* __restrict__ C,        // fp16 output buffer of shape mxn
+    const int* __restrict__ sorted_ids,      // int32 sorted ids of experts
+    const float* __restrict__ topk_weights,  // float topk weights
+    const int4* __restrict__ scales_ptr,  // fp16 quantization scales of shape
+                                          // (k/groupsize)xn
+    const int* __restrict__ g_idx,        // int32 group indices of shape k
+    const int* __restrict__ expert_offsets,
+    int num_groups,        // number of scale groups per output channel
+    int expert_idx,        // idx of current expert
+    int num_experts,       // number of experts
+    int topk,              // topk parameter of moe
+    int prob_m,            // batch dimension m
+    int prob_n,            // output dimension n
+    int prob_k,            // reduction dimension k
+    int tot_m,             // total number of rows in A and C
+    int* locks,            // extra global storage for barrier synchronization
+    bool replicate_input,  // do we use the same input for each expert?
+    bool apply_weights,    // apply weights to output
+    int current_m_block    // current m block to start kernel computation from
+) {
+  static constexpr auto w_type = aphrodite::ScalarType::from_id(w_type_id);
+  constexpr int pack_factor = 32 / w_type.size_bits();
+
+  // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
+  // better partitioning with less reductions
+  int parallel = 1;
+  if (prob_m > 16 * thread_m_blocks) {
+    parallel = prob_m / (16 * thread_m_blocks);
+    prob_m = 16 * thread_m_blocks;
+  }
+
+  int k_tiles = prob_k / 16 / thread_k_blocks;
+  int n_tiles = prob_n / 16 / thread_n_blocks;
+  int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
+
+  if constexpr (!has_act_order && group_blocks != -1) {
+    if (group_blocks >= thread_k_blocks) {
+      // Ensure that the number of tiles in each stripe is a multiple of the
+      // groupsize; this avoids an annoying special case where a stripe starts
+      // in the middle of group.
+      iters = (group_blocks / thread_k_blocks) *
+              ceildiv(iters, (group_blocks / thread_k_blocks));
+    }
+  }
+
+  int slice_row = (iters * blockIdx.x) % k_tiles;
+  int slice_col_par = (iters * blockIdx.x) / k_tiles;
+  int slice_col = slice_col_par;
+  int slice_iters;  // number of threadblock tiles in the current slice
+  int slice_count =
+      0;          // total number of active threadblocks in the current slice
+  int slice_idx;  // index of threadblock in current slice; numbered bottom to
+                  // top
+
+  // We can easily implement parallel problem execution by just remapping
+  // indices and advancing global pointers
+  if (slice_col_par >= n_tiles) {
+    locks += (slice_col_par / n_tiles) * n_tiles;
+    slice_col = slice_col_par % n_tiles;
+    sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks;
+  }
+
+  // Compute all information about the current slice which is required for
+  // synchronization.
+  auto init_slice = [&]() {
+    slice_iters =
+        iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
+    if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
+    if (slice_iters == 0) return;
+    if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
+    slice_count = 1;
+    slice_idx = 0;
+    int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
+    if (col_first <= k_tiles * (slice_col_par + 1)) {
+      int col_off = col_first - k_tiles * slice_col_par;
+      slice_count = ceildiv(k_tiles - col_off, iters);
+      if (col_off > 0) slice_count++;
+      int delta_first = iters * blockIdx.x - col_first;
+      if (delta_first < 0 || (col_off == 0 && delta_first == 0))
+        slice_idx = slice_count - 1;
+      else {
+        slice_idx = slice_count - 1 - delta_first / iters;
+        if (col_off > 0) slice_idx--;
+      }
+    }
+    if (slice_col == n_tiles) {
+      sorted_ids += 16 * thread_m_blocks;
+      locks += n_tiles;
+      slice_col = 0;
+    }
+  };
+  init_slice();
+
+  // A sizes/strides
+
+  // stride of the A matrix in global memory
+  int a_gl_stride = prob_k / 8;
+  // stride of an A matrix tile in shared memory
+  constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
+  // delta between subsequent A tiles in global memory
+  constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
+  // between subsequent accesses within a tile
+  int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
+  // between shared memory writes
+  constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
+  // between shared memory tile reads
+  constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
+  // within a shared memory tile
+  constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
+  // overall size of a tile
+  constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
+  // number of shared write iterations for a tile
+  constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta);
+
+  // B sizes/strides
+  int b_gl_stride = 16 * prob_n / (pack_factor * 4);
+  constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
+  constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
+  constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
+
+  int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
+  int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
+  constexpr int b_sh_wr_delta = threads * b_thread_vecs;
+  constexpr int b_sh_rd_delta = threads * b_thread_vecs;
+  constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
+  constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
+
+  // Scale sizes/strides without act_order
+  int s_gl_stride = prob_n / 8;
+  constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
+  constexpr int s_tb_groups =
+      !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
+          ? thread_k_blocks / group_blocks
+          : 1;
+  constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
+  int s_gl_rd_delta = s_gl_stride;
+  // Scale size/strides with act_order
+  constexpr int tb_k = 16 * thread_k_blocks;
+  constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
+  // constexpr int act_s_row_stride      = 1;
+  // int           act_s_col_stride      = act_s_row_stride * num_groups;
+  int act_s_col_stride = 1;
+  int act_s_col_warp_stride = act_s_col_stride * 8;
+  int tb_n_warps = thread_n_blocks / 4;
+  int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
+
+  constexpr int sorted_sh_stride = threads;
+  constexpr int sorted_gl_stride = threads;
+
+  // Global A read index of current thread.
+  int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
+                (threadIdx.x % a_gl_rd_delta_o);
+  a_gl_rd += a_gl_rd_delta_o * slice_row;
+  // Shared write index of current thread.
+  int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
+                (threadIdx.x % a_gl_rd_delta_o);
+  // Shared read index.
+  int a_sh_rd =
+      a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
+  a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
+
+  int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
+                (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
+  b_gl_rd += b_sh_stride * slice_col;
+  b_gl_rd += b_gl_rd_delta_o * slice_row;
+  int b_sh_wr = threadIdx.x * b_thread_vecs;
+  int b_sh_rd = threadIdx.x * b_thread_vecs;
+
+  // For act_order
+  constexpr int k_iter_size = tb_k / b_sh_wr_iters;
+  int slice_k_start = tb_k * slice_row;
+  int slice_k_finish = slice_k_start + tb_k * slice_iters;
+  int slice_k_start_shared_fetch = slice_k_start;
+  int slice_n_offset = act_s_col_tb_stride * slice_col;
+
+  // No act_order
+  int s_gl_rd;
+  if constexpr (!has_act_order) {
+    if constexpr (group_blocks == -1) {
+      s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
+    } else {
+      s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
+                s_sh_stride * slice_col + threadIdx.x;
+    }
+  }
+  int s_sh_wr = threadIdx.x;
+  bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
+
+  // We use a different scale layout for grouped and column-wise quantization as
+  // we scale a `half2` tile in column-major layout in the former and in
+  // row-major in the latter case.
+  int s_sh_rd;
+  if constexpr (group_blocks != -1)
+    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
+              (threadIdx.x % 32) / 4;
+  else
+    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
+              (threadIdx.x % 32) % 4;
+
+  int sh_first_group_id = -1;
+  int sh_num_groups = -1;
+  constexpr int sh_max_num_groups = 32;
+
+  int shs_size;
+  if constexpr (has_act_order)
+    shs_size = sh_max_num_groups * s_sh_stride + threads;
+  else
+    shs_size = group_blocks > 0 ? stages * s_sh_stage : threads;
+
+  extern __shared__ int4 sh[];
+  // Shared memory storage for global fetch pipelines.
+  int4* sh_a = sh;
+  int4* sh_b = sh_a + (stages * a_sh_stage);
+  int4* sh_g_idx = sh_b + (stages * b_sh_stage);
+  int4* sh_s = sh_g_idx + (stages * g_idx_stage);
+  int* sh_sorted = (int*)(sh_s + shs_size);
+
+  // Precompute which thread should not read memory in which iterations; this is
+  // needed if there are more threads than required for a certain tilesize or
+  // when the batchsize is not a multiple of 16.
+  bool a_sh_wr_pred[a_sh_wr_iters];
+  #pragma unroll
+  for (int i = 0; i < a_sh_wr_iters; i++) {
+    int a_idx = a_sh_wr_delta * i + a_sh_wr;
+    int row = a_idx / a_gl_rd_delta_o;
+    if (row >= prob_m) {
+      a_sh_wr_pred[i] = false;
+    } else {
+      a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
+    }
+  }
+
+  // To ensure that writing and reading A tiles to/from shared memory, the
+  // latter in fragment format, is fully bank conflict free, we need to use a
+  // rather fancy XOR-based layout. The key here is that neither reads nor
+  // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
+  // same shared memory banks. Further, it seems (based on NSight-Compute) that
+  // each warp must also write a consecutive memory segment?
+  auto transform_a = [&](int i) {
+    int row = i / a_gl_rd_delta_o;
+    return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
+  };
+  // Since the computation of this remapping is non-trivial and, due to our main
+  // loop unrolls, all shared memory accesses are static, we simply precompute
+  // both transformed reads and writes.
+  int a_sh_wr_trans[a_sh_wr_iters];
+  #pragma unroll
+  for (int i = 0; i < a_sh_wr_iters; i++)
+    a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
+  int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
+  #pragma unroll
+  for (int i = 0; i < b_sh_wr_iters; i++) {
+  #pragma unroll
+    for (int j = 0; j < thread_m_blocks; j++)
+      a_sh_rd_trans[i][j] =
+          transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
+  }
+
+  // Since B-accesses have non-constant stride they have to be computed at
+  // runtime; we break dependencies between subsequent accesses with a tile by
+  // maintining multiple pointers (we have enough registers), a tiny
+  // optimization.
+  const int4* B_ptr[b_sh_wr_iters];
+  #pragma unroll
+  for (int i = 0; i < b_sh_wr_iters; i++)
+    B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
+
+  // Register storage for double buffer of shared memory reads.
+  FragA frag_a[2][thread_m_blocks];
+  I4 frag_b_quant[2][b_thread_vecs];
+  FragC frag_c[thread_m_blocks][4][2];
+  FragS frag_s[2][4];         // No act-order
+  FragS act_frag_s[2][4][4];  // For act-order
+
+  // Zero accumulators.
+  auto zero_accums = [&]() {
+  #pragma unroll
+    for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
+      reinterpret_cast<float*>(frag_c)[i] = 0;
+  };
+
+  auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
+                                    int last_group_id) {
+    sh_first_group_id = first_group_id;
+    sh_num_groups = last_group_id - first_group_id + 1;
+
+    if (sh_num_groups < sh_max_num_groups) {
+      sh_num_groups = sh_max_num_groups;
+    }
+
+    if (sh_first_group_id + sh_num_groups > num_groups) {
+      sh_num_groups = num_groups - sh_first_group_id;
+    }
+
+    int row_offset = first_group_id * s_gl_stride;
+
+    if (is_async) {
+      for (int i = 0; i < sh_num_groups; i++) {
+        if (threadIdx.x < s_sh_stride) {
+          cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
+                         &scales_ptr[row_offset + (i * s_gl_stride) +
+                                     slice_n_offset + threadIdx.x]);
+        }
+      }
+    } else {
+      for (int i = 0; i < sh_num_groups; i++) {
+        if (threadIdx.x < s_sh_stride) {
+          sh_s[(i * s_sh_stride) + threadIdx.x] =
+              scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
+                         threadIdx.x];
+        }
+      }
+    }
+  };
+  // Asynchronously fetch the next A, B and s tile from global to the next
+  // shared memory pipeline location.
+  auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
+    if (pred) {
+      int4* sh_a_stage = sh_a + a_sh_stage * pipe;
+  #pragma unroll
+      for (int i = 0; i < a_sh_wr_iters; i++) {
+        int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
+        int row = a_idx / a_gl_stride;
+        int sorted_row =
+            replicate_input ? sorted_ids[row] / topk : sorted_ids[row];
+        int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
+        if (sorted_row < tot_m * (replicate_input ? 1 : topk) &&
+            new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) {
+          cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx],
+                         a_sh_wr_pred[i]);
+        }
+      }
+      int4* sh_b_stage = sh_b + b_sh_stage * pipe;
+  #pragma unroll
+      for (int i = 0; i < b_sh_wr_iters; i++) {
+  #pragma unroll
+        for (int j = 0; j < b_thread_vecs; j++) {
+          cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
+        }
+        B_ptr[i] += b_gl_rd_delta_o;
+      }
+
+      if constexpr (has_act_order) {
+        // Fetch g_idx thread-block portion
+        int full_pipe = a_off;
+        int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
+        if (cur_k < prob_k && cur_k < slice_k_finish) {
+          int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
+
+          int4 const* cur_g_idx_stage_ptr =
+              reinterpret_cast<int4 const*>(&g_idx[cur_k]);
+
+          if (threadIdx.x < g_idx_stage) {
+            cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
+                           &cur_g_idx_stage_ptr[threadIdx.x]);
+          }
+        }
+      } else {
+        if constexpr (group_blocks != -1) {
+          int4* sh_s_stage = sh_s + s_sh_stage * pipe;
+
+          if constexpr (group_blocks >= thread_k_blocks) {
+            // Only fetch scales if this tile starts a new group
+            if (pipe % (group_blocks / thread_k_blocks) == 0) {
+              if (s_sh_wr_pred) {
+                cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
+              }
+              s_gl_rd += s_gl_rd_delta;
+            }
+          } else {
+            for (int i = 0; i < s_tb_groups; i++) {
+              if (s_sh_wr_pred) {
+                cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
+                          &scales_ptr[s_gl_rd]);
+              }
+              s_gl_rd += s_gl_rd_delta;
+            }
+          }
+        }
+      }
+    }
+    // Insert a fence even when we are winding down the pipeline to ensure that
+    // waiting is also correct at this point.
+    cp_async_fence();
+  };
+
+  // TODO we are currently hitting illegal memory accesses when fetching
+  // sorted_ids to shared data: fix this
+  auto fetch_sorted_ids_to_shared = [&]() {
+    const int mpt = ceildiv(prob_m, threads);
+    for (int i = 0; i < mpt; i++) {
+      if ((i * sorted_gl_stride) + threadIdx.x < prob_m) {
+        sh_sorted[(i * sorted_sh_stride) + threadIdx.x] =
+            sorted_ids[(i * sorted_gl_stride) + threadIdx.x];
+      }
+    }
+  };
+
+  // Wait until the next thread tile has been loaded to shared memory.
+  auto wait_for_stage = [&]() {
+    // We only have `stages - 2` active fetches since we are double buffering
+    // and can only issue the next fetch when it is guaranteed that the previous
+    // shared memory load is fully complete (as it may otherwise be
+    // overwritten).
+    cp_async_wait<stages - 2>();
+    __syncthreads();
+  };
+
+  // Load the next sub-tile from the current location in the shared memory pipe
+  // into the current register buffer.
+  auto fetch_to_registers = [&](int k, int pipe) {
+    int4* sh_a_stage = sh_a + a_sh_stage * pipe;
+  #pragma unroll
+    for (int i = 0; i < thread_m_blocks; i++)
+      ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
+    int4* sh_b_stage = sh_b + b_sh_stage * pipe;
+
+  #pragma unroll
+    for (int i = 0; i < b_thread_vecs; i++) {
+      frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
+          &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
+    }
+  };
+
+  bool is_same_group[stages];
+  int same_group_id[stages];
+
+  auto init_same_group = [&](int pipe) {
+    if constexpr (!has_act_order) {
+      is_same_group[pipe] = false;
+      same_group_id[pipe] = 0;
+      return;
+    }
+
+    int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
+    int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
+
+    int group_id_1 = sh_g_idx_int_ptr[0];
+    int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
+
+    is_same_group[pipe] = group_id_1 == group_id_2;
+    same_group_id[pipe] = group_id_1;
+  };
+
+  auto fetch_scales_to_registers = [&](int k, int full_pipe) {
+    int pipe = full_pipe % stages;
+
+    if constexpr (!has_act_order) {
+      // No act-order case
+      if constexpr (group_blocks != -1) {
+        if constexpr (group_blocks >= thread_k_blocks) {
+          int4* sh_s_stage =
+              sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
+                                   (pipe / (group_blocks / thread_k_blocks)));
+          reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
+        } else {
+          int warp_id = threadIdx.x / 32;
+          int n_warps = thread_n_blocks / 4;
+
+          int warp_row = warp_id / n_warps;
+
+          int cur_k = warp_row * 16;
+          cur_k += k_iter_size * (k % b_sh_wr_iters);
+
+          int k_blocks = cur_k / 16;
+          int cur_group_id = k_blocks / group_blocks;
+
+          int4* sh_s_stage = sh_s + s_sh_stage * pipe;
+
+          reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
+              sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
+        }
+      }
+
+      return;
+    }
+
+    // Act-order case
+
+    // Determine K of the "current" thread-block
+    int cur_k = slice_k_start + tb_k * full_pipe;
+    if (cur_k >= prob_k || cur_k >= slice_k_finish) {
+      return;
+    }
+
+    // Reset (to current thread-block) since we read g_idx portion from the
+    // shared memory
+    cur_k = 0;
+
+    // Progress to current iteration
+    cur_k += k_iter_size * (k % b_sh_wr_iters);
+
+    // Determine "position" inside the thread-block (based on warp and
+    // thread-id)
+    int warp_id = threadIdx.x / 32;
+    int n_warps =
+        thread_n_blocks / 4;  // Each warp processes 4 16-size tiles over N
+
+    int warp_row = warp_id / n_warps;
+    int warp_col = warp_id % n_warps;
+
+    cur_k += warp_row * 16;
+
+    int th_id = threadIdx.x % 32;
+    cur_k += (th_id % 4) * 2;  // Due to tensor-core layout for fp16 B matrix
+
+    int s_col_shift =
+        /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
+        (th_id / 4) * act_s_col_stride;
+
+    if (is_same_group[pipe]) {
+      if (k % 2 == 0) {
+        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
+            sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
+                 s_col_shift];
+      } else {
+        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
+            *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));
+      }
+
+      for (int i = 1; i < 4; i++) {
+        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
+            *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));
+      }
+      return;
+    }
+
+    int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
+    int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
+
+    constexpr int k_frag_offsets[4] = {0, 1, 8,
+                                       9};  // Tensor core offsets per thread
+
+  #pragma unroll
+    for (int i = 0; i < 4; i++) {
+      int actual_k = cur_k + k_frag_offsets[i];
+
+      int group_id = sh_g_idx_int_ptr[actual_k];
+      int rel_group_id = group_id - sh_first_group_id;
+
+      *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
+          sh_s[rel_group_id * s_sh_stride + s_col_shift];
+    }
+  };
+
+  // Execute the actual tensor core matmul of a sub-tile.
+  auto matmul = [&](int k) {
+  // We have the m dimension as the inner loop in order to encourage overlapping
+  // dequantization and matmul operations.
+  #pragma unroll
+    for (int j = 0; j < 4; j++) {
+      int b_quant_0, b_quant_1;
+      if constexpr (w_type.size_bits() == 4) {
+        b_quant_0 = frag_b_quant[k % 2][0][j];
+        b_quant_1 = b_quant_0 >> 8;
+      } else {
+        static_assert(w_type.size_bits() == 8);
+        int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
+        b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
+        b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
+      }
+
+      FragB frag_b0 = dequant<w_type_id>(b_quant_0);
+      FragB frag_b1 = dequant<w_type_id>(b_quant_1);
+
+      // Apply scale to frag_b0
+      if constexpr (has_act_order) {
+        scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
+               act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
+      } else {
+        if constexpr (group_blocks != -1) {
+          scale(frag_b0, frag_s[k % 2][j], 0);
+        }
+      }
+
+      // Apply scale to frag_b1
+      if constexpr (has_act_order) {
+        scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
+               act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);
+
+      } else {
+        if constexpr (group_blocks != -1) {
+          scale(frag_b1, frag_s[k % 2][j], 1);
+        }
+      }
+
+  #pragma unroll
+      for (int i = 0; i < thread_m_blocks; i++) {
+        mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
+        mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
+      }
+    }
+  };
+
+  // Since we slice across the k dimension of a tile in order to increase the
+  // number of warps while keeping the n dimension of a tile reasonable, we have
+  // multiple warps that accumulate their partial sums of the same output
+  // location; which we have to reduce over in the end. We do in shared memory.
+  auto thread_block_reduce = [&]() {
+    constexpr int red_off = threads / b_sh_stride_threads / 2;
+    if (red_off >= 1) {
+      int red_idx = threadIdx.x / b_sh_stride_threads;
+      constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
+      constexpr int red_sh_delta = b_sh_stride_threads;
+      int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
+                      (threadIdx.x % b_sh_stride_threads);
+
+      // Parallel logarithmic shared memory reduction. We make sure to avoid any
+      // unnecessary read or write iterations, e.g., for two warps we write only
+      // once by warp 1 and read only once by warp 0.
+
+  #pragma unroll
+      for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
+  #pragma unroll
+        for (int i = red_off; i > 0; i /= 2) {
+          if (i <= red_idx && red_idx < 2 * i) {
+  #pragma unroll
+            for (int j = 0; j < 4 * 2; j++) {
+              int red_sh_wr =
+                  red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
+              if (i < red_off) {
+                float* c_rd =
+                    reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
+                float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
+  #pragma unroll
+                for (int k = 0; k < 4; k++)
+                  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
+                      c_rd[k] + c_wr[k];
+              }
+              sh[red_sh_wr] =
+                  reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
+            }
+          }
+          __syncthreads();
+        }
+        if (red_idx == 0) {
+  #pragma unroll
+          for (int i = 0; i < 4 * 2; i++) {
+            float* c_rd =
+                reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
+  #pragma unroll
+            for (int j = 0; j < 4; j++)
+              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
+                  c_rd[j];
+          }
+        }
+        __syncthreads();
+      }
+    }
+  };
+
+  // Since multiple threadblocks may process parts of the same column slice, we
+  // finally have to globally reduce over the results. As the striped
+  // partitioning minimizes the number of such reductions and our outputs are
+  // usually rather small, we perform this reduction serially in L2 cache.
+  auto global_reduce = [&](bool first = false, bool last = false) {
+    // We are very careful here to reduce directly in the output buffer to
+    // maximize L2 cache utilization in this step. To do this, we write out
+    // results in FP16 (but still reduce with FP32 compute).
+    constexpr int active_threads = 32 * thread_n_blocks / 4;
+    if (threadIdx.x < active_threads) {
+      int c_gl_stride = prob_n / 8;
+      int c_gl_wr_delta_o = 8 * c_gl_stride;
+      int c_gl_wr_delta_i = 4 * (active_threads / 32);
+      int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
+                    4 * (threadIdx.x / 32) + threadIdx.x % 4;
+      c_gl_wr += (2 * thread_n_blocks) * slice_col;
+      constexpr int c_sh_wr_delta = active_threads;
+      int c_sh_wr = threadIdx.x;
+
+      int row = (threadIdx.x % 32) / 4;
+
+      if (!first) {
+  // Interestingly, doing direct global accesses here really seems to mess up
+  // the compiler and lead to slowdowns, hence we also use async-copies even
+  // though these fetches are not actually asynchronous.
+  #pragma unroll
+        for (int i = 0; i < thread_m_blocks * 4; i++) {
+          int c_idx =
+              c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
+          int sorted_row = sorted_ids[c_idx / c_gl_stride];
+          int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;
+          cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx],
+                         sorted_row < tot_m * topk &&
+                             (8 * (i / 2) + row < prob_m &&
+                              (i < (thread_m_blocks - 1) * 4 ||
+                               sorted_ids[8 * (i / 2) + row] < tot_m * topk)));
+        }
+        cp_async_fence();
+        cp_async_wait<0>();
+      }
+
+  #pragma unroll
+      for (int i = 0; i < thread_m_blocks * 4; i++) {
+        if (8 * (i / 2) + row < prob_m &&
+            (i < (thread_m_blocks - 1) * 4 ||
+             sorted_ids[8 * (i / 2) + row] < tot_m * topk)) {
+          if (!first) {
+            int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
+  #pragma unroll
+            for (int j = 0; j < 2 * 4; j++) {
+              reinterpret_cast<float*>(
+                  &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
+                  __half2float(reinterpret_cast<__half*>(&c_red)[j]);
+            }
+          }
+          if (!last) {
+            int4 c;
+  #pragma unroll
+            for (int j = 0; j < 2 * 4; j++) {
+              reinterpret_cast<__half*>(&c)[j] =
+                  __float2half(reinterpret_cast<float*>(
+                      &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
+            }
+            int c_idx =
+                c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
+            int row = sorted_ids[c_idx / c_gl_stride];
+            if (row < tot_m * topk) {
+              int new_idx = row * c_gl_stride + c_idx % c_gl_stride;
+              C[new_idx] = c;
+            }
+          }
+        }
+      }
+    }
+  };
+
+  // Write out the reduce final result in the correct layout. We only actually
+  // reshuffle matrix fragments in this step, the reduction above is performed
+  // in fragment layout.
+  auto write_result = [&]() {
+    int c_gl_stride = prob_n / 8;
+    constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
+    int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
+    constexpr int c_sh_rd_delta =
+        c_sh_stride * (threads / (2 * thread_n_blocks));
+
+    int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
+                  (threadIdx.x % (2 * thread_n_blocks));
+    c_gl_wr += (2 * thread_n_blocks) * slice_col;
+    int c_sh_wr =
+        (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
+    c_sh_wr += 32 * (threadIdx.x / 32);
+    int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
+                  (threadIdx.x % (2 * thread_n_blocks));
+
+    int c_gl_wr_end = c_gl_stride * prob_m;
+
+    // We first reorder in shared memory to guarantee the most efficient final
+    // global write patterns
+    auto write = [&](int idx, float c0, float c1, FragS& s) {
+      half2 res = __halves2half2(__float2half(c0), __float2half(c1));
+
+      // For per-column quantization we finally apply the scale here (only for
+      // 4-bit)
+      if constexpr (!has_act_order && group_blocks == -1 &&
+                    w_type.size_bits() == 4) {
+        res = __hmul2(res, s[0]);
+      }
+
+      ((half2*)sh)[idx] = res;
+    };
+    if (threadIdx.x / 32 < thread_n_blocks / 4) {
+  #pragma unroll
+      for (int i = 0; i < thread_m_blocks; i++) {
+  #pragma unroll
+        for (int j = 0; j < 4; j++) {
+          int wr = c_sh_wr + 8 * j;
+          write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
+                frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
+          write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
+                frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
+          write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
+                frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
+          write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
+                frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
+        }
+        c_sh_wr += 16 * (4 * c_sh_stride);
+      }
+    }
+    __syncthreads();
+
+  #pragma unroll
+    for (int i = 0;
+         i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
+         i++) {
+      if (c_gl_wr < c_gl_wr_end) {
+        int row = sorted_ids[c_gl_wr / c_gl_stride];
+        if (row < tot_m * topk) {
+          int off = row * c_gl_stride + c_gl_wr % c_gl_stride;
+          if (!apply_weights) {
+            C[off] = sh[c_sh_rd];
+          } else {
+            __half* ctrg = reinterpret_cast<__half*>(&C[off]);
+            __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]);
+            for (int j = 0; j < 8; ++j) {
+              ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j]));
+            }
+          }
+          c_gl_wr += c_gl_wr_delta;
+          c_sh_rd += c_sh_rd_delta;
+        }
+      }
+    }
+  };
+
+  // Start global fetch and register load pipelines.
+  auto start_pipes = [&]() {
+  // TODO re-enable after fixing this function
+  // fetch_sorted_ids_to_shared();
+  // __syncthreads();
+
+  #pragma unroll
+    for (int i = 0; i < stages - 1; i++) {
+      if (has_act_order && i == 0) {
+        int last_g_idx = slice_k_start + stages * tb_k * 2;
+        if (last_g_idx >= prob_k) {
+          last_g_idx = prob_k - 1;
+        }
+        fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
+      }
+      fetch_to_shared(i, i, i < slice_iters);
+    }
+
+    zero_accums();
+    wait_for_stage();
+    init_same_group(0);
+    fetch_to_registers(0, 0);
+    fetch_scales_to_registers(0, 0);
+    a_gl_rd += a_gl_rd_delta_o * (stages - 1);
+    slice_k_start_shared_fetch += tb_k * (stages - 1);
+  };
+  if (slice_iters) {
+    start_pipes();
+  }
+
+  // Main loop.
+  while (slice_iters) {
+    // We unroll over both the global fetch and the register load pipeline to
+    // ensure all shared memory accesses are static. Note that both pipelines
+    // have even length meaning that the next iteration will always start at
+    // index 0.
+  #pragma unroll
+    for (int pipe = 0; pipe < stages;) {
+  #pragma unroll
+      for (int k = 0; k < b_sh_wr_iters; k++) {
+        fetch_to_registers(k + 1, pipe % stages);
+        fetch_scales_to_registers(k + 1, pipe);
+        if (k == b_sh_wr_iters - 2) {
+          fetch_to_shared((pipe + stages - 1) % stages, pipe,
+                          slice_iters >= stages);
+          pipe++;
+          wait_for_stage();
+          init_same_group(pipe % stages);
+        }
+        matmul(k);
+      }
+      slice_iters--;
+      if (slice_iters == 0) {
+        break;
+      }
+    }
+
+    a_gl_rd += a_gl_rd_delta_o * stages;
+    slice_k_start += tb_k * stages;
+    slice_k_start_shared_fetch += tb_k * stages;
+
+    if constexpr (has_act_order) {
+      int first_group_id = g_idx[slice_k_start];
+      int last_g_idx = slice_k_start + stages * tb_k * 2;
+      if (last_g_idx >= prob_k) {
+        last_g_idx = prob_k - 1;
+      }
+      int last_group_id = g_idx[last_g_idx];
+      if (last_group_id >= sh_first_group_id + sh_num_groups) {
+        fetch_scales_to_shared(false, first_group_id, last_group_id);
+        __syncthreads();
+      }
+    }
+
+    // Process results and, if necessary, proceed to the next column slice.
+    // While this pattern may not be the most readable, other ways of writing
+    // the loop seemed to noticeably worse performance after compilation.
+    if (slice_iters == 0) {
+      cp_async_wait<0>();
+      bool last = slice_idx == slice_count - 1;
+      if constexpr (!has_act_order && group_blocks == -1) {
+        if constexpr (w_type.size_bits() == 8) {
+          if (s_sh_wr_pred) {
+            cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
+          }
+          cp_async_fence();
+        } else {
+          // For 4-bit per-column scales, we only fetch them here in the
+          // final step before write-out
+          if (last) {
+            if (s_sh_wr_pred) {
+              cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
+            }
+            cp_async_fence();
+          }
+        }
+      }
+
+      thread_block_reduce();
+      if constexpr (!has_act_order && group_blocks == -1) {
+        if constexpr (w_type.size_bits() == 8) {
+          cp_async_wait<0>();
+          __syncthreads();
+          if (threadIdx.x / 32 < thread_n_blocks / 4) {
+            reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
+            reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
+          }
+
+        } else {
+          if (last) {
+            cp_async_wait<0>();
+            __syncthreads();
+            if (threadIdx.x / 32 < thread_n_blocks / 4) {
+              reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
+              reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
+            }
+          }
+        }
+      }
+
+      // For 8-bit channelwise, we apply the scale before the global reduction
+      // that converts the fp32 results to fp16 (so that we avoid possible
+      // overflow in fp16)
+      if constexpr (!has_act_order && group_blocks == -1 &&
+                    w_type.size_bits() == 8) {
+        if (threadIdx.x / 32 < thread_n_blocks / 4) {
+  #pragma unroll
+          for (int i = 0; i < thread_m_blocks; i++) {
+  #pragma unroll
+            for (int j = 0; j < 4; j++) {
+              scale_float(reinterpret_cast<float*>(&frag_c[i][j][0][0]),
+                          frag_s[j / 2][2 * (j % 2) + 0]);
+              scale_float(reinterpret_cast<float*>(&frag_c[i][j][0][2]),
+                          frag_s[j / 2][2 * (j % 2) + 0]);
+
+              scale_float(reinterpret_cast<float*>(&frag_c[i][j][1][0]),
+                          frag_s[j / 2][2 * (j % 2) + 1]);
+              scale_float(reinterpret_cast<float*>(&frag_c[i][j][1][2]),
+                          frag_s[j / 2][2 * (j % 2) + 1]);
+            }
+          }
+        }
+      }
+
+      if (slice_count > 1) {  // only globally reduce if there is more than one
+                              // block in a slice
+        barrier_acquire(&locks[slice_col], slice_idx);
+        global_reduce(slice_idx == 0, last);
+        barrier_release(&locks[slice_col], last);
+      }
+      if (last)  // only the last block in a slice actually writes the result
+        write_result();
+      slice_row = 0;
+      slice_col_par++;
+      slice_col++;
+      init_slice();
+      if (slice_iters) {
+        a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
+                  (threadIdx.x % a_gl_rd_delta_o);
+  #pragma unroll
+        for (int i = 0; i < b_sh_wr_iters; i++)
+          B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
+        if (slice_col == 0) {
+  #pragma unroll
+          for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
+        }
+
+        // Update slice k/n for scales loading
+        if constexpr (has_act_order) {
+          slice_k_start = tb_k * slice_row;
+          slice_k_finish = slice_k_start + tb_k * slice_iters;
+          slice_k_start_shared_fetch = slice_k_start;
+          slice_n_offset = act_s_col_tb_stride * slice_col;
+
+        } else {
+          s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
+        }
+        start_pipes();
+      }
+    }
+  }
+}
+
+template <const aphrodite::ScalarTypeId w_type_id,  // weight ScalarType id
+          const int threads,          // number of threads in a threadblock
+          const int thread_n_blocks,  // same for n dimension (output)
+          const int thread_k_blocks,  // same for k dimension (reduction)
+          const int stages,  // number of stages for the async global->shared
+                             // fetch pipeline
+          const bool has_act_order,    // whether act_order is enabled
+          const int group_blocks = -1  // number of consecutive 16x16 blocks
+                                       // with a separate quantization scale
+          >
+__global__ void MarlinMoE(
+    const int4* __restrict__ A,  // fp16 input matrix of shape mxk
+    const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn
+    int4* __restrict__ C,        // fp16 output buffer of shape mxn
+    const int* __restrict__ sorted_ids_base,  // int32 sorted ids of experts
+    const float* __restrict__ topk_weights,   // float topk weights
+    const int4* __restrict__ scales_ptr,  // fp16 quantization scales of shape
+                                          // (k/groupsize)xn
+    const int* __restrict__ g_idx,        // int32 group indices of shape k
+    const int* __restrict__ expert_offsets,
+    int num_groups,        // number of scale groups per output channel
+    int expert_idx,        // idx of current expert
+    int num_experts,       // number of experts
+    int topk,              // topk parameter of moe
+    int prob_m,            // batch dimension m
+    int prob_n,            // output dimension n
+    int prob_k,            // reduction dimension k
+    int tot_m,             // total number of rows in A and C
+    int* locks,            // extra global storage for barrier synchronization
+    bool replicate_input,  // do we use the same input for each expert?
+    bool apply_weights,    // apply weights to output
+    int current_m_block,   // current m block to start kernel computation from
+    int max_par,           // maximum parallelism
+    int cfg_max_m_blocks   // upper bound on m blocks
+) {
+  int m_block_ctr = current_m_block;
+
+  const int* sorted_ids_expert =
+      sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par;
+  int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx];
+  if (tot_its == 0) {
+    return;
+  }
+  int tot_m_blocks = ceildiv(tot_its, 16);
+  int pad = 16 * tot_m_blocks - tot_its;
+
+  if (m_block_ctr >= tot_m_blocks) {
+    return;
+  }
+
+  int max_block = tot_m_blocks - m_block_ctr;
+  prob_m = tot_its - 16 * m_block_ctr;
+
+  int par = 1;
+  if (max_block > cfg_max_m_blocks) {
+    // Note that parallel > 1 currently only works for inputs without any
+    // padding
+    par = (16 * max_block - pad) / (16 * cfg_max_m_blocks);
+    if (par > max_par) par = max_par;
+    prob_m = (16 * cfg_max_m_blocks) * par;
+    m_block_ctr += cfg_max_m_blocks * (par - 1);
+    max_block = cfg_max_m_blocks;
+  }
+
+  if (max_block == 1) {
+    MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
+                    stages, has_act_order, group_blocks>(
+        A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
+        expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
+        prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
+        current_m_block);
+  } else if (max_block == 2) {
+    MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
+                    stages, has_act_order, group_blocks>(
+        A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
+        expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
+        prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
+        current_m_block);
+  } else if (max_block == 3) {
+    MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
+                    stages, has_act_order, group_blocks>(
+        A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
+        expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
+        prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
+        current_m_block);
+  } else {
+    MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
+                    stages, has_act_order, group_blocks>(
+        A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
+        expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
+        prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
+        current_m_block);
+  }
+}
+
+#else
+
+template <const aphrodite::ScalarTypeId w_type_id,  // weight ScalarType id
+          const int threads,          // number of threads in a threadblock
+          const int thread_n_blocks,  // same for n dimension (output)
+          const int thread_k_blocks,  // same for k dimension (reduction)
+          const int stages,  // number of stages for the async global->shared
+                             // fetch pipeline
+          const bool has_act_order,    // whether act_order is enabled
+          const int group_blocks = -1  // number of consecutive 16x16 blocks
+                                       // with a separate quantization scale
+          >
+__global__ void MarlinMoE(
+    const int4* __restrict__ A,  // fp16 input matrix of shape mxk
+    const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn
+    int4* __restrict__ C,        // fp16 output buffer of shape mxn
+    const int* __restrict__ sorted_ids,      // int32 sorted ids of experts
+    const float* __restrict__ topk_weights,  // float topk weights
+    const int4* __restrict__ scales_ptr,  // fp16 quantization scales of shape
+                                          // (k/groupsize)xn
+    const int* __restrict__ g_idx,        // int32 group indices of shape k
+    const int* __restrict__ expert_offsets,
+    int num_groups,        // number of scale groups per output channel
+    int expert_idx,        // idx of current expert
+    int num_experts,       // number of experts
+    int topk,              // topk parameter of moe
+    int prob_m,            // batch dimension m
+    int prob_n,            // output dimension n
+    int prob_k,            // reduction dimension k
+    int tot_m,             // total number of rows in A and C
+    int* locks,            // extra global storage for barrier synchronization
+    bool replicate_input,  // do we use the same input for each expert?
+    bool apply_weights,    // apply weights to output
+    int current_m_block,   // current m block to start kernel computation from
+    int max_par,           // maximum parallelism
+    int cfg_max_m_blocks   // upper bound on m blocks
+
+) {
+  // Marlin is not implemented yet for SM < 8.0
+  assert(false);
+  return;
+}
+
+#endif
+
+// 8 warps are a good choice since every SM has 4 schedulers and having more
+// than 1 warp per schedule allows some more latency hiding. At the same time,
+// we want relatively few warps to have many registers per warp and small tiles.
+const int USER_THREADS =
+    256;               // Note: This is only used with user-provided thread_k/n
+const int STAGES = 4;  // 4 pipeline stages fit into shared memory
+// const int SHARED_MEM =
+//     96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
+
+static constexpr int min_thread_n = 64;
+static constexpr int min_thread_k = 64;
+
+#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
+                      GROUP_BLOCKS, NUM_THREADS)                               \
+  else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS &&           \
+           thread_k_blocks == THREAD_K_BLOCKS &&                               \
+           has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS &&   \
+           num_threads == NUM_THREADS) {                                       \
+    cudaFuncSetAttribute(                                                      \
+        MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS,  \
+                  STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>,                        \
+        cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);          \
+    MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS,      \
+              STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>                             \
+        <<<blocks, NUM_THREADS, max_shared_mem, stream>>>(                     \
+            A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr,      \
+            g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx,             \
+            num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks,           \
+            replicate_input, apply_weights, m_block, max_par,                  \
+            cfg_max_m_blocks);                                                 \
+  }
+
+#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)   \
+  __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \
+  __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
+  __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \
+  __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \
+  __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
+
+}  // namespace marlin_moe

+ 29 - 0
kernels/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu

@@ -0,0 +1,29 @@
+#include "marlin_moe_kernel_ku4b8.h"
+
+namespace marlin_moe {
+
+// We return bool so we can create these different kernel calls as a sequence
+// of if-elseif's.
+bool call_marlin_moe_kernel_ku4b8(
+    aphrodite::ScalarType const& q_type, int thread_n_blocks,
+    int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads,
+    int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
+    const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
+    const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
+    int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
+    int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
+    bool replicate_input, bool apply_weights, int m_block, int max_par,
+    int cfg_max_m_blocks) {
+  if (false) {
+  }
+  GPTQ_CALL_IF_MOE(aphrodite::kU4B8, 16, 4, 256)
+  GPTQ_CALL_IF_MOE(aphrodite::kU4B8, 8, 8, 256)
+  GPTQ_CALL_IF_MOE(aphrodite::kU4B8, 8, 4, 128)
+  GPTQ_CALL_IF_MOE(aphrodite::kU4B8, 4, 8, 128)
+  else {
+    return false;
+  }
+  return true;
+}
+
+}  // namespace marlin_moe

+ 20 - 0
kernels/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h

@@ -0,0 +1,20 @@
+#pragma once
+
+#include "marlin_moe_kernel.h"
+
+namespace marlin_moe {
+
+// We return bool so we can create these different kernel calls as a sequence
+// of if-elseif's.
+bool call_marlin_moe_kernel_ku4b8(
+    aphrodite::ScalarType const& q_type, int thread_n_blocks,
+    int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads,
+    int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
+    const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
+    const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
+    int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
+    int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
+    bool replicate_input, bool apply_weights, int m_block, int max_par,
+    int cfg_max_m_blocks);
+
+}  // namespace marlin_moe

+ 29 - 0
kernels/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu

@@ -0,0 +1,29 @@
+#include "marlin_moe_kernel_ku8b128.h"
+
+namespace marlin_moe {
+
+// We return bool so we can create these different kernel calls as a sequence
+// of if-elseif's.
+bool call_marlin_moe_kernel_ku8b128(
+    aphrodite::ScalarType const& q_type, int thread_n_blocks,
+    int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads,
+    int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
+    const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
+    const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
+    int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
+    int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
+    bool replicate_input, bool apply_weights, int m_block, int max_par,
+    int cfg_max_m_blocks) {
+  if (false) {
+  }
+  GPTQ_CALL_IF_MOE(aphrodite::kU8B128, 16, 4, 256)
+  GPTQ_CALL_IF_MOE(aphrodite::kU8B128, 8, 8, 256)
+  GPTQ_CALL_IF_MOE(aphrodite::kU8B128, 8, 4, 128)
+  GPTQ_CALL_IF_MOE(aphrodite::kU8B128, 4, 8, 128)
+  else {
+    return false;
+  }
+  return true;
+}
+
+}  // namespace marlin_moe

+ 18 - 0
kernels/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h

@@ -0,0 +1,18 @@
+#pragma once
+
+#include "marlin_moe_kernel.h"
+
+namespace marlin_moe {
+
+bool call_marlin_moe_kernel_ku8b128(
+    aphrodite::ScalarType const& q_type, int thread_n_blocks,
+    int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads,
+    int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
+    const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
+    const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
+    int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
+    int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
+    bool replicate_input, bool apply_weights, int m_block, int max_par,
+    int cfg_max_m_blocks);
+
+}

File diff suppressed because it is too large
+ 27 - 1092
kernels/moe/marlin_moe_ops.cu


+ 5 - 2
kernels/moe/marlin_moe_ops.h

@@ -2,11 +2,14 @@
 
 #include <torch/all.h>
 
+#include "core/scalar_type.hpp"
+
 torch::Tensor marlin_gemm_moe(
     const torch::Tensor& a, const torch::Tensor& b_q_weights,
     const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
     const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
     const torch::Tensor& g_idx, const torch::Tensor& perm,
-    torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
-    bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
+    torch::Tensor& workspace, aphrodite::ScalarTypeTorchPtr const& b_q_type,
+    int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
+    int64_t num_experts, int64_t topk, int64_t moe_block_size,
     bool replicate_input, bool apply_weights);

+ 2 - 3
kernels/moe/softmax.cu → kernels/moe/topk_softmax_kernels.cu

@@ -1,7 +1,6 @@
 /*
  * Adapted from
  * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
- * Copyright (c) 2024, The PygmalionAI team.
  * Copyright (c) 2024, The vLLM team.
  * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
  * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
@@ -253,8 +252,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
 
   // Determine the pointer type to use to read in the data depending on the
   // BYTES_PER_LDG template param. In theory, this can support all powers of 2
-  // up to 16. NOTE: The original implementation uses CUTLASS aligned array
-  // here. We defined our own aligned array and use it here to avoid the
+  // up to 16. NOTE(woosuk): The original implementation uses CUTLASS aligned
+  // array here. We defined our own aligned array and use it here to avoid the
   // dependency on CUTLASS.
   using AccessType = AlignedArray<float, ELTS_PER_LDG>;
 

+ 5 - 4
kernels/moe/torch_bindings.cpp

@@ -17,10 +17,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
   m.def(
       "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
       "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
-      "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
-      "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
-      "bool replicate_input, bool apply_weights) -> Tensor");
-
+      "g_idx, Tensor! perm, Tensor! workspace, "
+      "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
+      "int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
+      "int moe_block_size, bool replicate_input, bool apply_weights)"
+      " -> Tensor");
   m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
 #endif
 }

+ 91 - 225
kernels/quantization/gptq_marlin/gptq_marlin.cu

@@ -52,10 +52,9 @@ template <typename scalar_t,  // compute dtype, half or nv_float16
           const int thread_k_blocks,  // same for k dimension (reduction)
           const int stages,  // number of stages for the async global->shared
                              // fetch pipeline
-          const bool has_act_order,     // whether act_order is enabled
-          const int group_blocks = -1,  // number of consecutive 16x16 blocks
-                                        // with a separate quantization scale
-          const bool is_zp_float        // is zero point of float16 type?
+          const bool has_act_order,    // whether act_order is enabled
+          const int group_blocks = -1  // number of consecutive 16x16 blocks
+                                       // with a separate quantization scale
           >
 __global__ void Marlin(
     const int4* __restrict__ A,  // fp16 input matrix of shape mxk
@@ -81,7 +80,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                torch::Tensor& workspace,
                                aphrodite::ScalarTypeTorchPtr const& b_q_type,
                                int64_t size_m, int64_t size_n, int64_t size_k,
-                               bool is_k_full, bool has_zp, bool is_zp_float) {
+                               bool is_k_full, bool has_zp) {
   TORCH_CHECK_NOT_IMPLEMENTED(false,
                               "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
   return torch::empty({1, 1});
@@ -389,17 +388,6 @@ __device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
   frag_b[1] = __hsub2(frag_b[1], zp);
 }
 
-template <typename scalar_t>
-__device__ inline void sub_zpf(typename ScalarType<scalar_t>::FragB& frag_b,
-                               typename ScalarType<scalar_t>::FragZPF& frag_zpf,
-                               int i) {
-  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
-  scalar_t2 zp =
-      ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zpf)[i]);
-  frag_b[0] = __hsub2(frag_b[0], zp);
-  frag_b[1] = __hsub2(frag_b[1], zp);
-}
-
 // Same as above, but for act_order (each K is multiplied individually)
 template <typename scalar_t>
 __device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
@@ -430,15 +418,6 @@ __device__ inline void scale_float(float* c,
   c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
 }
 
-// Given 2 floats subtract by 2 zero points (halves)
-template <typename scalar_t>
-__device__ inline void sub_zpf_float(
-    float* c, typename ScalarType<scalar_t>::FragZPF& zp) {
-  scalar_t* zp_ptr = reinterpret_cast<scalar_t*>(&zp);
-  c[0] = __fsub_rn(c[0], ScalarType<scalar_t>::num2float(zp_ptr[0]));
-  c[1] = __fsub_rn(c[1], ScalarType<scalar_t>::num2float(zp_ptr[1]));
-}
-
 // Wait until barrier reaches `count`, then lock for current threadblock.
 __device__ inline void barrier_acquire(int* lock, int count) {
   if (threadIdx.x == 0) {
@@ -535,11 +514,10 @@ template <typename scalar_t,  // compute dtype, half or nv_float16
           const int thread_k_blocks,  // same for k dimension (reduction)
           const int stages,  // number of stages for the async global->shared
                              // fetch pipeline
-          const bool has_act_order,     // whether act_order is enabled
-          const bool has_zp,            // whether zero-points are enabled
-          const int group_blocks = -1,  // number of consecutive 16x16 blocks
-                                        // with a separate quantization scale
-          const bool is_zp_float        // is zero point of float16 type?
+          const bool has_act_order,    // whether act_order is enabled
+          const bool has_zp,           // whether zero-points are enabled
+          const int group_blocks = -1  // number of consecutive 16x16 blocks
+                                       // with a separate quantization scale
           >
 __global__ void Marlin(
     const int4* __restrict__ A,  // fp16 input matrix of shape mxk
@@ -576,7 +554,6 @@ __global__ void Marlin(
   using FragC = typename ScalarType<scalar_t>::FragC;
   using FragS = typename ScalarType<scalar_t>::FragS;
   using FragZP = typename ScalarType<scalar_t>::FragZP;
-  using FragZPF = typename ScalarType<scalar_t>::FragZPF;
 
   static constexpr auto w_type = aphrodite::ScalarType::from_id(w_type_id);
 
@@ -713,10 +690,8 @@ __global__ void Marlin(
   int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
 
   // Zero-points sizes/strides
-  int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4;
-  constexpr int zp_sh_stride = is_zp_float
-                                   ? 16 * thread_n_blocks / 8
-                                   : ((16 * thread_n_blocks) / pack_factor) / 4;
+  int zp_gl_stride = (prob_n / pack_factor) / 4;
+  constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
   constexpr int zp_tb_groups = s_tb_groups;
   constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
   int zp_gl_rd_delta = zp_gl_stride;
@@ -791,15 +766,9 @@ __global__ void Marlin(
   constexpr int num_ints_per_thread = 8 / pack_factor;
   int zp_sh_rd;
   if constexpr (has_zp) {
-    if constexpr (is_zp_float) {
-      if constexpr (group_blocks != -1)
-        zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
-                   (threadIdx.x % 32) / 4;
-    } else {
-      zp_sh_rd = num_ints_per_thread * num_col_threads *
-                     ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
-                 num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
-    }
+    zp_sh_rd = num_ints_per_thread * num_col_threads *
+                   ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
+               num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
   }
 
   // Precompute which thread should not read memory in which iterations; this is
@@ -861,7 +830,6 @@ __global__ void Marlin(
   FragS act_frag_s[2][4][4];             // For act-order
   int frag_qzp[2][num_ints_per_thread];  // Zero-points
   FragZP frag_zp;                        // Zero-points in fp16
-  FragZPF frag_zpf[2][4];                // float16 zero-points
 
   // Zero accumulators.
   auto zero_accums = [&]() {
@@ -1156,7 +1124,7 @@ __global__ void Marlin(
     // has_zp implies AWQ, which doesn't have act_order,
     static_assert(!has_zp || group_blocks != 0);
 
-    if constexpr (has_zp && !is_zp_float) {
+    if constexpr (has_zp) {
       int pipe = full_pipe % stages;
 
       if constexpr (group_blocks == -1) {
@@ -1200,40 +1168,11 @@ __global__ void Marlin(
         }
       }
     }
-
-    if constexpr (has_zp && is_zp_float) {
-      int pipe = full_pipe % stages;
-
-      if constexpr (group_blocks != -1) {
-        if constexpr (group_blocks >= thread_k_blocks) {
-          int4* sh_zp_stage =
-              sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
-                                     (pipe / (group_blocks / thread_k_blocks)));
-          reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
-        } else {
-          int warp_id = threadIdx.x / 32;
-          int n_warps = thread_n_blocks / 4;
-
-          int warp_row = warp_id / n_warps;
-
-          int cur_k = warp_row * 16;
-          cur_k += k_iter_size * (k % b_sh_wr_iters);
-
-          int k_blocks = cur_k / 16;
-          int cur_group_id = k_blocks / group_blocks;
-
-          int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
-
-          reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
-              sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride];
-        }
-      }
-    }
   };
 
   // Execute the actual tensor core matmul of a sub-tile.
   auto matmul = [&](int k) {
-    if constexpr (has_zp && !is_zp_float) {
+    if constexpr (has_zp) {
       FragB frag_zp_0;
       FragB frag_zp_1;
       int zp_quant_0, zp_quant_1;
@@ -1278,14 +1217,10 @@ __global__ void Marlin(
       frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
 
       // Apply zero-point to frag_b0
-      if constexpr (has_zp && !is_zp_float) {
+      if constexpr (has_zp) {
         sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
       }
 
-      if constexpr (has_zp && is_zp_float && group_blocks != -1) {
-        sub_zpf<scalar_t>(frag_b0, frag_zpf[k % 2][j], 0);
-      }
-
       // Apply scale to frag_b0
       if constexpr (has_act_order) {
         scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],
@@ -1298,14 +1233,10 @@ __global__ void Marlin(
       }
 
       // Apply zero-point to frag_b1
-      if constexpr (has_zp && !is_zp_float) {
+      if constexpr (has_zp) {
         sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
       }
 
-      if constexpr (has_zp && is_zp_float && group_blocks != -1) {
-        sub_zpf<scalar_t>(frag_b1, frag_zpf[k % 2][j], 1);
-      }
-
       // Apply scale to frag_b1
       if constexpr (has_act_order) {
         scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
@@ -1577,7 +1508,7 @@ __global__ void Marlin(
         fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
       }
 
-      if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
+      if constexpr (has_zp && group_blocks == -1) {
         if (i == 0) {
           fetch_zp_to_shared();
         }
@@ -1764,27 +1695,23 @@ __global__ void Marlin(
 }
 
   #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
-                    HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS,          \
-                    IS_ZP_FLOAT)                                               \
-    if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS &&              \
-        thread_n_blocks == THREAD_N_BLOCKS &&                                  \
-        thread_k_blocks == THREAD_K_BLOCKS &&                                  \
-        has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP &&                  \
-        group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS &&          \
-        is_zp_float == IS_ZP_FLOAT) {                                          \
-      if constexpr (!IS_ZP_FLOAT || std::is_same<scalar_t, half>::value) {     \
-        cudaFuncSetAttribute(                                                  \
-            Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS,        \
-                   THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages,              \
-                   HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>,          \
-            cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);      \
-        Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS,            \
-               THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER,   \
-               HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>                              \
-            <<<blocks, NUM_THREADS, max_shared_mem, stream>>>(                 \
-                A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,      \
-                num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce);   \
-      }                                                                        \
+                    HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS)          \
+    else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS &&         \
+             thread_n_blocks == THREAD_N_BLOCKS &&                             \
+             thread_k_blocks == THREAD_K_BLOCKS &&                             \
+             has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP &&             \
+             group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) {     \
+      cudaFuncSetAttribute(                                                    \
+          Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS,          \
+                 THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
+                 HAS_ZP, GROUP_BLOCKS>,                                        \
+          cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);        \
+      Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS,              \
+             THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER,     \
+             HAS_ZP, GROUP_BLOCKS>                                             \
+          <<<blocks, NUM_THREADS, max_shared_mem, stream>>>(                   \
+              A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,        \
+              num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce);     \
     }
 
 typedef struct {
@@ -1976,96 +1903,51 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
 }
 
   #define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)             \
-    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS,   \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS,   \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS,   \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS,   \
-              false)                                                        \
+    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS)   \
+    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS)   \
+    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS)   \
+    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS)   \
                                                                             \
-    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS,  \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS,  \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS,  \
-              false)                                                        \
+    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
+    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)  \
                                                                             \
-    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS,  \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS,  \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS,  \
-              false)                                                        \
+    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
+    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)  \
                                                                             \
-    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS,  \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS,  \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS,  \
-              false)                                                        \
+    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
+    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)  \
                                                                             \
-    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS,  \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS,  \
-              false)                                                        \
-    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS,  \
-              false)
+    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
+    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
 
   #define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)             \
-    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
-              false)                                                       \
-    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS,  \
-              false)                                                       \
-    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS,  \
-              false)                                                       \
-    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS,  \
-              false)                                                       \
+    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
+    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)  \
                                                                            \
-    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
-              false)                                                       \
-    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS,  \
-              false)                                                       \
-    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS,  \
-              false)                                                       \
-    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS,  \
-              false)                                                       \
+    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
+    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)  \
                                                                            \
-    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
-              false)                                                       \
-    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS,  \
-              false)                                                       \
-    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS,  \
-              false)                                                       \
-    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS,  \
-              false)                                                       \
+    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
+    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)  \
                                                                            \
-    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
-              false)                                                       \
-    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS,  \
-              false)                                                       \
-    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS,  \
-              false)                                                       \
-    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false)
-
-  // We currently have 4-bit models only with group_blocks == 4
-  #define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS)            \
-    __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
-              true)                                                       \
-    __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
-              true)                                                       \
-    __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
-              true)                                                       \
-    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true)
+    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
+    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS)  \
+    __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
 
 template <typename scalar_t>
 void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
@@ -2074,7 +1956,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
                aphrodite::ScalarType const& q_type, bool has_act_order,
                bool is_k_full, bool has_zp, int num_groups, int group_size,
                int dev, cudaStream_t stream, int thread_k, int thread_n,
-               int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
+               int sms, int max_par, bool use_fp32_reduce) {
   if (has_zp) {
     TORCH_CHECK(
         q_type == aphrodite::kU4 || q_type == aphrodite::kU8,
@@ -2227,15 +2109,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
     AWQ_CALL_IF(aphrodite::kU8, 8, 8, 256)
     AWQ_CALL_IF(aphrodite::kU8, 8, 4, 128)
     AWQ_CALL_IF(aphrodite::kU8, 4, 8, 128)
-
-    HQQ_CALL_IF(aphrodite::kU4, 16, 4, 256)
-    HQQ_CALL_IF(aphrodite::kU4, 8, 8, 256)
-    HQQ_CALL_IF(aphrodite::kU4, 8, 4, 128)
-    HQQ_CALL_IF(aphrodite::kU4, 4, 8, 128)
-    HQQ_CALL_IF(aphrodite::kU8, 16, 4, 256)
-    HQQ_CALL_IF(aphrodite::kU8, 8, 8, 256)
-    HQQ_CALL_IF(aphrodite::kU8, 8, 4, 128)
-    HQQ_CALL_IF(aphrodite::kU8, 4, 8, 128)
+    else {
+      TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
+                  ", ", prob_k, "]", ", has_act_order = ", has_act_order,
+                  ", num_groups = ", num_groups, ", group_size = ", group_size,
+                  ", thread_m_blocks = ", thread_m_blocks,
+                  ", thread_n_blocks = ", thread_n_blocks,
+                  ", thread_k_blocks = ", thread_k_blocks,
+                  ", num_bits = ", num_bits);
+    }
 
     A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
     C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
@@ -2251,7 +2133,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                aphrodite::ScalarTypeTorchPtr const& b_q_type,
                                int64_t size_m, int64_t size_n, int64_t size_k,
                                bool is_k_full, bool has_zp,
-                               bool use_fp32_reduce, bool is_zp_float) {
+                               bool use_fp32_reduce) {
   if (has_zp) {
     TORCH_CHECK(*b_q_type == aphrodite::kU4 || *b_q_type == aphrodite::kU8,
                 "b_q_type must be u4 or u8 when has_zp = True. Got = ",
@@ -2263,12 +2145,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
         b_q_type->str());
   }
 
-  if (has_zp && is_zp_float) {
-    TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
-                "Computation type must be float16 (half) when using float zero "
-                "points.");
-  }
-
   int pack_factor = 32 / b_q_type->size_bits();
 
   // Verify A
@@ -2378,22 +2254,12 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
   if (has_zp) {
     int rank = b_zeros.sizes().size();
     TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
-    if (is_zp_float) {
-      TORCH_CHECK(b_zeros.size(1) == size_n,
-                  "b_zeros dim 1 = ", b_zeros.size(1),
-                  " is not size_n = ", size_n);
-      TORCH_CHECK(num_groups == b_zeros.size(0),
-                  "b_zeros dim 0 = ", b_zeros.size(0),
-                  " is not num_groups = ", num_groups);
-      TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
-    } else {
-      TORCH_CHECK(b_zeros.size(0) == num_groups,
-                  "b_zeros dim 0 = ", b_zeros.size(0),
-                  " is not num_groups = ", num_groups);
-      TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
-                  "b_zeros dim 1 = ", b_zeros.size(1),
-                  " is not size_n / pack_factor = ", size_n / pack_factor);
-    }
+    TORCH_CHECK(b_zeros.size(0) == num_groups,
+                "b_zeros dim 0 = ", b_zeros.size(0),
+                " is not num_groups = ", num_groups);
+    TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
+                "b_zeros dim 1 = ", b_scales.size(1),
+                " is not size_n / pack_factor = ", size_n / pack_factor);
   }
 
   // Verify workspace size
@@ -2413,7 +2279,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
         a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
         workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
         num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
-        thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
+        thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
   } else if (a.scalar_type() == at::ScalarType::BFloat16) {
     marlin::marlin_mm<nv_bfloat16>(
         a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
@@ -2422,7 +2288,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
         perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
         workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
         num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
-        thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
+        thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
   } else {
     TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
   }
@@ -2430,4 +2296,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
   return c;
 }
 
-#endif
+#endif

+ 1 - 1
kernels/quantization/gptq_marlin/gptq_marlin_repack.cu

@@ -353,4 +353,4 @@ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
   return torch::empty_symint(
       {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
       options);
-}
+}

+ 1 - 1
kernels/quantization/gptq_marlin/marlin.cuh

@@ -84,4 +84,4 @@ __device__ inline void cp_async_wait() {
 
 #endif
 
-}  // namespace marlin
+}  // namespace marlin

+ 1 - 3
kernels/quantization/gptq_marlin/marlin_dtypes.cuh

@@ -24,7 +24,6 @@ class ScalarType<half> {
   using FragC = Vec<float, 4>;
   using FragS = Vec<half2, 1>;
   using FragZP = Vec<half2, 4>;
-  using FragZPF = Vec<half2, 1>;
 
   static __device__ float inline num2float(const half x) {
     return __half2float(x);
@@ -54,7 +53,6 @@ class ScalarType<nv_bfloat16> {
   using FragC = Vec<float, 4>;
   using FragS = Vec<nv_bfloat162, 1>;
   using FragZP = Vec<nv_bfloat162, 4>;
-  using FragZPF = Vec<nv_bfloat162, 1>;
 
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
   static __device__ float inline num2float(const nv_bfloat16 x) {
@@ -78,4 +76,4 @@ class ScalarType<nv_bfloat16> {
 
 }  // namespace marlin
 
-#endif
+#endif

+ 1 - 1
kernels/quantization/quant_ops.h

@@ -78,7 +78,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                aphrodite::ScalarTypeTorchPtr const& b_q_type,
                                int64_t size_m, int64_t size_n, int64_t size_k,
                                bool is_k_full, bool has_zp,
-                               bool use_fp32_reduce, bool is_zp_float);
+                               bool use_fp32_reduce);
 
 torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
                                  int64_t size_k, int64_t size_n,

+ 24 - 13
kernels/torch_bindings.cpp

@@ -198,7 +198,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
       "__torch__.torch.classes._core_C.ScalarType b_q_type, "
       "int size_m, int size_n, int size_k, bool is_k_full, "
-      "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
+      "bool has_zp, bool use_fp32_reduce) -> Tensor");
   ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
 
   // gptq_marlin repack from GPTQ.
@@ -449,22 +449,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "bool silu_activation) -> Tensor");
   ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
 
-  ops.def("fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor? alibi_slopes, "
-          "float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, "
-          "float softcap, bool return_softmax, Generator? gen) -> Tensor[]");
+  ops.def(
+      "fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor? alibi_slopes, "
+      "float p_dropout, float softmax_scale, bool is_causal, int "
+      "window_size_left, int window_size_right, "
+      "float softcap, bool return_softmax, Generator? gen) -> Tensor[]");
   ops.impl("fwd", torch::kCUDA, &mha_fwd);
 
-  ops.def("varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, "
-          "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? block_table, Tensor? alibi_slopes, "
-          "int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
-          "bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, "
-          "Generator? gen) -> Tensor[]");
+  ops.def(
+      "varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor "
+      "cu_seqlens_q, "
+      "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? block_table, Tensor? "
+      "alibi_slopes, "
+      "int max_seqlen_q, int max_seqlen_k, float p_dropout, float "
+      "softmax_scale, bool zero_tensors, "
+      "bool is_causal, int window_size_left, int window_size_right, float "
+      "softcap, bool return_softmax, "
+      "Generator? gen) -> Tensor[]");
   ops.impl("varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
 
-  ops.def("fwd_kvcache(Tensor! q, Tensor kcache, Tensor vcache, Tensor? k, Tensor? v, Tensor? seqlens_k, "
-          "Tensor? rotary_cos, Tensor? rotary_sin, Tensor? cache_batch_idx, Tensor? block_table, Tensor? alibi_slopes, "
-          "Tensor!? out, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, "
-          "float softcap, bool is_rotary_interleaved, int num_splits) -> Tensor[]");
+  ops.def(
+      "fwd_kvcache(Tensor! q, Tensor kcache, Tensor vcache, Tensor? k, Tensor? "
+      "v, Tensor? seqlens_k, "
+      "Tensor? rotary_cos, Tensor? rotary_sin, Tensor? cache_batch_idx, "
+      "Tensor? block_table, Tensor? alibi_slopes, "
+      "Tensor!? out, float softmax_scale, bool is_causal, int "
+      "window_size_left, int window_size_right, "
+      "float softcap, bool is_rotary_interleaved, int num_splits) -> Tensor[]");
   ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache);
 #endif
 }

+ 223 - 4
tests/kernels/test_moe.py

@@ -2,14 +2,22 @@
 
 Run `pytest tests/kernels/test_moe.py`.
 """
+from typing import List
+
 import pytest
 import torch
 from transformers import MixtralConfig
 from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
 
+from aphrodite.common.utils import seed_everything
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.fused_moe import fused_moe
+from aphrodite.modeling.layers.fused_moe.fused_marlin_moe import (
+    fused_marlin_moe, single_marlin_moe)
+from aphrodite.modeling.layers.fused_moe.fused_moe import fused_topk
 from aphrodite.modeling.models.mixtral import MixtralMoE
+from aphrodite.quantization.utils.marlin_utils_test import marlin_quantize
+from aphrodite.scalar_type import scalar_types
 
 
 def torch_moe(a, w1, w2, score, topk):
@@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
             topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
 
 
+def torch_moe_single(a, w, score, topk):
+    B, D = a.shape
+    a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
+    out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
+    score = torch.softmax(score, dim=-1, dtype=torch.float32)
+    _, topk_ids = torch.topk(score, topk)
+    topk_ids = topk_ids.view(-1)
+    for i in range(w.shape[0]):
+        mask = topk_ids == i
+        if mask.sum():
+            out[mask] = a[mask] @ w[i].transpose(0, 1)
+    return (out.view(B, -1, w.shape[1])).sum(dim=1)
+
+
 @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
 @pytest.mark.parametrize("n", [2048, 256, 1024])
 @pytest.mark.parametrize("k", [128, 511, 1024])
@@ -43,11 +65,11 @@ def test_fused_moe(
     topk: int,
     dtype: torch.dtype,
 ):
-    a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
-    w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
-    w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
+    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
+    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
+    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
 
-    score = torch.randn((m, e), device='cuda', dtype=dtype)
+    score = torch.randn((m, e), device="cuda", dtype=dtype)
     triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
     torch_output = torch_moe(a, w1, w2, score, topk)
     torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
@@ -99,3 +121,200 @@ def test_mixtral_moe(dtype: torch.dtype):
                                aphrodite_states,
                                rtol=mixtral_moe_tol[dtype],
                                atol=mixtral_moe_tol[dtype])
+
+
+def stack_and_dev(tensors: List[torch.Tensor]):
+    dev = tensors[0].device
+    return torch.stack(tensors, dim=0).to(dev)
+
+
+def compute_max_diff(output, output_ref):
+    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
+        torch.abs(output_ref))
+
+
+@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
+@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
+@pytest.mark.parametrize("k", [128, 1024, 512])
+@pytest.mark.parametrize("e", [4, 8, 64])
+@pytest.mark.parametrize("topk", [2, 6])
+@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
+@pytest.mark.parametrize("act_order", [True, False])
+@pytest.mark.parametrize("num_bits", [4, 8])
+def test_fused_marlin_moe(
+    m: int,
+    n: int,
+    k: int,
+    e: int,
+    topk: int,
+    group_size: int,
+    act_order: bool,
+    num_bits: int,
+):
+    seed_everything(7)
+
+    if topk > e:
+        return
+
+    # Filter act_order
+    if act_order:
+        if group_size == -1:
+            return
+        if group_size in (k, n):
+            return
+
+    quant_type = (scalar_types.uint4b8
+                  if num_bits == 4 else scalar_types.uint8b128)
+    dtype = torch.float16
+    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
+    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
+    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
+
+    w_ref1_l = []
+    qweight1_l = []
+    scales1_l = []
+    g_idx1_l = []
+    sort_indices1_l = []
+
+    for i in range(w1.shape[0]):
+        test_perm = torch.randperm(k)
+        w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
+            w1[i].transpose(1, 0), quant_type, group_size, act_order,
+            test_perm)
+        w_ref1_l.append(w_ref1)
+        qweight1_l.append(qweight1)
+        scales1_l.append(scales1)
+        g_idx1_l.append(g_idx1)
+        sort_indices1_l.append(sort_indices1)
+
+    w_ref1 = stack_and_dev(w_ref1_l)
+    qweight1 = stack_and_dev(qweight1_l).contiguous()
+    scales1 = stack_and_dev(scales1_l)
+    g_idx1 = stack_and_dev(g_idx1_l)
+    sort_indices1 = stack_and_dev(sort_indices1_l)
+
+    w_ref2_l = []
+    qweight2_l = []
+    scales2_l = []
+    g_idx2_l = []
+    sort_indices2_l = []
+
+    for i in range(w2.shape[0]):
+        test_perm = torch.randperm(n)
+        w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
+            w2[i].transpose(1, 0), quant_type, group_size, act_order,
+            test_perm)
+        w_ref2_l.append(w_ref2)
+        qweight2_l.append(qweight2)
+        scales2_l.append(scales2)
+        g_idx2_l.append(g_idx2)
+        sort_indices2_l.append(sort_indices2)
+
+    w_ref2 = stack_and_dev(w_ref2_l)
+    qweight2 = stack_and_dev(qweight2_l).contiguous()
+    scales2 = stack_and_dev(scales2_l)
+    g_idx2 = stack_and_dev(g_idx2_l)
+    sort_indices2 = stack_and_dev(sort_indices2_l)
+
+    score = torch.randn((m, e), device="cuda", dtype=dtype)
+
+    topk_weights, topk_ids = fused_topk(a, score, topk, False)
+
+    triton_output = fused_moe(
+        a,
+        w_ref1.transpose(1, 2).contiguous(),
+        w_ref2.transpose(1, 2).contiguous(),
+        score,
+        topk,
+        renormalize=False,
+    )
+    marlin_output = fused_marlin_moe(
+        a,
+        qweight1,
+        qweight2,
+        score,
+        g_idx1,
+        g_idx2,
+        sort_indices1,
+        sort_indices2,
+        topk_weights,
+        topk_ids,
+        w1_scale=scales1,
+        w2_scale=scales2,
+        num_bits=num_bits,
+    )
+
+    assert compute_max_diff(marlin_output, triton_output) < 4e-2
+
+
+@pytest.mark.skip("This test is here for the sake of debugging, "
+                  "don't run it in automated tests.")
+@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
+@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
+@pytest.mark.parametrize("k", [128, 1024, 512])
+@pytest.mark.parametrize("e", [4, 8, 64])
+@pytest.mark.parametrize("topk", [2, 6])
+@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
+@pytest.mark.parametrize("act_order", [True, False])
+@pytest.mark.parametrize("num_bits", [4, 8])
+def test_single_marlin_moe_multiply(
+    m: int,
+    n: int,
+    k: int,
+    e: int,
+    topk: int,
+    group_size: int,
+    act_order: bool,
+    num_bits: int,
+):
+    if topk > e:
+        return
+
+    # Filter act_order
+    if act_order:
+        if group_size == -1:
+            return
+        if group_size == k:
+            return
+
+    quant_type = (scalar_types.uint4b8
+                  if num_bits == 4 else scalar_types.uint8b128)
+    dtype = torch.float16
+    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
+    w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
+
+    w_ref_l = []
+    qweights_l = []
+    scales_l = []
+    g_idx_l = []
+    sort_indices_l = []
+
+    for i in range(w.shape[0]):
+        test_perm = torch.randperm(k)
+        w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
+            w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
+        w_ref_l.append(w_ref)
+        qweights_l.append(qweight)
+        scales_l.append(scales)
+        g_idx_l.append(g_idx)
+        sort_indices_l.append(sort_indices)
+
+    w_ref = stack_and_dev(w_ref_l)
+    qweight = stack_and_dev(qweights_l).contiguous()
+    scales = stack_and_dev(scales_l)
+    g_idx = stack_and_dev(g_idx_l)
+    sort_indices = stack_and_dev(sort_indices_l)
+
+    score = torch.randn((m, e), device="cuda", dtype=dtype)
+    marlin_output = single_marlin_moe(a,
+                                      qweight,
+                                      scales,
+                                      score,
+                                      g_idx,
+                                      sort_indices,
+                                      topk,
+                                      renormalize=False,
+                                      num_bits=num_bits)
+    torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
+
+    assert compute_max_diff(marlin_output, torch_output) < 1e-2

Some files were not shown because too many files changed in this diff