Browse Source

feat: switch from `PYBIND11_MODULE` to `TORCH_LIBRARY` (#569)

* feat: switch from `PYBIND11_MODULE` to `TORCH_LIBRARY`

* fix: fp8 imports
AlpinDale 7 months ago
parent
commit
156f577f79
79 changed files with 3340 additions and 2916 deletions
  1. 81 121
      CMakeLists.txt
  2. 477 0
      aphrodite/_custom_ops.py
  3. 5 5
      aphrodite/attention/backends/flash_attn.py
  4. 1 1
      aphrodite/attention/backends/flashinfer.py
  5. 5 6
      aphrodite/attention/ops/paged_attn.py
  6. 4 4
      aphrodite/common/utils.py
  7. 4 4
      aphrodite/modeling/layers/activation.py
  8. 2 3
      aphrodite/modeling/layers/fused_moe/fused_moe.py
  9. 1 1
      aphrodite/modeling/layers/layernorm.py
  10. 1 1
      aphrodite/modeling/layers/rotary_embedding.py
  11. 0 10
      aphrodite/quantization/__init__.py
  12. 1 8
      aphrodite/quantization/aqlm.py
  13. 1 10
      aphrodite/quantization/autoquant.py
  14. 1 8
      aphrodite/quantization/awq.py
  15. 5 2
      aphrodite/quantization/compressed_tensors/compressed_tensors.py
  16. 4 5
      aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
  17. 4 5
      aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
  18. 0 48
      aphrodite/quantization/compressed_tensors/schemes/utils.py
  19. 1 8
      aphrodite/quantization/exl2.py
  20. 24 66
      aphrodite/quantization/fp8.py
  21. 1 8
      aphrodite/quantization/gguf.py
  22. 1 8
      aphrodite/quantization/gptq.py
  23. 1 8
      aphrodite/quantization/gptq_marlin.py
  24. 3 10
      aphrodite/quantization/gptq_marlin_24.py
  25. 1 8
      aphrodite/quantization/marlin.py
  26. 3 12
      aphrodite/quantization/quip.py
  27. 1 8
      aphrodite/quantization/squeezellm.py
  28. 6 5
      cmake/cpu_extension.cmake
  29. 8 3
      cmake/utils.cmake
  30. 66 76
      kernels/activation_kernels.cu
  31. 41 35
      kernels/all_reduce/custom_all_reduce.cu
  32. 17 15
      kernels/attention/attention_kernels.cu
  33. 9 5
      kernels/cache.h
  34. 8 5
      kernels/cache_kernels.cu
  35. 14 12
      kernels/cpu/attention.cpp
  36. 8 5
      kernels/cpu/cache.cpp
  37. 1 1
      kernels/cpu/cpu_types.hpp
  38. 16 16
      kernels/cpu/layernorm.cpp
  39. 1 1
      kernels/cpu/pos_encoding.cpp
  40. 0 44
      kernels/cpu/pybind.cpp
  41. 106 0
      kernels/cpu/torch_bindings.cpp
  42. 2 7
      kernels/cuda_utils.h
  43. 17 22
      kernels/cuda_utils_kernels.cu
  44. 1 1
      kernels/dispatch_utils.h
  45. 121 120
      kernels/layernorm_kernels.cu
  46. 111 102
      kernels/moe/align_block_size_kernel.cu
  47. 0 7
      kernels/moe/moe_ops.cpp
  48. 4 6
      kernels/moe/moe_ops.h
  49. 1 1
      kernels/moe/softmax.cu
  50. 12 0
      kernels/moe/torch_bindings.cpp
  51. 27 25
      kernels/ops.h
  52. 108 127
      kernels/pos_encoding_kernels.cu
  53. 3 3
      kernels/punica/punica_ops.cu
  54. 4 4
      kernels/punica/punica_ops.h
  55. 0 11
      kernels/punica/punica_pybind.cpp
  56. 18 0
      kernels/punica/torch_bindings.cpp
  57. 0 82
      kernels/pybind.cpp
  58. 1 1
      kernels/quantization/aqlm/gemm_kernels.cu
  59. 4 4
      kernels/quantization/autoquant/int4_fp16_gemm_kernels.cu
  60. 950 764
      kernels/quantization/awq/gemm_kernels.cu
  61. 1 1
      kernels/quantization/compressed_tensors/int8_quant_kernels.cu
  62. 1 1
      kernels/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
  63. 1 1
      kernels/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
  64. 1 1
      kernels/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
  65. 101 180
      kernels/quantization/exl2/q_gemm_exl2.cu
  66. 311 305
      kernels/quantization/exl2/q_matrix.cu
  67. 1 1
      kernels/quantization/fp8/common.cu
  68. 3 3
      kernels/quantization/gptq/q_gemm.cu
  69. 1 1
      kernels/quantization/gptq_marlin/gptq_marlin.cuh
  70. 1 1
      kernels/quantization/marlin/dense/marlin_cuda_kernel.cu
  71. 1 1
      kernels/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
  72. 0 72
      kernels/quantization/quant_ops.cpp
  73. 11 50
      kernels/quantization/quant_ops.h
  74. 231 340
      kernels/quantization/quip/origin_order.cu
  75. 27 37
      kernels/quantization/squeezellm/quant_cuda_kernel.cu
  76. 22 0
      kernels/registration.h
  77. 292 0
      kernels/torch_bindings.cpp
  78. 15 32
      setup.py
  79. 1 1
      tests/benchmarks/attention.py

+ 81 - 121
CMakeLists.txt

@@ -66,19 +66,6 @@ endif()
 #
 find_package(Torch REQUIRED)
 
-#
-# Normally `torch.utils.cpp_extension.CUDAExtension` would add
-# `libtorch_python.so` for linking against an extension. Torch's cmake
-# configuration does not include this library (presumably since the cmake
-# config is used for standalone C++ binaries that link against torch).
-# The `libtorch_python.so` library defines some of the glue code between
-# torch/python via pybind and is required by APHRODITE extensions for this
-# reason. So, add it by manually using `append_torchlib_if_found` from
-# torch's cmake setup.
-#
-find_library(torch_python_LIBRARY torch_python PATHS
-  "${TORCH_INSTALL_PREFIX}/lib")
-
 #
 # Forward the non-CUDA device extensions to external CMake scripts.
 #
@@ -165,71 +152,45 @@ set(APHRODITE_EXT_SRC
   "kernels/pos_encoding_kernels.cu"
   "kernels/activation_kernels.cu"
   "kernels/layernorm_kernels.cu"
-  "kernels/cuda_utils_kernels.cu"
-  "kernels/moe/align_block_size_kernel.cu"
-  "kernels/pybind.cpp")
-
-if(APHRODITE_GPU_LANG STREQUAL "CUDA")
-  list(APPEND APHRODITE_EXT_SRC
-    "kernels/all_reduce/custom_all_reduce.cu")
-endif()
-
-define_gpu_extension_target(
-  _C
-  DESTINATION aphrodite
-  LANGUAGE ${APHRODITE_GPU_LANG}
-  SOURCES ${APHRODITE_EXT_SRC}
-  COMPILE_FLAGS ${APHRODITE_GPU_FLAGS}
-  ARCHITECTURES ${APHRODITE_GPU_ARCHES}
-  WITH_SOABI)
-
-
-#
-# _quant_C extension
-#
-
-set (APHRODITE_QUANT_EXT_SRC
-  "kernels/quantization/gptq/q_gemm.cu"
   "kernels/quantization/squeezellm/quant_cuda_kernel.cu"
-  "kernels/quantization/exl2/q_matrix.cu"
-  "kernels/quantization/exl2/q_gemm_exl2.cu"
-  "kernels/quantization/fp8/common.cu"
+  "kernels/quantization/gptq/q_gemm.cu"
   "kernels/quantization/compressed_tensors/int8_quant_kernels.cu"
-  "kernels/quantization/quant_ops.cpp")
+  "kernels/quantization/fp8/common.cu"
+  "kernels/cuda_utils_kernels.cu"
+  "kernels/moe/align_block_size_kernel.cu"
+  "kernels/torch_bindings.cpp")
 
 if(APHRODITE_GPU_LANG STREQUAL "CUDA")
-include(FetchContent)
+  include(FetchContent)
   SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
   FetchContent_Declare(
-        cutlass 
+        cutlass
         GIT_REPOSITORY https://github.com/nvidia/cutlass.git
         # CUTLASS 3.5.0
         GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
   )
   FetchContent_MakeAvailable(cutlass)
-  list(APPEND APHRODITE_QUANT_EXT_SRC
+
+  list(APPEND APHRODITE_EXT_SRC
     "kernels/quantization/aqlm/gemm_kernels.cu"
     "kernels/quantization/awq/gemm_kernels.cu"
-    "kernels/quantization/autoquant/int4_fp16_gemm_kernels.cu"
-    "kernels/quantization/autoquant/format.cu"
-    "kernels/quantization/autoquant/gemm_s4_f16.cu"
-    "kernels/quantization/gguf/gguf_kernel.cu"
+    "kernels/quantization/quip/origin_order.cu"
     "kernels/quantization/marlin/dense/marlin_cuda_kernel.cu"
     "kernels/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
     "kernels/quantization/gptq_marlin/gptq_marlin.cu"
     "kernels/quantization/gptq_marlin/gptq_marlin_repack.cu"
-    "kernels/quantization/quip/origin_order.cu"
+    "kernels/all_reduce/custom_all_reduce.cu"
     "kernels/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
     "kernels/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
     "kernels/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
 
-    #
-    # The CUTLASS kernels for Hopper require sm90a to be enabled.
-    # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
-    # That adds an extra 17MB to compiled binary, so instead we selectively enable it.
-    if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
+  #
+  # The CUTLASS kernels for Hopper require sm90a to be enabled.
+  # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
+  # That adds an extra 17MB to compiled binary, so instead we selectively enable it.
+  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
     set_source_files_properties(
-          "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
+          "kernels/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
           PROPERTIES
           COMPILE_FLAGS
           "-gencode arch=compute_90a,code=sm_90a")
@@ -238,13 +199,14 @@ include(FetchContent)
 endif()
 
 define_gpu_extension_target(
-  _quant_C
+  _C
   DESTINATION aphrodite
   LANGUAGE ${APHRODITE_GPU_LANG}
-  SOURCES ${APHRODITE_QUANT_EXT_SRC}
+  SOURCES ${APHRODITE_EXT_SRC}
   COMPILE_FLAGS ${APHRODITE_GPU_FLAGS}
   ARCHITECTURES ${APHRODITE_GPU_ARCHES}
   INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
+  USE_SABI 3
   WITH_SOABI)
 
 #
@@ -252,7 +214,7 @@ define_gpu_extension_target(
 #
 
 set(APHRODITE_MOE_EXT_SRC
-  "kernels/moe/moe_ops.cpp"
+  "kernels/moe/torch_bindings.cpp"
   "kernels/moe/softmax.cu")
 
 define_gpu_extension_target(
@@ -262,6 +224,7 @@ define_gpu_extension_target(
   SOURCES ${APHRODITE_MOE_EXT_SRC}
   COMPILE_FLAGS ${APHRODITE_GPU_FLAGS}
   ARCHITECTURES ${APHRODITE_GPU_ARCHES}
+  USE_SABI 3
   WITH_SOABI)
 
 #
@@ -288,7 +251,7 @@ set(APHRODITE_PUNICA_EXT_SRC
   "kernels/punica/bgmv/bgmv_fp32_fp32_bf16.cu"
   "kernels/punica/bgmv/bgmv_fp32_fp32_fp16.cu"
   "kernels/punica/punica_ops.cu"
-  "kernels/punica/punica_pybind.cpp")
+  "kernels/punica/torch_bindings.cpp")
 
 #
 # Copy GPU compilation flags+update for punica
@@ -325,63 +288,64 @@ if (APHRODITE_PUNICA_GPU_ARCHES)
     SOURCES ${APHRODITE_PUNICA_EXT_SRC}
     COMPILE_FLAGS ${APHRODITE_PUNICA_GPU_FLAGS}
     ARCHITECTURES ${APHRODITE_PUNICA_GPU_ARCHES}
+    USE_SABI 3
     WITH_SOABI)
 else()
   message(WARNING "Unable to create _punica_C target because none of the "
     "requested architectures (${APHRODITE_GPU_ARCHES}) are supported, i.e. >= 8.0")
 endif()
 
-#
-# _hadamard_C extension
-#
-
-set(APHRODITE_HADAMARD_EXT_SRC
-  "kernels/hadamard/fast_hadamard_transform.cpp"
-  "kernels/hadamard/fast_hadamard_transform_cuda.cu")
-
-#
-# Copy GPU compilation flags+update for hadamard
-#
-set(APHRODITE_HADAMARD_GPU_FLAGS ${APHRODITE_GPU_FLAGS})
-list(APPEND APHRODITE_HADAMARD_GPU_FLAGS
-  "-U__CUDA_NO_HALF_OPERATORS__"
-  "-U__CUDA_NO_HALF_CONVERSIONS__"
-  "-U__CUDA_NO_BFLOAT16_OPERATORS__"
-  "-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
-  "-U__CUDA_NO_BFLOAT162_OPERATORS__"
-  "-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
-  "--expt-relaxed-constexpr"
-  "--expt-extended-lambda"
-  "--use_fast_math"
-  "-lineinfo")
-
-#
-# Filter out CUDA architectures < 7.0 for hadamard.
-#
-if (${APHRODITE_GPU_LANG} STREQUAL "CUDA")
-  set(APHRODITE_HADAMARD_GPU_ARCHES)
-  foreach(ARCH ${APHRODITE_GPU_ARCHES})
-    string_to_ver(CODE_VER ${ARCH})
-    if (CODE_VER GREATER_EQUAL 6.0)
-      list(APPEND APHRODITE_HADAMARD_GPU_ARCHES ${ARCH})
-    endif()
-  endforeach()
-  message(STATUS "Hadamard target arches: ${APHRODITE_HADAMARD_GPU_ARCHES}")
-endif()
-
-if (APHRODITE_HADAMARD_GPU_ARCHES)
-  define_gpu_extension_target(
-    _hadamard_C
-    DESTINATION aphrodite
-    LANGUAGE ${APHRODITE_GPU_LANG}
-    SOURCES ${APHRODITE_HADAMARD_EXT_SRC}
-    COMPILE_FLAGS ${APHRODITE_HADAMARD_GPU_FLAGS}
-    ARCHITECTURES ${APHRODITE_HADAMARD_GPU_ARCHES}
-    WITH_SOABI)
-else()
-  message(WARNING "Unable to create _hadamard_C target because none of the "
-    "requested architectures (${APHRODITE_GPU_ARCHES}) are supported, i.e. >= 6.0")
-endif()
+# #
+# # _hadamard_C extension
+# #
+
+# set(APHRODITE_HADAMARD_EXT_SRC
+#   "kernels/hadamard/fast_hadamard_transform.cpp"
+#   "kernels/hadamard/fast_hadamard_transform_cuda.cu")
+
+# #
+# # Copy GPU compilation flags+update for hadamard
+# #
+# set(APHRODITE_HADAMARD_GPU_FLAGS ${APHRODITE_GPU_FLAGS})
+# list(APPEND APHRODITE_HADAMARD_GPU_FLAGS
+#   "-U__CUDA_NO_HALF_OPERATORS__"
+#   "-U__CUDA_NO_HALF_CONVERSIONS__"
+#   "-U__CUDA_NO_BFLOAT16_OPERATORS__"
+#   "-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
+#   "-U__CUDA_NO_BFLOAT162_OPERATORS__"
+#   "-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
+#   "--expt-relaxed-constexpr"
+#   "--expt-extended-lambda"
+#   "--use_fast_math"
+#   "-lineinfo")
+
+# #
+# # Filter out CUDA architectures < 7.0 for hadamard.
+# #
+# if (${APHRODITE_GPU_LANG} STREQUAL "CUDA")
+#   set(APHRODITE_HADAMARD_GPU_ARCHES)
+#   foreach(ARCH ${APHRODITE_GPU_ARCHES})
+#     string_to_ver(CODE_VER ${ARCH})
+#     if (CODE_VER GREATER_EQUAL 6.0)
+#       list(APPEND APHRODITE_HADAMARD_GPU_ARCHES ${ARCH})
+#     endif()
+#   endforeach()
+#   message(STATUS "Hadamard target arches: ${APHRODITE_HADAMARD_GPU_ARCHES}")
+# endif()
+
+# if (APHRODITE_HADAMARD_GPU_ARCHES)
+#   define_gpu_extension_target(
+#     _hadamard_C
+#     DESTINATION aphrodite
+#     LANGUAGE ${APHRODITE_GPU_LANG}
+#     SOURCES ${APHRODITE_HADAMARD_EXT_SRC}
+#     COMPILE_FLAGS ${APHRODITE_HADAMARD_GPU_FLAGS}
+#     ARCHITECTURES ${APHRODITE_HADAMARD_GPU_ARCHES}
+#     WITH_SOABI)
+# else()
+#   message(WARNING "Unable to create _hadamard_C target because none of the "
+#     "requested architectures (${APHRODITE_GPU_ARCHES}) are supported, i.e. >= 6.0")
+# endif()
 
 #
 # Add the `default` target which detects which extensions should be
@@ -404,22 +368,18 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA" OR APHRODITE_GPU_LANG STREQUAL "HIP")
 
   message(STATUS "Enabling moe extension.")
   add_dependencies(default _moe_C)
+
   # Enable punica if -DAPHRODITE_INSTALL_PUNICA_KERNELS=ON or
   # APHRODITE_INSTALL_PUNICA_KERNELS is set in the environment and
   # there are supported target arches.
-  if (APHRODITE_QUANT_EXT_SRC AND
-      (ENV{APHRODITE_INSTALL_QUANT_KERNELS} OR APHRODITE_INSTALL_QUANT_KERNELS))
-    message(STATUS "Enabling quant extension.")
-    add_dependencies(default _quant_C)
-  endif()
   if (APHRODITE_PUNICA_GPU_ARCHES AND
       (ENV{APHRODITE_INSTALL_PUNICA_KERNELS} OR APHRODITE_INSTALL_PUNICA_KERNELS))
     message(STATUS "Enabling punica extension.")
     add_dependencies(default _punica_C)
   endif()
-  if (APHRODITE_HADAMARD_GPU_ARCHES AND
-      (ENV{APHRODITE_INSTALL_HADAMARD_KERNELS} OR APHRODITE_INSTALL_HADAMARD_KERNELS))
-    message(STATUS "Enabling hadamard extension.")
-    add_dependencies(default _hadamard_C)
-  endif()
+  # if (APHRODITE_HADAMARD_GPU_ARCHES AND
+  #     (ENV{APHRODITE_INSTALL_HADAMARD_KERNELS} OR APHRODITE_INSTALL_HADAMARD_KERNELS))
+  #   message(STATUS "Enabling hadamard extension.")
+  #   add_dependencies(default _hadamard_C)
+  # endif()
 endif()

+ 477 - 0
aphrodite/_custom_ops.py

@@ -0,0 +1,477 @@
+import contextlib
+from typing import List, Optional, Tuple, Type
+
+import torch
+
+try:
+    import aphrodite._C
+except ImportError as e:
+    from loguru import logger
+    logger.warning("Failed to import from vllm._C with %r", e)
+
+with contextlib.suppress(ImportError):
+    import aphrodite._moe_C
+
+with contextlib.suppress(ImportError):
+    # ruff: noqa: F401
+    import aphrodite._punica_C
+
+
+def is_custom_op_supported(op_name: str) -> bool:
+    op, overloads = torch._C._jit_get_operation(op_name)
+    return op is not None
+
+
+# activation ops
+def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
+    torch.ops._C.silu_and_mul(out, x)
+
+
+def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
+    torch.ops._C.gelu_and_mul(out, x)
+
+
+def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
+    torch.ops._C.gelu_tanh_and_mul(out, x)
+
+
+def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
+    torch.ops._C.gelu_fast(out, x)
+
+
+def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
+    torch.ops._C.gelu_new(out, x)
+
+
+# page attention ops
+def paged_attention_v1(
+    out: torch.Tensor,
+    query: torch.Tensor,
+    key_cache: torch.Tensor,
+    value_cache: torch.Tensor,
+    num_kv_heads: int,
+    scale: float,
+    block_tables: torch.Tensor,
+    seq_lens: torch.Tensor,
+    block_size: int,
+    max_seq_len: int,
+    alibi_slopes: Optional[torch.Tensor],
+    kv_cache_dtype: str,
+    kv_scale: float,
+    tp_rank: int = 0,
+    blocksparse_local_blocks: int = 0,
+    blocksparse_vert_stride: int = 0,
+    blocksparse_block_size: int = 64,
+    blocksparse_head_sliding_step: int = 0,
+) -> None:
+    torch.ops._C.paged_attention_v1(
+        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
+        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
+        kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
+        blocksparse_block_size, blocksparse_head_sliding_step)
+
+
+def paged_attention_v2(
+    out: torch.Tensor,
+    exp_sum: torch.Tensor,
+    max_logits: torch.Tensor,
+    tmp_out: torch.Tensor,
+    query: torch.Tensor,
+    key_cache: torch.Tensor,
+    value_cache: torch.Tensor,
+    num_kv_heads: int,
+    scale: float,
+    block_tables: torch.Tensor,
+    seq_lens: torch.Tensor,
+    block_size: int,
+    max_seq_len: int,
+    alibi_slopes: Optional[torch.Tensor],
+    kv_cache_dtype: str,
+    kv_scale: float,
+    tp_rank: int = 0,
+    blocksparse_local_blocks: int = 0,
+    blocksparse_vert_stride: int = 0,
+    blocksparse_block_size: int = 64,
+    blocksparse_head_sliding_step: int = 0,
+) -> None:
+    torch.ops._C.paged_attention_v2(
+        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
+        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
+        alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
+        blocksparse_local_blocks, blocksparse_vert_stride,
+        blocksparse_block_size, blocksparse_head_sliding_step)
+
+
+# pos encoding ops
+def rotary_embedding(
+    positions: torch.Tensor,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    head_size: int,
+    cos_sin_cache: torch.Tensor,
+    is_neox: bool,
+) -> None:
+    torch.ops._C.rotary_embedding(positions, query, key, head_size,
+                                  cos_sin_cache, is_neox)
+
+
+def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
+                             key: torch.Tensor, head_size: int,
+                             cos_sin_cache: torch.Tensor, is_neox: bool,
+                             rot_dim: int,
+                             cos_sin_cache_offsets: torch.Tensor) -> None:
+    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
+                                          cos_sin_cache, is_neox, rot_dim,
+                                          cos_sin_cache_offsets)
+
+
+# layer norm ops
+def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
+             epsilon: float) -> None:
+    torch.ops._C.rms_norm(out, input, weight, epsilon)
+
+
+def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
+                       weight: torch.Tensor, epsilon: float) -> None:
+    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
+
+
+# quantization ops
+# awq
+def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
+                   zeros: torch.Tensor, split_k_iters: int, thx: int,
+                   thy: int) -> torch.Tensor:
+    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
+                                       thx, thy)
+
+
+def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
+             scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
+    return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
+
+
+# gptq
+def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
+              b_g_idx: torch.Tensor, use_exllama: bool,
+              bit: int) -> torch.Tensor:
+    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
+                                  b_g_idx, use_exllama, bit)
+
+
+def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
+                 bit: int) -> None:
+    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
+
+
+# squeezellm
+def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
+                    lookup_table: torch.Tensor) -> None:
+    torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
+
+
+# marlin
+def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
+                size_n: int, size_k: int) -> torch.Tensor:
+    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
+                                    size_n, size_k)
+
+
+# marlin_24
+def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+                        b_meta: torch.Tensor, b_scales: torch.Tensor,
+                        workspace: torch.Tensor, num_bits: int, size_m: int,
+                        size_n: int, size_k: int) -> torch.Tensor:
+    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
+                                            workspace, num_bits, size_m,
+                                            size_n, size_k)
+
+
+# cutlass
+def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
+                         scale_a: torch.Tensor, scale_b: torch.Tensor,
+                         out_dtype: Type[torch.dtype]) -> torch.Tensor:
+    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
+    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
+
+    m = a.shape[0]
+    n = b.shape[1]
+    out = torch.empty((m, n), dtype=out_dtype, device=a.device)
+
+    torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
+
+    return out
+
+
+# aqlm
+def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
+              codebooks: torch.Tensor, scales: torch.Tensor,
+              codebook_partition_sizes: torch.Tensor,
+              bias: Optional[torch.Tensor]) -> torch.Tensor:
+    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
+                                  codebook_partition_sizes, bias)
+
+
+def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
+                 codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
+    return torch.ops._C.aqlm_dequant(codes, codebooks,
+                                     codebook_partition_sizes)
+
+
+# gptq_marlin
+def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
+                       size_k: int, size_n: int,
+                       num_bits: int) -> torch.Tensor:
+    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
+                                           num_bits)
+
+
+def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+                     b_scales: torch.Tensor, g_idx: torch.Tensor,
+                     perm: torch.Tensor, workspace: torch.Tensor,
+                     num_bits: int, size_m: int, size_n: int, size_k: int,
+                     is_k_full: bool) -> torch.Tensor:
+    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
+                                         workspace, num_bits, size_m, size_n,
+                                         size_k, is_k_full)
+
+
+# fp8
+def scaled_fp8_quant(
+    input: torch.Tensor,
+    scale: Optional[torch.Tensor] = None,
+    batch_dim_padding: Optional[int] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to FP8 and return quantized tensor and scale.
+
+    This function supports both static and dynamic quantization: If you
+    provide the scale, it will use static scaling and if you omit it,
+    the scale will be determined dynamically. The function also allows
+    optional padding of the output tensor for downstream kernels that
+    will benefit from padding.
+
+    Args:
+        input: The input tensor to be quantized to FP8
+        scale: Optional scaling factor for the FP8 quantization
+        batch_dim_padding: If specified, pad the first dimension
+            of the output to at least this value.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
+            scaling factor.
+    """
+    if batch_dim_padding:
+        shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
+        output = torch.empty(shape,
+                             device=input.device,
+                             dtype=torch.float8_e4m3fn)
+    else:
+        output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
+    if scale is None:
+        scale = torch.zeros(1, device=input.device, dtype=torch.float32)
+        torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
+    else:
+        torch.ops._C.static_scaled_fp8_quant(output, input, scale)
+    return output, scale
+
+
+# int8
+def scaled_int8_quant(
+        input: torch.Tensor,
+        scale: Optional[torch.Tensor] = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize the input tensor to int8 and return the quantized tensor and scale.
+
+    Args:
+        input: The input tensor to be quantized to int8.
+        scale: Optional scaling factor for the int8 quantization.
+            When not provided, we invoke dynamic-per-token quantization.
+
+    Returns:
+      Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
+    """
+    output = torch.empty_like(input, dtype=torch.int8)
+    if scale is not None:
+        # static-per-tensor quantization.
+        torch.ops._C.static_scaled_int8_quant(output, input, scale)
+        return output, scale
+
+    # dynamic-per-token quantization.
+    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
+                               device=input.device,
+                               dtype=torch.float32)
+    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
+    return output, input_scales
+
+
+# quip#
+def quip_gemv(
+    A: torch.Tensor,
+    B: torch.Tensor,
+    CB: torch.Tensor,
+) -> torch.Tensor:
+    return torch.ops._C.quip_gemv(A, B, CB)
+
+
+def quip_decompress(
+    YIs: torch.Tensor,
+    CB: torch.Tensor,
+    Y: torch.Tensor,
+) -> torch.Tensor:
+    return torch.ops._C.quip_decompress(YIs, CB, Y)
+
+
+# moe
+def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
+                         block_size: int, sorted_token_ids: torch.Tensor,
+                         experts_ids: torch.Tensor,
+                         num_tokens_post_pad: torch.Tensor) -> None:
+    torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
+                                      sorted_token_ids, experts_ids,
+                                      num_tokens_post_pad)
+
+
+def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
+                 token_expert_indicies: torch.Tensor,
+                 gating_output: float) -> None:
+    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
+                                  token_expert_indicies, gating_output)
+
+
+def reshape_and_cache(
+    key: torch.Tensor,
+    value: torch.Tensor,
+    key_cache: torch.Tensor,
+    value_cache: torch.Tensor,
+    slot_mapping: torch.Tensor,
+    kv_cache_dtype: str,
+    kv_scale: float,
+) -> None:
+    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
+                                             value_cache, slot_mapping,
+                                             kv_cache_dtype, kv_scale)
+
+
+def reshape_and_cache_flash(
+    key: torch.Tensor,
+    value: torch.Tensor,
+    key_cache: torch.Tensor,
+    value_cache: torch.Tensor,
+    slot_mapping: torch.Tensor,
+    kv_cache_dtype: str,
+) -> None:
+    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
+                                                   value_cache, slot_mapping,
+                                                   kv_cache_dtype)
+
+
+def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
+                block_mapping: torch.Tensor) -> None:
+    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
+
+
+def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
+                block_mapping: torch.Tensor) -> None:
+    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
+
+
+def convert_fp8(output: torch.Tensor,
+                input: torch.Tensor,
+                scale: float = 1.0,
+                kv_dtype: str = "fp8") -> None:
+    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
+
+
+def get_device_attribute(attribute: int, device: int) -> int:
+    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
+
+
+def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
+    # ruff: noqa: E501
+    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
+        device)
+
+
+# custom ar
+def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
+                   handles: List[str], offsets: List[int], rank: int,
+                   full_nvlink: bool) -> int:
+    return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
+                                                 offsets, rank, full_nvlink)
+
+
+def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
+                     full_nvlink: bool) -> bool:
+    return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
+                                                   full_nvlink)
+
+
+def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
+    torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
+
+
+def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
+                     out: torch.Tensor) -> None:
+    torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
+
+
+def dispose(fa: int) -> None:
+    torch.ops._C_custom_ar.dispose(fa)
+
+
+def meta_size() -> int:
+    return torch.ops._C_custom_ar.meta_size()
+
+
+def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
+                    offsets: List[int]) -> None:
+    return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
+
+
+def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
+    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
+
+
+def register_graph_buffers(fa: int, handles: List[str],
+                           offsets: List[List[int]]) -> None:
+    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
+
+
+# punica
+def dispatch_bgmv(
+    y: torch.Tensor,
+    x: torch.Tensor,
+    w_t_all: torch.Tensor,
+    indicies: torch.Tensor,
+    layer_idx: int,
+    scale: float,
+) -> None:
+    torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx,
+                                      scale)
+
+
+def dispatch_bgmv_low_level(
+    y: torch.Tensor,
+    x: torch.Tensor,
+    w_t_all: torch.Tensor,
+    indicies: torch.Tensor,
+    layer_idx: int,
+    scale: float,
+    h_in: int,
+    h_out: int,
+    y_offset: int,
+) -> None:
+    torch.ops._punica_C.dispatch_bgmv_low_level(
+        y,
+        x,
+        w_t_all,
+        indicies,
+        layer_idx,
+        scale,
+        h_in,
+        h_out,
+        y_offset,
+    )

+ 5 - 5
aphrodite/attention/backends/flash_attn.py

@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
 import torch
 from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
 
-from aphrodite._C import cache_ops
+from aphrodite import _custom_ops as ops
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionMetadata)
@@ -48,11 +48,11 @@ class FlashAttentionBackend(AttentionBackend):
     ) -> None:
         src_key_cache = src_kv_cache[0]
         dst_key_cache = dst_kv_cache[0]
-        cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
+        ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
 
         src_value_cache = src_kv_cache[1]
         dst_value_cache = dst_kv_cache[1]
-        cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
+        ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
 
     @staticmethod
     def copy_blocks(
@@ -61,7 +61,7 @@ class FlashAttentionBackend(AttentionBackend):
     ) -> None:
         key_caches = [kv_cache[0] for kv_cache in kv_caches]
         value_caches = [kv_cache[1] for kv_cache in kv_caches]
-        cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
+        ops.copy_blocks(key_caches, value_caches, src_to_dists)
 
 
 @dataclass
@@ -286,7 +286,7 @@ class FlashAttentionImpl(AttentionImpl):
             # Reshape the input keys and values and store them in the cache.
             # If kv_cache is not provided, the new key and value tensors are
             # not cached. This happens during the initial memory profiling run.
-            cache_ops.reshape_and_cache_flash(
+            ops.reshape_and_cache_flash(
                 key,
                 value,
                 key_cache,

+ 1 - 1
aphrodite/attention/backends/flashinfer.py

@@ -6,7 +6,7 @@ import torch
 from flashinfer import BatchDecodeWithPagedKVCacheWrapper
 from vllm_flash_attn import flash_attn_varlen_func
 
-from aphrodite._C import cache_ops as ops
+from aphrodite import _custom_ops as ops
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionMetadata)

+ 5 - 6
aphrodite/attention/ops/paged_attn.py

@@ -3,8 +3,7 @@ from typing import List, Optional, Tuple
 
 import torch
 
-from aphrodite._C import ops
-from aphrodite._C import cache_ops
+from aphrodite import _custom_ops as ops
 from aphrodite.attention.ops.prefix_prefill import context_attention_fwd
 
 # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
@@ -69,7 +68,7 @@ class PagedAttention:
         kv_cache_dtype: str,
         kv_scale: float,
     ) -> None:
-        cache_ops.reshape_and_cache(
+        ops.reshape_and_cache(
             key,
             value,
             key_cache,
@@ -224,11 +223,11 @@ class PagedAttention:
     ) -> None:
         src_key_cache = src_kv_cache[0]
         dst_key_cache = dst_kv_cache[0]
-        cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
+        ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
 
         src_value_cache = src_kv_cache[1]
         dst_value_cache = dst_kv_cache[1]
-        cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
+        ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
 
     @staticmethod
     def copy_blocks(
@@ -237,4 +236,4 @@ class PagedAttention:
     ) -> None:
         key_caches = [kv_cache[0] for kv_cache in kv_caches]
         value_caches = [kv_cache[1] for kv_cache in kv_caches]
-        cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
+        ops.copy_blocks(key_caches, value_caches, src_to_dists)

+ 4 - 4
aphrodite/common/utils.py

@@ -149,10 +149,10 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
     """Returns the maximum shared memory per thread block in bytes."""
     # NOTE: This import statement should be executed lazily since
     # the Neuron-X backend does not have the `cuda_utils` module.
-    from aphrodite._C import cuda_utils
+    from aphrodite import _custom_ops as ops
 
     max_shared_mem = (
-        cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu))
+        ops.get_max_shared_memory_per_block_device_attribute(gpu))
     # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
     # will fail
     assert max_shared_mem > 0, "max_shared_mem can not be zero"
@@ -329,10 +329,10 @@ def _generate_random_fp8(
     #-----|-------------|-------------------
     # Inf | N/A         | s.11111.00
     # NaN | s.1111.111  | s.11111.{01,10,11}
-    from aphrodite._C import cache_ops
+    from aphrodite import _custom_ops as ops
     tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
     tensor_tmp.uniform_(low, high)
-    cache_ops.convert_fp8(tensor, tensor_tmp)
+    ops.convert_fp8(tensor, tensor_tmp)
     del tensor_tmp
 
 

+ 4 - 4
aphrodite/modeling/layers/activation.py

@@ -29,7 +29,7 @@ class SiluAndMul(CustomOp):
         return F.silu(x[..., :d]) * x[..., d:]
 
     def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
-        from aphrodite._C import ops
+        from aphrodite import _custom_ops as ops
         d = x.shape[-1] // 2
         output_shape = (x.shape[:-1] + (d, ))
         out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
@@ -59,7 +59,7 @@ class GeluAndMul(CustomOp):
         return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
 
     def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
-        from aphrodite._C import ops
+        from aphrodite import _custom_ops as ops
         d = x.shape[-1] // 2
         output_shape = (x.shape[:-1] + (d, ))
         out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
@@ -82,7 +82,7 @@ class NewGELU(CustomOp):
                                            (x + 0.044715 * torch.pow(x, 3.0))))
 
     def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
-        from aphrodite._C import ops
+        from aphrodite import _custom_ops as ops
         out = torch.empty_like(x)
         ops.gelu_new(out, x)
         return out
@@ -96,7 +96,7 @@ class FastGELU(CustomOp):
                                            (1.0 + 0.044715 * x * x)))
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        from aphrodite._C import ops
+        from aphrodite import _custom_ops as ops
         out = torch.empty_like(x)
         ops.gelu_fast(out, x)
         return out

+ 2 - 3
aphrodite/modeling/layers/fused_moe/fused_moe.py

@@ -9,8 +9,7 @@ import triton
 import triton.language as tl
 from loguru import logger
 
-import aphrodite._moe_C as moe_kernels
-from aphrodite._C import ops
+from aphrodite import _custom_ops as ops
 
 
 @triton.jit
@@ -353,7 +352,7 @@ def fused_topk(
                                         topk,
                                         dtype=torch.int32,
                                         device=hidden_states.device)
-    moe_kernels.topk_softmax(
+    ops.topk_softmax(
         topk_weights,
         topk_ids,
         token_expert_indicies,

+ 1 - 1
aphrodite/modeling/layers/layernorm.py

@@ -73,7 +73,7 @@ class RMSNorm(CustomOp):
         x: torch.Tensor,
         residual: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
-        from aphrodite._C import ops
+        from aphrodite import _custom_ops as ops
         if residual is not None:
             ops.fused_add_rms_norm(
                 x,

+ 1 - 1
aphrodite/modeling/layers/rotary_embedding.py

@@ -147,7 +147,7 @@ class RotaryEmbedding(CustomOp):
         key: torch.Tensor,
         offsets: Optional[torch.Tensor] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
-        from aphrodite._C import ops
+        from aphrodite import _custom_ops as ops
         self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
                                                    dtype=query.dtype)
         # ops.rotary_embedding()/batched_rotary_embedding()

+ 0 - 10
aphrodite/quantization/__init__.py

@@ -1,7 +1,5 @@
 from typing import Type
 
-from loguru import logger
-
 from aphrodite.quantization.aqlm import AQLMConfig
 from aphrodite.quantization.autoquant import AutoQuantConfig
 from aphrodite.quantization.awq import AWQConfig
@@ -21,14 +19,6 @@ from aphrodite.quantization.marlin import MarlinConfig
 from aphrodite.quantization.quip import QuipConfig
 from aphrodite.quantization.squeezellm import SqueezeLLMConfig
 
-try:
-    from aphrodite._quant_C import quant_ops  # noqa: F401
-except ImportError:
-    logger.warning("The Quantization Kernels are not installed. "
-                   "To use quantization with Aphrodite, make sure "
-                   "you've exported the `APHRODITE_INSTALL_QUANT_KERNELS=1`"
-                   "environment variable during the compilation process.")
-
 QUANTIZATION_METHODS = {
     "aqlm": AQLMConfig,
     "awq": AWQConfig,

+ 1 - 8
aphrodite/quantization/aqlm.py

@@ -2,22 +2,17 @@
 # and https://arxiv.org/pdf/2401.06118.pdf
 
 import math
-from contextlib import suppress
 from typing import Any, Dict, List, Optional
 
 import torch
 import torch.nn.functional as F
 from torch.nn.parameter import Parameter
 
+from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 
-HAS_QUANTS = False
-with suppress(ImportError):
-    from aphrodite._quant_C import quant_ops as ops
-    HAS_QUANTS = True
-
 
 def get_int_dtype(nbits: int) -> torch.dtype:
     if nbits <= 8:
@@ -229,8 +224,6 @@ class AQLMLinearMethod(LinearMethodBase):
     """
 
     def __init__(self, quant_config: AQLMConfig):
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
 
     def create_weights(self, layer: torch.nn.Module,

+ 1 - 10
aphrodite/quantization/autoquant.py

@@ -1,9 +1,9 @@
-from contextlib import suppress
 from typing import Any, Dict, List, NamedTuple, Optional, TypeVar
 
 import torch
 from torch.nn.parameter import Parameter
 
+from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear, LinearBase,
                                               LinearMethodBase,
                                               QKVParallelLinear,
@@ -11,11 +11,6 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear, LinearBase,
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 
-HAS_QUANTS = False
-with suppress(ImportError):
-    from aphrodite._quant_C import quant_ops as ops
-    HAS_QUANTS = True
-
 
 class AutoQuantConfig(QuantizationConfig):
     """Config class for AutoQuant.
@@ -30,8 +25,6 @@ class AutoQuantConfig(QuantizationConfig):
             from_float: bool,
             quant_mode: str,  # llm_int8, smoothquant, weight_only
     ) -> None:
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.weight_bits = weight_bits
         self.group_size = group_size
         self.zero_point = zero_point
@@ -105,8 +98,6 @@ class AutoQuantLinearMethod(LinearMethodBase):
     """
 
     def __init__(self, quant_config: AutoQuantConfig):
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
 
     def create_weights(self, layer: torch.nn.Module,

+ 1 - 8
aphrodite/quantization/awq.py

@@ -1,18 +1,13 @@
-from contextlib import suppress
 from typing import Any, Dict, List, Optional
 
 import torch
 from torch.nn.parameter import Parameter
 
+from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 
-HAS_QUANTS = False
-with suppress(ImportError):
-    from aphrodite._quant_C import quant_ops as ops
-    HAS_QUANTS = True
-
 
 class AWQConfig(QuantizationConfig):
     """Config class for AWQ.
@@ -83,8 +78,6 @@ class AWQLinearMethod(LinearMethodBase):
     """
 
     def __init__(self, quant_config: AWQConfig):
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
 
     def create_weights(self, layer: torch.nn.Module,

+ 5 - 2
aphrodite/quantization/compressed_tensors/compressed_tensors.py

@@ -128,7 +128,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
                        **extra_weight_attrs):
         """
         Use the CompressedTensorsScheme associated with each layer to create 
-        the necessary parameters for the layer.
+        the necessary parameters for the layer. See LinearMethodBase for param
+        details
+
         """
         weight_loader = extra_weight_attrs.get("weight_loader")
 
@@ -150,7 +152,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
         """
         Use the output of create_weights and the CompressedTensorsScheme 
         associated with the layer to apply the forward pass with the 
-        layer input.
+        layer input.  See LinearMethodBase for param details
+
         """
 
         if bias is not None:

+ 4 - 5
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py

@@ -3,11 +3,10 @@ from typing import Callable, List, Tuple, Union
 import torch
 from torch.nn import Parameter
 
+from aphrodite import _custom_ops as custom_ops
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.compressed_tensors.schemes import \
     CompressedTensorsScheme
-from aphrodite.quantization.compressed_tensors.schemes.utils import (
-    cutlass_scaled_mm_dq, scaled_int8_quant)
 
 __all__ = ["CompressedTensorsW8A8DynamicToken"]
 
@@ -81,6 +80,6 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
         weight = layer.weight
         weight_scale = layer.weight_scale
 
-        x_q, input_scales = scaled_int8_quant(x)
-        return cutlass_scaled_mm_dq(x_q, weight.t(), input_scales,
-                                    weight_scale, x.dtype)
+        x_q, input_scales = custom_ops.scaled_int8_quant(x)
+        return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), input_scales,
+                                               weight_scale, x.dtype)

+ 4 - 5
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py

@@ -3,11 +3,10 @@ from typing import Callable, List, Tuple, Union
 import torch
 from torch.nn import Parameter
 
+from aphrodite import _custom_ops as custom_ops
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.compressed_tensors.schemes import \
     CompressedTensorsScheme
-from aphrodite.quantization.compressed_tensors.schemes.utils import (
-    cutlass_scaled_mm_dq, scaled_int8_quant)
 
 __all__ = ["CompressedTensorsW8A8StaticTensor"]
 
@@ -98,7 +97,7 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
         act_scale = layer.input_scale
 
         # Input quantize
-        x_q, _ = scaled_int8_quant(x, act_scale)
+        x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
 
-        return cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, weight_scale,
-                                    x.dtype)
+        return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
+                                               weight_scale, x.dtype)

+ 0 - 48
aphrodite/quantization/compressed_tensors/schemes/utils.py

@@ -1,48 +0,0 @@
-from typing import Optional, Tuple, Type
-
-import torch
-
-from aphrodite._quant_C import quant_ops
-
-
-# cutlass
-def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
-                         scale_a: torch.Tensor, scale_b: torch.Tensor,
-                         out_dtype: Type[torch.dtype]) -> torch.Tensor:
-    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
-    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
-    m = a.shape[0]
-    n = b.shape[1]
-    out = torch.empty((m, n), dtype=out_dtype, device=a.device)
-
-    quant_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
-
-    return out
-
-
-# int8
-def scaled_int8_quant(
-        input: torch.Tensor,
-        scale: Optional[torch.Tensor] = None
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """
-    Quantize the input tensor to int8 and return the quantized tensor and scale.
-    Args:
-        input: The input tensor to be quantized to int8.
-        scale: Optional scaling factor for the int8 quantization.
-            When not provided, we invoke dynamic-per-token quantization.
-    Returns:
-      Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
-    """
-    output = torch.empty_like(input, dtype=torch.int8)
-    if scale is not None:
-        # static-per-tensor quantization.
-        quant_ops.static_scaled_int8_quant(output, input, scale)
-        return output, scale
-
-    # dynamic-per-token quantization.
-    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
-                               device=input.device,
-                               dtype=torch.float32)
-    quant_ops.dynamic_scaled_int8_quant(output, input, input_scales)
-    return output, input_scales

+ 1 - 8
aphrodite/quantization/exl2.py

@@ -1,17 +1,12 @@
-from contextlib import suppress
 from typing import Any, Dict, List, Optional
 
 import torch
 
+from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 
-HAS_QUANTS = False
-with suppress(ImportError):
-    from aphrodite._quant_C import quant_ops as ops
-    HAS_QUANTS = True
-
 
 def make_group_map(q_groups, num_qrows):
     gr = q_groups.tolist()
@@ -88,8 +83,6 @@ class Exl2LinearMethod(LinearMethodBase):
     """
 
     def __init__(self, quant_config: Exl2Config):
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
 
     def create_weights(self, layer: torch.nn.Module,

+ 24 - 66
aphrodite/quantization/fp8.py

@@ -1,4 +1,3 @@
-from contextlib import suppress
 from typing import Any, Dict, List, Optional, Tuple, Union
 
 import torch
@@ -6,58 +5,16 @@ from loguru import logger
 from torch.nn import Module
 from torch.nn.parameter import Parameter
 
-from aphrodite.common.utils import print_warning_once
+from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
-from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import (QuantizationConfig,
                                                 QuantizeMethodBase)
-from aphrodite.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_statictensor import \
-    cutlass_scaled_mm_dq  # noqa: E501
-
-HAS_QUANTS = False
-with suppress(ImportError):
-    from aphrodite._quant_C import quant_ops as ops
-    HAS_QUANTS = True
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.common.utils import print_warning_once
 
 ACTIVATION_SCHEMES = ["static", "dynamic"]
 
 
-def scaled_fp8_quant(
-    input: torch.Tensor,
-    scale: Optional[torch.Tensor] = None,
-    batch_dim_padding: Optional[int] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """
-    Quantize input tensor to FP8 and return quantized tensor and scale.
-    This function supports both static and dynamic quantization: If you
-    provide the scale, it will use static scaling and if you omit it,
-    the scale will be determined dynamically. The function also allows
-    optional padding of the output tensor for downstream kernels that
-    will benefit from padding.
-    Args:
-        input: The input tensor to be quantized to FP8
-        scale: Optional scaling factor for the FP8 quantization
-        batch_dim_padding: If specified, pad the first dimension
-            of the output to at least this value.
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
-            scaling factor.
-    """
-    if batch_dim_padding:
-        shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
-        output = torch.empty(shape,
-                             device=input.device,
-                             dtype=torch.float8_e4m3fn)
-    else:
-        output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
-    if scale is None:
-        scale = torch.zeros(1, device=input.device, dtype=torch.float32)
-        ops.dynamic_scaled_fp8_quant(output, input, scale)
-    else:
-        ops.static_scaled_fp8_quant(output, input, scale)
-    return output, scale
-
-
 def cutlass_fp8_supported() -> bool:
     capability = torch.cuda.get_device_capability()
     capability = capability[0] * 10 + capability[1]
@@ -119,8 +76,7 @@ class Fp8Config(QuantizationConfig):
 
     def get_quant_method(
             self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
-        from aphrodite.attention.layer import \
-            Attention  # Avoid circular import
+        from aphrodite.attention.layer import Attention  # Avoid circular import
 
         if isinstance(layer, LinearBase):
             return Fp8LinearMethod(self)
@@ -136,21 +92,21 @@ class Fp8LinearMethod(LinearMethodBase):
     """Linear method for FP8.
     Supports loading FP8 checkpoints with static weight scale and
     dynamic/static activation scale.
+
     Also supports loading quantized FP16/BF16 model checkpoints with dynamic
     activation scaling. The weight scaling factor will be initialized after
     the model weights are loaded.
+
     Limitations:
     1. Only support per-tensor quantization due to torch._scaled_mm support.
     2. Only support float8_e4m3fn data type due to the limitation of
        torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
-       
+
     Args:
         quant_config: The quantization config.
     """
 
     def __init__(self, quant_config: Fp8Config):
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
         self.cutlass_fp8_supported = cutlass_fp8_supported()
 
@@ -213,7 +169,7 @@ class Fp8LinearMethod(LinearMethodBase):
                 output_partition_sizes=output_partition_sizes,
                 **extra_weight_attrs)
 
-            # ACTIVATION SCALE
+            # INPUT ACTIVATION SCALE
             if self.quant_config.activation_scheme == "static":
                 self._create_scale_param(
                     scale_name="input_scale",
@@ -244,7 +200,8 @@ class Fp8LinearMethod(LinearMethodBase):
 
         # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
         if not self.quant_config.is_checkpoint_fp8_serialized:
-            qweight, weight_scale = scaled_fp8_quant(layer.weight, scale=None)
+            qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
+                                                         scale=None)
             layer.weight = Parameter(qweight.t(), requires_grad=False)
             layer.weight_scale = Parameter(weight_scale, requires_grad=False)
             layer.logical_widths = None
@@ -273,16 +230,16 @@ class Fp8LinearMethod(LinearMethodBase):
             weight = layer.weight
             layer.weight = Parameter(weight.t(), requires_grad=False)
 
-            # ACT_SCALE
+            # INPUT ACTIVATION SCALE
             #   Dynamic: set to None (required input to ops.scaled_fp8_quant).
-            #   Static:  set to max of the act_scales (since they are equal).
+            #   Static:  set to max of the input_scales (since they are equal).
             if self.quant_config.activation_scheme == "dynamic":
                 layer.input_scale = None
             elif self.quant_config.activation_scheme == "static":
                 if not all_close_1d(layer.input_scale):
                     raise ValueError(
-                        "All the act_scales for the logical weights of a layer "
-                        f"must be equal. But got {layer.input_scale}")
+                        "All the input_scales for the logical weights of a "
+                        f"layer must be equal. But got {layer.input_scale}")
                 layer.input_scale = Parameter(layer.input_scale.max(),
                                               requires_grad=False)
             else:
@@ -293,15 +250,16 @@ class Fp8LinearMethod(LinearMethodBase):
               layer: torch.nn.Module,
               x: torch.Tensor,
               bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+
         # ops.scaled_fp8_quant supports both dynamic and static quant.
         #   If dynamic, layer.input_scale is None and x_scale computed from x.
-        #   If static,  layer.input_scale is scalar and x_scale set to
-        # input_scale.
+        #   If static, layer.input_scale is scalar and x_scale is input_scale.
+
         if bias is None and self.cutlass_fp8_supported:
-            qinput, x_scale = scaled_fp8_quant(x, layer.input_scale)
+            qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
 
             # Fused GEMM_DQ
-            output = cutlass_scaled_mm_dq(
+            output = ops.cutlass_scaled_mm_dq(
                 qinput,
                 layer.weight,
                 out_dtype=x.dtype,
@@ -310,9 +268,9 @@ class Fp8LinearMethod(LinearMethodBase):
             )
 
         else:
-            qinput, x_scale = scaled_fp8_quant(x,
-                                               layer.input_scale,
-                                               batch_dim_padding=17)
+            qinput, x_scale = ops.scaled_fp8_quant(x,
+                                                   layer.input_scale,
+                                                   batch_dim_padding=17)
 
             # Fused GEMM_DQ -- note we padded the input above because
             # torch._scaled_mm is more performant for matrices with
@@ -338,8 +296,8 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
         self.quant_config = quant_config
 
     def create_weights(self, layer: torch.nn.Module):
-        """Create "weight" (aka kv_scale) for an attention layer. 
-        
+        """Create "weight" (aka kv_scale) for an attention layer.
+
         Args:
             layer: The layer that is using the QuantizeMethodBase factory.
         """

+ 1 - 8
aphrodite/quantization/gguf.py

@@ -1,18 +1,13 @@
-from contextlib import suppress
 from typing import Any, Dict, List, Optional
 
 import torch
 from torch.nn.parameter import Parameter
 
+from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 
-HAS_QUANTS = False
-with suppress(ImportError):
-    from aphrodite._quant_C import quant_ops as ops
-    HAS_QUANTS = True
-
 GGML_QUANT_SIZES = {
     0: (1, 4),  # F32
     1: (1, 2),  # F16
@@ -92,8 +87,6 @@ class GGUFLinearMethod(LinearMethodBase):
     """
 
     def __init__(self, quant_config: GGUFConfig):
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
 
     def create_weights(self, layer: torch.nn.Module,

+ 1 - 8
aphrodite/quantization/gptq.py

@@ -1,5 +1,4 @@
 import enum
-from contextlib import suppress
 from enum import Enum
 from fractions import Fraction
 from typing import Any, Dict, List, Optional
@@ -7,15 +6,11 @@ from typing import Any, Dict, List, Optional
 import torch
 from torch.nn.parameter import Parameter
 
+from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 
-HAS_QUANTS = False
-with suppress(ImportError):
-    from aphrodite._quant_C import quant_ops as ops
-    HAS_QUANTS = True
-
 
 class GPTQConfig(QuantizationConfig):
     """Config class for GPTQ.
@@ -92,8 +87,6 @@ class GPTQLinearMethod(LinearMethodBase):
     """
 
     def __init__(self, quant_config: GPTQConfig):
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
 
     def create_weights(

+ 1 - 8
aphrodite/quantization/gptq_marlin.py

@@ -1,5 +1,4 @@
 import enum
-from contextlib import suppress
 from enum import Enum
 from typing import Any, Dict, List, Optional
 
@@ -7,15 +6,11 @@ import torch
 from loguru import logger
 from torch.nn.parameter import Parameter
 
+from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
                                               set_weight_attrs)
 from aphrodite.quantization.base_config import QuantizationConfig
 
-HAS_QUANTS = False
-with suppress(ImportError):
-    from aphrodite._quant_C import quant_ops as ops
-    HAS_QUANTS = True
-
 GPTQ_MARLIN_TILE = 16
 GPTQ_MARLIN_MIN_THREAD_N = 64
 GPTQ_MARLIN_MIN_THREAD_K = 128
@@ -189,8 +184,6 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
     """
 
     def __init__(self, quant_config: GPTQMarlinConfig) -> None:
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
 
     def create_weights(

+ 3 - 10
aphrodite/quantization/gptq_marlin_24.py

@@ -1,18 +1,13 @@
 from typing import Any, Dict, List, Optional
-from contextlib import suppress
 
 import torch
-from torch.nn.parameter import Parameter
 from loguru import logger
+from torch.nn.parameter import Parameter
 
+from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
-from aphrodite.quantization.base_config import (QuantizationConfig)
 from aphrodite.modeling.utils import set_weight_attrs
-
-HAS_QUANTS = False
-with suppress(ImportError):
-    from aphrodite._quant_C import quant_ops as ops
-    HAS_QUANTS = True
+from aphrodite.quantization.base_config import QuantizationConfig
 
 GPTQ_MARLIN_24_TILE = 16
 GPTQ_MARLIN_24_MIN_THREAD_N = 128
@@ -130,8 +125,6 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
     """
 
     def __init__(self, quant_config: GPTQMarlin24Config):
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
 
     def create_weights(

+ 1 - 8
aphrodite/quantization/marlin.py

@@ -1,19 +1,14 @@
-from contextlib import suppress
 from typing import Any, Dict, List, Optional
 
 import torch
 from loguru import logger
 from torch.nn.parameter import Parameter
 
+from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 
-HAS_QUANTS = False
-with suppress(ImportError):
-    from aphrodite._quant_C import quant_ops as ops
-    HAS_QUANTS = True
-
 
 class MarlinConfig(QuantizationConfig):
     """Config class for Marlin.
@@ -114,8 +109,6 @@ class MarlinLinearMethod(LinearMethodBase):
     """
 
     def __init__(self, quant_config: MarlinConfig):
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
 
     def create_weights(

+ 3 - 12
aphrodite/quantization/quip.py

@@ -1,20 +1,15 @@
-from contextlib import suppress
 from typing import Any, Dict, List, Optional
 
 import torch
 from torch.nn.parameter import Parameter
 
-from aphrodite.modeling.layers.linear import LinearMethodBase, LinearBase
+from aphrodite import _custom_ops as ops
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.quip_utils import (get_hadK, get_packed_abs_grid,
                                                matmul_hadU_cuda,
                                                matmul_hadUt_cuda)
-from aphrodite.modeling.utils import set_weight_attrs
-
-HAS_QUANTS = False
-with suppress(ImportError):
-    from aphrodite._quant_C import quant_ops as ops
-    HAS_QUANTS = True
 
 
 class QuipConfig(QuantizationConfig):
@@ -24,8 +19,6 @@ class QuipConfig(QuantizationConfig):
     """
 
     def __init__(self, codebook: int, use_rand: bool) -> None:
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.codebook = codebook
         self.use_rand = use_rand
 
@@ -77,8 +70,6 @@ class QuipLinearMethod(LinearMethodBase):
     """
 
     def __init__(self, quant_config: QuipConfig):
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
         self.grid_packed_abs = get_packed_abs_grid().to(device="cuda")
         self.pack = 8

+ 1 - 8
aphrodite/quantization/squeezellm.py

@@ -1,20 +1,15 @@
 from typing import Any, Dict, List, Optional
-from contextlib import suppress
 
 import torch
 from torch.nn.parameter import Parameter
 
+from aphrodite import _custom_ops as ops
 from aphrodite.common.utils import is_hip
 from aphrodite.modeling.layers.linear import LinearBase
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import (QuantizationConfig,
                                                 QuantizeMethodBase)
 
-HAS_QUANTS = False
-with suppress(ImportError):
-    from aphrodite._quant_C import quant_ops as ops
-    HAS_QUANTS = True
-
 
 class SqueezeLLMConfig(QuantizationConfig):
     """Config class for SqueezeLLM.
@@ -75,8 +70,6 @@ class SqueezeLLMLinearMethod(QuantizeMethodBase):
     """
 
     def __init__(self, quant_config: SqueezeLLMConfig):
-        if not HAS_QUANTS:
-            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
 
     def create_weights(self, layer: torch.nn.Module,

+ 6 - 5
cmake/cpu_extension.cmake

@@ -12,7 +12,7 @@ include_directories("${CMAKE_SOURCE_DIR}/kernels")
 #
 # Check the compile flags
 #
-list(APPEND CXX_COMPILE_FLAGS 
+list(APPEND CXX_COMPILE_FLAGS
     "-fopenmp"
     "-DAPHRODITE_CPU_EXTENSION")
 
@@ -44,8 +44,8 @@ if (AVX512_FOUND)
 
     find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
     if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
-        if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND 
-            CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) 
+        if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
+            CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
             list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
         else()
             message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
@@ -73,7 +73,7 @@ set(APHRODITE_EXT_SRC
     "kernels/cpu/cache.cpp"
     "kernels/cpu/layernorm.cpp"
     "kernels/cpu/pos_encoding.cpp"
-    "kernels/cpu/pybind.cpp")
+    "kernels/cpu/torch_bindings.cpp")
 
 define_gpu_extension_target(
     _C
@@ -81,7 +81,8 @@ define_gpu_extension_target(
     LANGUAGE CXX
     SOURCES ${APHRODITE_EXT_SRC}
     COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
-    WITH_SOABI 
+    USE_SABI 3
+    WITH_SOABI
 )
 
 add_custom_target(default)

+ 8 - 3
cmake/utils.cmake

@@ -5,7 +5,7 @@
 macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
   file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
   set(Python_EXECUTABLE ${EXECUTABLE})
-  find_package(Python COMPONENTS Interpreter Development.Module)
+  find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
   if (NOT Python_FOUND)
     message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
   endif()
@@ -294,6 +294,7 @@ endmacro()
 # INCLUDE_DIRECTORIES <dirs> - Extra include directories.
 # LIBRARIES <libraries>      - Extra link libraries.
 # WITH_SOABI                 - Generate library with python SOABI suffix name.
+# USE_SABI <version>         - Use python stable api <version>
 #
 # Note: optimization level/debug info is set via cmake build type.
 #
@@ -301,7 +302,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
   cmake_parse_arguments(PARSE_ARGV 1
     GPU
     "WITH_SOABI"
-    "DESTINATION;LANGUAGE"
+    "DESTINATION;LANGUAGE;USE_SABI"
     "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
 
   # Add hipify preprocessing step when building with HIP/ROCm.
@@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME)
     set(GPU_WITH_SOABI)
   endif()
 
-  Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI})
+  if (GPU_USE_SABI)
+    Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
+  else()
+    Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
+  endif()
 
   if (GPU_LANGUAGE STREQUAL "HIP")
     # Make this target dependent on the hipify preprocessor step.

+ 66 - 76
kernels/activation_kernels.cu

@@ -1,5 +1,5 @@
 #include <ATen/cuda/CUDAContext.h>
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <c10/cuda/CUDAGuard.h>
 
 #include <cmath>
@@ -10,11 +10,11 @@
 namespace aphrodite {
 
 // Activation and gating kernel template.
-template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
+template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
 __global__ void act_and_mul_kernel(
-  scalar_t* __restrict__ out,               // [..., d]
-  const scalar_t* __restrict__ input,       // [..., 2, d]
-  const int d) {
+    scalar_t* __restrict__ out,          // [..., d]
+    const scalar_t* __restrict__ input,  // [..., 2, d]
+    const int d) {
   const int64_t token_idx = blockIdx.x;
   for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
     const scalar_t x = APHRODITE_LDG(&input[token_idx * 2 * d + idx]);
@@ -23,70 +23,64 @@ __global__ void act_and_mul_kernel(
   }
 }
 
-template<typename T>
+template <typename T>
 __device__ __forceinline__ T silu_kernel(const T& x) {
   // x * sigmoid(x)
-  return (T) (((float) x) / (1.0f + expf((float) -x)));
+  return (T)(((float)x) / (1.0f + expf((float)-x)));
 }
 
-template<typename T>
+template <typename T>
 __device__ __forceinline__ T gelu_kernel(const T& x) {
   // Equivalent to PyTorch GELU with 'none' approximation.
   // Refer to:
   // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
-  const float f = (float) x;
+  const float f = (float)x;
   constexpr float ALPHA = M_SQRT1_2;
-  return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
+  return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
 }
 
-template<typename T>
+template <typename T>
 __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
   // Equivalent to PyTorch GELU with `tanh` approximation
-  const float f = (float) x;
+  const float f = (float)x;
   constexpr float BETA = M_SQRT2 * M_2_SQRTPI / 0.5f;
   constexpr float KAPPA = 0.044715;
   float x_cube = f * f * f;
   float inner = BETA * (f + KAPPA * x_cube);
-  return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
+  return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
 }
 
-} // namespace aphrodite
+}  // namespace aphrodite
 
 // Launch activation and gating kernel.
-#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL)                                                 \
-  int d = input.size(-1) / 2;                                                                 \
-  int64_t num_tokens = input.numel() / input.size(-1);                                        \
-  dim3 grid(num_tokens);                                                                      \
-  dim3 block(std::min(d, 1024));                                                              \
-  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                           \
-  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                               \
-  APHRODITE_DISPATCH_FLOATING_TYPES(                                                          \
-    input.scalar_type(),                                                                      \
-    "act_and_mul_kernel",                                                                     \
-    [&] {                                                                                     \
-      aphrodite::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(  \
-        out.data_ptr<scalar_t>(),                                                             \
-        input.data_ptr<scalar_t>(),                                                           \
-        d);                                                                                   \
-    });
-
-void silu_and_mul(
-  torch::Tensor& out,      // [..., d]
-  torch::Tensor& input)    // [..., 2 * d]
+#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL)                            \
+  int d = input.size(-1) / 2;                                            \
+  int64_t num_tokens = input.numel() / input.size(-1);                   \
+  dim3 grid(num_tokens);                                                 \
+  dim3 block(std::min(d, 1024));                                         \
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));      \
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();          \
+  APHRODITE_DISPATCH_FLOATING_TYPES(                                     \
+      input.scalar_type(), "act_and_mul_kernel", [&] {                   \
+        aphrodite::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>>        \
+            <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),       \
+                                         input.data_ptr<scalar_t>(), d); \
+      });
+
+void silu_and_mul(torch::Tensor& out,    // [..., d]
+                  torch::Tensor& input)  // [..., 2 * d]
 {
   LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::silu_kernel);
 }
 
-void gelu_and_mul(
-  torch::Tensor& out,      // [..., d]
-  torch::Tensor& input)    // [..., 2 * d]
+void gelu_and_mul(torch::Tensor& out,    // [..., d]
+                  torch::Tensor& input)  // [..., 2 * d]
 {
   LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::gelu_kernel);
 }
 
-void gelu_tanh_and_mul(
-  torch::Tensor& out,      // [..., d]
-  torch::Tensor& input)    // [..., 2 * d]
+void gelu_tanh_and_mul(torch::Tensor& out,    // [..., d]
+                       torch::Tensor& input)  // [..., 2 * d]
 {
   LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::gelu_tanh_kernel);
 }
@@ -94,11 +88,11 @@ void gelu_tanh_and_mul(
 namespace aphrodite {
 
 // Element-wise activation kernel template.
-template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
+template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
 __global__ void activation_kernel(
-  scalar_t* __restrict__ out,               // [..., d]
-  const scalar_t* __restrict__ input,       // [..., d]
-  const int d) {
+    scalar_t* __restrict__ out,          // [..., d]
+    const scalar_t* __restrict__ input,  // [..., d]
+    const int d) {
   const int64_t token_idx = blockIdx.x;
   for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
     const scalar_t x = APHRODITE_LDG(&input[token_idx * d + idx]);
@@ -106,54 +100,50 @@ __global__ void activation_kernel(
   }
 }
 
-} // namespace aphrodite
+}  // namespace aphrodite
 
 // Launch element-wise activation kernel.
-#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                                  \
-  int d = input.size(-1);                                                                 \
-  int64_t num_tokens = input.numel() / d;                                                 \
-  dim3 grid(num_tokens);                                                                  \
-  dim3 block(std::min(d, 1024));                                                          \
-  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                       \
-  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \
-  APHRODITE_DISPATCH_FLOATING_TYPES(                                                           \
-    input.scalar_type(),                                                                  \
-    "activation_kernel",                                                                  \
-    [&] {                                                                                 \
-      aphrodite::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(    \
-        out.data_ptr<scalar_t>(),                                                         \
-        input.data_ptr<scalar_t>(),                                                       \
-        d);                                                                               \
-    });
+#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                 \
+  int d = input.size(-1);                                                \
+  int64_t num_tokens = input.numel() / d;                                \
+  dim3 grid(num_tokens);                                                 \
+  dim3 block(std::min(d, 1024));                                         \
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));      \
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();          \
+  APHRODITE_DISPATCH_FLOATING_TYPES(                                     \
+      input.scalar_type(), "activation_kernel", [&] {                    \
+        aphrodite::activation_kernel<scalar_t, KERNEL<scalar_t>>         \
+            <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),       \
+                                         input.data_ptr<scalar_t>(), d); \
+      });
 
 namespace aphrodite {
 
-template<typename T>
+template <typename T>
 __device__ __forceinline__ T gelu_new_kernel(const T& x) {
-  const float x3 = (float) (x * x * x);
-  const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
-  return ((T) 0.5) * x * (((T) 1.0) + t);
+  const float x3 = (float)(x * x * x);
+  const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
+  return ((T)0.5) * x * (((T)1.0) + t);
 }
 
-template<typename T>
+template <typename T>
 __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
-  const float f = (float) x;
-  const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
-  return ((T) 0.5) * x * (((T) 1.0) + t);
+  const float f = (float)x;
+  const T t =
+      (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
+  return ((T)0.5) * x * (((T)1.0) + t);
 }
 
-} // namespace aphrodite
+}  // namespace aphrodite
 
-void gelu_new(
-  torch::Tensor& out,     // [..., d]
-  torch::Tensor& input)   // [..., d]
+void gelu_new(torch::Tensor& out,    // [..., d]
+              torch::Tensor& input)  // [..., d]
 {
   LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_new_kernel);
 }
 
-void gelu_fast(
-  torch::Tensor& out,     // [..., d]
-  torch::Tensor& input)   // [..., d]
+void gelu_fast(torch::Tensor& out,    // [..., d]
+               torch::Tensor& input)  // [..., d]
 {
   LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_fast_kernel);
 }

+ 41 - 35
kernels/all_reduce/custom_all_reduce.cu

@@ -1,17 +1,17 @@
 #include <ATen/cuda/Exceptions.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <c10/cuda/CUDAStream.h>
-#include <torch/extension.h>
+#include <torch/all.h>
 
 #include "custom_all_reduce.cuh"
 
-// fake pointer type
-using fptr_t = uint64_t;
-static_assert(sizeof(void *) == sizeof(fptr_t));
+// fake pointer type, must match fptr_t type in ops.h
+using fptr_t = int64_t;
+static_assert(sizeof(void*) == sizeof(fptr_t));
 
-fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
-                      const std::vector<std::string> &handles,
-                      const std::vector<int64_t> &offsets, int rank,
+fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
+                      const std::vector<std::string>& handles,
+                      const std::vector<int64_t>& offsets, int64_t rank,
                       bool full_nvlink) {
   int world_size = offsets.size();
   if (world_size > 8)
@@ -29,8 +29,9 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
     std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
   }
   return (fptr_t) new aphrodite::CustomAllreduce(
-      reinterpret_cast<aphrodite::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
-      rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
+      reinterpret_cast<aphrodite::Signal*>(meta.data_ptr()),
+      rank_data.data_ptr(), rank_data.numel(), ipc_handles, offsets, rank,
+      full_nvlink);
 }
 
 /**
@@ -49,13 +50,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
  * 5. A[None].expand(2, -1, -1, -1): Not OK
  * 6. A[:, 1:, 1:]: Not OK
  */
-bool _is_weak_contiguous(torch::Tensor &t) {
+bool _is_weak_contiguous(torch::Tensor& t) {
   return t.is_contiguous() ||
          (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
           t.numel() * t.element_size());
 }
 
-bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
+bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
                       bool full_nvlink) {
   auto inp_size = inp.numel() * inp.element_size();
   // custom allreduce requires input byte size to be multiples of 16
@@ -67,28 +68,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
   return false;
 }
 
-void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
+void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
                  cudaStream_t stream) {
-  auto fa = reinterpret_cast<aphrodite::CustomAllreduce *>(_fa);
+  auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
   TORCH_CHECK(_is_weak_contiguous(out));
   switch (out.scalar_type()) {
     case at::ScalarType::Float: {
-      fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
-                           reinterpret_cast<float *>(out.data_ptr()),
+      fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
+                           reinterpret_cast<float*>(out.data_ptr()),
                            out.numel());
       break;
     }
     case at::ScalarType::Half: {
-      fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
-                          reinterpret_cast<half *>(out.data_ptr()),
-                          out.numel());
+      fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
+                          reinterpret_cast<half*>(out.data_ptr()), out.numel());
       break;
     }
 #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
     case at::ScalarType::BFloat16: {
       fa->allreduce<nv_bfloat16>(
-          stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
-          reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
+          stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
+          reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
       break;
     }
 #endif
@@ -98,7 +98,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
   }
 }
 
-void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
+void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
   const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
   auto stream = c10::cuda::getCurrentCUDAStream().stream();
   TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
@@ -106,8 +106,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
   _all_reduce(_fa, inp, out, stream);
 }
 
-void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
-                      torch::Tensor &out) {
+void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
+                      torch::Tensor& out) {
   const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
   auto stream = c10::cuda::getCurrentCUDAStream().stream();
 
@@ -122,27 +122,33 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
 }
 
 void dispose(fptr_t _fa) {
-  auto fa = reinterpret_cast<aphrodite::CustomAllreduce *>(_fa);
+  auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
   delete fa;
 }
 
-int meta_size() { return sizeof(aphrodite::Signal); }
+int64_t meta_size() { return sizeof(aphrodite::Signal); }
 
-void register_buffer(fptr_t _fa, torch::Tensor &t,
-                     const std::vector<std::string> &handles,
-                     const std::vector<int64_t> &offsets) {
-  auto fa = reinterpret_cast<aphrodite::CustomAllreduce *>(_fa);
+void register_buffer(fptr_t _fa, torch::Tensor& t,
+                     const std::vector<std::string>& handles,
+                     const std::vector<int64_t>& offsets) {
+  auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
   fa->register_buffer(handles, offsets, t.data_ptr());
 }
 
-std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
+std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
     fptr_t _fa) {
-  auto fa = reinterpret_cast<aphrodite::CustomAllreduce *>(_fa);
-  return fa->get_graph_buffer_ipc_meta();
+  auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
+  auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
+  auto options =
+      torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
+  auto handles =
+      torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
+  std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
+  return {handles, std::move(offsets)};
 }
 
-void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
-                            const std::vector<std::vector<int64_t>> &offsets) {
-  auto fa = reinterpret_cast<aphrodite::CustomAllreduce *>(_fa);
+void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
+                            const std::vector<std::vector<int64_t>>& offsets) {
+  auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
   fa->register_graph_buffers(handles, offsets);
 }

+ 17 - 15
kernels/attention/attention_kernels.cu

@@ -18,7 +18,7 @@
  * limitations under the License.
  */
 
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <algorithm>
@@ -809,16 +809,17 @@ void paged_attention_v1(
     torch::Tensor&
         key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
     torch::Tensor&
-        value_cache,   // [num_blocks, num_heads, head_size, block_size]
-    int num_kv_heads,  // [num_heads]
-    float scale,
+        value_cache,       // [num_blocks, num_heads, head_size, block_size]
+    int64_t num_kv_heads,  // [num_heads]
+    double scale,
     torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
     torch::Tensor& seq_lens,      // [num_seqs]
-    int block_size, int max_seq_len,
+    int64_t block_size, int64_t max_seq_len,
     const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
-    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
-    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
+    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
+    const int64_t blocksparse_local_blocks,
+    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
+    const int64_t blocksparse_head_sliding_step) {
   const bool is_block_sparse = (blocksparse_vert_stride > 1);
 
   DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
@@ -973,16 +974,17 @@ void paged_attention_v2(
     torch::Tensor&
         key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
     torch::Tensor&
-        value_cache,   // [num_blocks, num_heads, head_size, block_size]
-    int num_kv_heads,  // [num_heads]
-    float scale,
+        value_cache,       // [num_blocks, num_heads, head_size, block_size]
+    int64_t num_kv_heads,  // [num_heads]
+    double scale,
     torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
     torch::Tensor& seq_lens,      // [num_seqs]
-    int block_size, int max_seq_len,
+    int64_t block_size, int64_t max_seq_len,
     const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
-    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
-    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
+    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
+    const int64_t blocksparse_local_blocks,
+    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
+    const int64_t blocksparse_head_sliding_step) {
   const bool is_block_sparse = (blocksparse_vert_stride > 1);
   DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
                              CALL_V2_LAUNCHER_BLOCK_SIZE)

+ 9 - 5
kernels/cache.h

@@ -1,6 +1,6 @@
 #pragma once
 
-#include <torch/extension.h>
+#include <torch/all.h>
 
 #include <map>
 #include <vector>
@@ -8,14 +8,18 @@
 void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
                  const torch::Tensor& block_mapping);
 
-void copy_blocks(std::vector<torch::Tensor>& key_caches,
-                 std::vector<torch::Tensor>& value_caches,
+// NOTE: the key_caches and value_caches vectors are constant but
+// not the Tensors they contain. The vectors need to be const refs
+// in order to satisfy pytorch's C++ operator registration code.
+void copy_blocks(std::vector<torch::Tensor> const& key_caches,
+                 std::vector<torch::Tensor> const& value_caches,
                  const torch::Tensor& block_mapping);
 
 void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
                        torch::Tensor& key_cache, torch::Tensor& value_cache,
                        torch::Tensor& slot_mapping,
-                       const std::string& kv_cache_dtype, const float kv_scale);
+                       const std::string& kv_cache_dtype,
+                       const double kv_scale);
 
 void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
                              torch::Tensor& key_cache,
@@ -25,4 +29,4 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
 
 // Just for unittest
 void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
-                 const float scale, const std::string& kv_cache_dtype);
+                 const double scale, const std::string& kv_cache_dtype);

+ 8 - 5
kernels/cache_kernels.cu

@@ -1,4 +1,4 @@
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 
@@ -95,8 +95,11 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
 
 }  // namespace aphrodite
 
-void copy_blocks(std::vector<torch::Tensor>& key_caches,
-                 std::vector<torch::Tensor>& value_caches,
+// NOTE: the key_caches and value_caches vectors are constant but
+// not the Tensors they contain. The vectors need to be const refs
+// in order to satisfy pytorch's C++ operator registration code.
+void copy_blocks(std::vector<torch::Tensor> const& key_caches,
+                 std::vector<torch::Tensor> const& value_caches,
                  const torch::Tensor& block_mapping) {
   int num_layers = key_caches.size();
   TORCH_CHECK(num_layers == value_caches.size());
@@ -255,7 +258,7 @@ void reshape_and_cache(
     torch::Tensor&
         value_cache,  // [num_blocks, num_heads, head_size, block_size]
     torch::Tensor& slot_mapping,  // [num_tokens]
-    const std::string& kv_cache_dtype, const float kv_scale) {
+    const std::string& kv_cache_dtype, const double kv_scale) {
   int num_tokens = key.size(0);
   int num_heads = key.size(1);
   int head_size = key.size(2);
@@ -336,7 +339,7 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
 
 // Only for testing.
 void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
-                 const float kv_scale, const std::string& kv_cache_dtype) {
+                 const double kv_scale, const std::string& kv_cache_dtype) {
   torch::Device src_device = src_cache.device();
   torch::Device dst_device = dst_cache.device();
   TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")

+ 14 - 12
kernels/cpu/attention.cpp

@@ -420,12 +420,13 @@ void paged_attention_v1_impl_launcher(
 
 void paged_attention_v1(
     torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
-    torch::Tensor& value_cache, int num_kv_heads, float scale,
-    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
-    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
-    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
-    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
+    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
+    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
+    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
+    const int64_t blocksparse_local_blocks,
+    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
+    const int64_t blocksparse_head_sliding_step) {
   TORCH_CHECK(kv_scale == 1.0f);
   TORCH_CHECK(blocksparse_vert_stride <= 1,
               "CPU backend does not support blocksparse attention yet.");
@@ -738,12 +739,13 @@ void paged_attention_v2_impl_launcher(
 void paged_attention_v2(
     torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
     torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
-    torch::Tensor& value_cache, int num_kv_heads, float scale,
-    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
-    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
-    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
-    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
+    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
+    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
+    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
+    const int64_t blocksparse_local_blocks,
+    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
+    const int64_t blocksparse_head_sliding_step) {
   TORCH_CHECK(kv_scale == 1.0f);
   TORCH_CHECK(blocksparse_vert_stride <= 1,
               "CPU backend does not support blocksparse attention yet.");

+ 8 - 5
kernels/cpu/cache.cpp

@@ -5,8 +5,8 @@
 
 namespace {
 template <typename scalar_t>
-void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
-                          std::vector<torch::Tensor>& value_caches,
+void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
+                          std::vector<torch::Tensor> const& value_caches,
                           const torch::Tensor& mapping_pairs,
                           const int element_num_per_block,
                           const int layer_num) {
@@ -82,8 +82,11 @@ void reshape_and_cache_cpu_impl(
 }
 };  // namespace
 
-void copy_blocks(std::vector<torch::Tensor>& key_caches,
-                 std::vector<torch::Tensor>& value_caches,
+// NOTE: the key_caches and value_caches vectors are constant but
+// not the Tensors they contain. The vectors need to be const refs
+// in order to satisfy pytorch's C++ operator registration code.
+void copy_blocks(std::vector<torch::Tensor> const& key_caches,
+                 std::vector<torch::Tensor> const& value_caches,
                  const torch::Tensor& block_mapping) {
   unsigned num_layers = key_caches.size();
   TORCH_CHECK(num_layers == value_caches.size());
@@ -104,7 +107,7 @@ void copy_blocks(std::vector<torch::Tensor>& key_caches,
 void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
                        torch::Tensor& key_cache, torch::Tensor& value_cache,
                        torch::Tensor& slot_mapping,
-                       const std::string& kv_cache_dtype, float kv_scale) {
+                       const std::string& kv_cache_dtype, double kv_scale) {
   TORCH_CHECK(kv_scale == 1.0f);
   int num_tokens = key.size(0);
   int num_heads = key.size(1);

+ 1 - 1
kernels/cpu/cpu_types.hpp

@@ -2,7 +2,7 @@
 #define CPU_TYPES_HPP
 
 #include <immintrin.h>
-#include <torch/extension.h>
+#include <torch/all.h>
 
 namespace vec_op {
 

+ 16 - 16
kernels/cpu/layernorm.cpp

@@ -2,10 +2,10 @@
 
 namespace {
 template <typename scalar_t>
-void rms_norm_impl(scalar_t *__restrict__ out,
-                       const scalar_t *__restrict__ input,
-                       const scalar_t *__restrict__ weight, const float epsilon,
-                       const int num_tokens, const int hidden_size) {
+void rms_norm_impl(scalar_t* __restrict__ out,
+                   const scalar_t* __restrict__ input,
+                   const scalar_t* __restrict__ weight, const float epsilon,
+                   const int num_tokens, const int hidden_size) {
   using scalar_vec_t = vec_op::vec_t<scalar_t>;
   constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
   TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
@@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out,
 }
 
 template <typename scalar_t>
-void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
-                                 scalar_t *__restrict__ residual,
-                                 const scalar_t *__restrict__ weight,
-                                 const float epsilon, const int num_tokens,
-                                 const int hidden_size) {
+void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
+                             scalar_t* __restrict__ residual,
+                             const scalar_t* __restrict__ weight,
+                             const float epsilon, const int num_tokens,
+                             const int hidden_size) {
   using scalar_vec_t = vec_op::vec_t<scalar_t>;
   constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
   TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
@@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
     }
   }
 }
-} // namespace
+}  // namespace
 
-void rms_norm(torch::Tensor &out, torch::Tensor &input,
-                  torch::Tensor &weight, float epsilon) {
+void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
+              double epsilon) {
   int hidden_size = input.size(-1);
   int num_tokens = input.numel() / hidden_size;
 
   APHRODITE_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
     CPU_KERNEL_GUARD_IN(rms_norm_impl)
     rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
-                      weight.data_ptr<scalar_t>(), epsilon, num_tokens,
-                      hidden_size);
+                  weight.data_ptr<scalar_t>(), epsilon, num_tokens,
+                  hidden_size);
     CPU_KERNEL_GUARD_OUT(rms_norm_impl)
   });
 }
 
-void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual,
-                            torch::Tensor &weight, float epsilon) {
+void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
+                        torch::Tensor& weight, double epsilon) {
   int hidden_size = input.size(-1);
   int num_tokens = input.numel() / hidden_size;
 

+ 1 - 1
kernels/cpu/pos_encoding.cpp

@@ -168,7 +168,7 @@ void rotary_embedding_gptj_impl(
 };  // namespace
 
 void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
-                      torch::Tensor& key, int head_size,
+                      torch::Tensor& key, int64_t head_size,
                       torch::Tensor& cos_sin_cache, bool is_neox) {
   int num_tokens = query.numel() / query.size(-1);
   int rot_dim = cos_sin_cache.size(1);

+ 0 - 44
kernels/cpu/pybind.cpp

@@ -1,44 +0,0 @@
-#include "../cache.h"
-#include "../ops.h"
-#include <torch/extension.h>
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-  // Aphrodite custom ops
-  pybind11::module ops = m.def_submodule("ops", "Aphrodite custom operators");
-
-  // Attention ops
-  ops.def("paged_attention_v1", &paged_attention_v1,
-          "Compute the attention between an input query and the cached "
-          "keys/values using PagedAttention.");
-  ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
-
-  // Activation ops
-  ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
-  ops.def("gelu_and_mul", &gelu_and_mul,
-          "Activation function used in GeGLU with `none` approximation.");
-  ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
-          "Activation function used in GeGLU with `tanh` approximation.");
-  ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
-  ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
-
-  // Layernorm
-  ops.def("rms_norm", &rms_norm,
-          "Apply Root Mean Square (RMS) Normalization to the input tensor.");
-
-  ops.def("fused_add_rms_norm", &fused_add_rms_norm,
-          "In-place fused Add and RMS Normalization");
-
-  // Rotary embedding
-  ops.def("rotary_embedding", &rotary_embedding,
-          "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
-
-  // Cache ops
-  pybind11::module cache_ops =
-      m.def_submodule("cache_ops", "Aphrodite cache ops");
-  cache_ops.def("swap_blocks", &swap_blocks,
-                "Swap in (out) the cache blocks from src to dst");
-  cache_ops.def("copy_blocks", &copy_blocks,
-                "Copy the cache blocks from src to dst");
-  cache_ops.def("reshape_and_cache", &reshape_and_cache,
-                "Reshape the key and value tensors and cache them");
-}

+ 106 - 0
kernels/cpu/torch_bindings.cpp

@@ -0,0 +1,106 @@
+#include "cache.h"
+#include "ops.h"
+#include "registration.h"
+
+#include <torch/library.h>
+
+TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
+  // Aphrodite custom ops
+
+  // Attention ops
+  // Compute the attention between an input query and the cached keys/values
+  // using PagedAttention.
+  ops.def(
+      "paged_attention_v1("
+      "    Tensor! out, Tensor query, Tensor key_cache,"
+      "    Tensor value_cache, int num_kv_heads, float scale,"
+      "    Tensor block_tables, Tensor seq_lens, int block_size,"
+      "    int max_seq_len, Tensor? alibi_slopes,"
+      "    str kv_cache_dtype, float kv_scale, int tp_rank,"
+      "    int blocksparse_local_blocks,"
+      "    int blocksparse_vert_stride, int blocksparse_block_size,"
+      "    int blocksparse_head_sliding_step) -> ()");
+  ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
+
+  // PagedAttention V2.
+  ops.def(
+      "paged_attention_v2("
+      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
+      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
+      "    Tensor value_cache, int num_kv_heads, float scale,"
+      "    Tensor block_tables, Tensor seq_lens, int block_size,"
+      "    int max_seq_len, Tensor? alibi_slopes,"
+      "    str kv_cache_dtype, float kv_scale, int tp_rank,"
+      "    int blocksparse_local_blocks,"
+      "    int blocksparse_vert_stride, int blocksparse_block_size,"
+      "    int blocksparse_head_sliding_step) -> ()");
+  ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
+
+  // Activation ops
+
+  // Activation function used in SwiGLU.
+  ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
+  ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);
+
+  // Activation function used in GeGLU with `none` approximation.
+  ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
+  ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);
+
+  // Activation function used in GeGLU with `tanh` approximation.
+  ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
+  ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);
+
+  // GELU implementation used in GPT-2.
+  ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
+  ops.impl("gelu_new", torch::kCPU, &gelu_new);
+
+  // Approximate GELU implementation.
+  ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
+  ops.impl("gelu_fast", torch::kCPU, &gelu_fast);
+
+  // Layernorm
+  // Apply Root Mean Square (RMS) Normalization to the input tensor.
+  ops.def(
+      "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
+      "()");
+  ops.impl("rms_norm", torch::kCPU, &rms_norm);
+
+  // In-place fused Add and RMS Normalization.
+  ops.def(
+      "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
+      "float epsilon) -> ()");
+  ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);
+
+  // Rotary embedding
+  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
+  ops.def(
+      "rotary_embedding(Tensor positions, Tensor! query,"
+      "                 Tensor! key, int head_size,"
+      "                 Tensor cos_sin_cache, bool is_neox) -> ()");
+  ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
+}
+
+TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
+  // Cache ops
+  // Swap in (out) the cache blocks from src to dst.
+  cache_ops.def(
+      "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
+  cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);
+
+  // Copy the cache blocks from src to dst.
+  cache_ops.def(
+      "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
+      "block_mapping) -> ()");
+  cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);
+
+  // Reshape the key and value tensors and cache them.
+  cache_ops.def(
+      "reshape_and_cache(Tensor key, Tensor value,"
+      "                  Tensor! key_cache, Tensor! value_cache,"
+      "                  Tensor slot_mapping,"
+      "                  str kv_cache_dtype,"
+      "                  float kv_scale) -> ()");
+  cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
+}
+
+REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

+ 2 - 7
kernels/cuda_utils.h

@@ -1,10 +1,5 @@
 #pragma once
 
-#include <torch/extension.h>
+int64_t get_device_attribute(int64_t attribute, int64_t device_id);
 
-int get_device_attribute(
-    int attribute,
-    int device_id);
-
-int get_max_shared_memory_per_block_device_attribute(
-    int device_id);
+int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);

+ 17 - 22
kernels/cuda_utils_kernels.cu

@@ -2,33 +2,28 @@
   #include <hip/hip_runtime.h>
   #include <hip/hip_runtime_api.h>
 #endif
-int get_device_attribute(
-    int attribute,
-    int device_id)
-{
-    int device, value;
-    if (device_id < 0) {
-        cudaGetDevice(&device);
-    }
-    else {
-        device = device_id;
-    }
-    cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
-    return value;
+int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
+  int device, value;
+  if (device_id < 0) {
+    cudaGetDevice(&device);
+  } else {
+    device = device_id;
+  }
+  cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
+                         device);
+  return value;
 }
 
-int get_max_shared_memory_per_block_device_attribute(
-    int device_id)
-{
-int attribute;    
-// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
-// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
+int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) {
+  int64_t attribute;
+  // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
+  // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
 
 #ifdef USE_ROCM
-    attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
+  attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
 #else
-    attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
+  attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
 #endif
 
-    return get_device_attribute(attribute, device_id);
+  return get_device_attribute(attribute, device_id);
 }

+ 1 - 1
kernels/dispatch_utils.h

@@ -4,7 +4,7 @@
  */
 #pragma once
 
-#include <torch/extension.h>
+#include <torch/all.h>
 
 #define APHRODITE_DISPATCH_CASE_FLOATING_TYPES(...)    \
   AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \

+ 121 - 120
kernels/layernorm_kernels.cu

@@ -1,4 +1,4 @@
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 
@@ -11,26 +11,24 @@
   #include <hip/hip_bf16.h>
   #include <hip/hip_fp16.h>
 
-  using __nv_bfloat16 = __hip_bfloat16;
-  using __nv_bfloat162 = __hip_bfloat162;
+using __nv_bfloat16 = __hip_bfloat16;
+using __nv_bfloat162 = __hip_bfloat162;
 #endif
 
 namespace aphrodite {
 
 // TODO: Further optimize this kernel.
-template<typename scalar_t>
+template <typename scalar_t>
 __global__ void rms_norm_kernel(
-  scalar_t* __restrict__ out,             // [..., hidden_size]
-  const scalar_t* __restrict__ input,     // [..., hidden_size]
-  const scalar_t* __restrict__ weight,    // [hidden_size]
-  const float epsilon,
-  const int num_tokens,
-  const int hidden_size) {
+    scalar_t* __restrict__ out,           // [..., hidden_size]
+    const scalar_t* __restrict__ input,   // [..., hidden_size]
+    const scalar_t* __restrict__ weight,  // [hidden_size]
+    const float epsilon, const int num_tokens, const int hidden_size) {
   __shared__ float s_variance;
   float variance = 0.0f;
 
   for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
-    const float x = (float) input[blockIdx.x * hidden_size + idx];
+    const float x = (float)input[blockIdx.x * hidden_size + idx];
     variance += x * x;
   }
   variance = blockReduceSum<float>(variance);
@@ -40,8 +38,9 @@ __global__ void rms_norm_kernel(
   __syncthreads();
 
   for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
-    float x = (float) input[blockIdx.x * hidden_size + idx];
-    out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
+    float x = (float)input[blockIdx.x * hidden_size + idx];
+    out[blockIdx.x * hidden_size + idx] =
+        ((scalar_t)(x * s_variance)) * weight[idx];
   }
 }
 
@@ -52,52 +51,68 @@ __global__ void rms_norm_kernel(
    a generic conversion via type casts cannot be implemented.
    Each struct should have the member static constexpr bool `exists`:
    If false, the optimized kernel is not used for the corresponding torch type.
-   If true, the struct should be fully defined as shown in the examples below. 
+   If true, the struct should be fully defined as shown in the examples below.
  */
-template<typename torch_type>
-struct _typeConvert { static constexpr bool exists = false; };
+template <typename torch_type>
+struct _typeConvert {
+  static constexpr bool exists = false;
+};
 
 #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
 // CUDA < 12.0 runs into issues with packed type conversion
-template<>
+template <>
 struct _typeConvert<c10::Half> {
   static constexpr bool exists = true;
   using hip_type = __half;
   using packed_hip_type = __half2;
 
   __device__ static inline float convert(hip_type x) { return __half2float(x); }
-  __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
-  __device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
-  __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
+  __device__ static inline float2 convert(packed_hip_type x) {
+    return __half22float2(x);
+  }
+  __device__ static inline hip_type convert(float x) {
+    return __float2half_rn(x);
+  }
+  __device__ static inline packed_hip_type convert(float2 x) {
+    return __float22half2_rn(x);
+  }
 };
 
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
 // CUDA_ARCH < 800 does not have BF16 support
 // TODO: Add in ROCm support once public headers handle bf16 maturely
-template<>
+template <>
 struct _typeConvert<c10::BFloat16> {
   static constexpr bool exists = true;
   using hip_type = __nv_bfloat16;
   using packed_hip_type = __nv_bfloat162;
 
-  __device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
-  __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
-  __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
-  __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
+  __device__ static inline float convert(hip_type x) {
+    return __bfloat162float(x);
+  }
+  __device__ static inline float2 convert(packed_hip_type x) {
+    return __bfloat1622float2(x);
+  }
+  __device__ static inline hip_type convert(float x) {
+    return __float2bfloat16(x);
+  }
+  __device__ static inline packed_hip_type convert(float2 x) {
+    return __float22bfloat162_rn(x);
+  }
 };
-#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
-#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
-
+  #endif  // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+#endif    // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
+          // 12000))
 
 /* Vector POD struct to generate vectorized and packed FP16/BF16 ops
    for appropriate specializations of fused_add_rms_norm_kernel.
    Only functions that are necessary in that kernel are implemented.
    Alignment to 16 bytes is required to use 128-bit global memory ops.
  */
-template<typename scalar_t, int width>
+template <typename scalar_t, int width>
 struct alignas(16) _f16Vec {
-  /* Not theoretically necessary that width is a power of 2 but should 
-     almost always be the case for optimization purposes */ 
+  /* Not theoretically necessary that width is a power of 2 but should
+     almost always be the case for optimization purposes */
   static_assert(width > 0 && (width & (width - 1)) == 0,
                 "Width is not a positive power of 2!");
   using Converter = _typeConvert<scalar_t>;
@@ -107,51 +122,49 @@ struct alignas(16) _f16Vec {
 
   __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
     if constexpr (width % 2 == 0) {
-      #pragma unroll
+#pragma unroll
       for (int i = 0; i < width; i += 2) {
-        T2 temp{data[i], data[i+1]};
-        temp += T2{other.data[i], other.data[i+1]};
+        T2 temp{data[i], data[i + 1]};
+        temp += T2{other.data[i], other.data[i + 1]};
         data[i] = temp.x;
-        data[i+1] = temp.y;
+        data[i + 1] = temp.y;
       }
     } else {
-      #pragma unroll
-      for (int i = 0; i < width; ++i)
-        data[i] += other.data[i];
+#pragma unroll
+      for (int i = 0; i < width; ++i) data[i] += other.data[i];
     }
     return *this;
   }
 
   __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
     if constexpr (width % 2 == 0) {
-      #pragma unroll
+#pragma unroll
       for (int i = 0; i < width; i += 2) {
-        T2 temp{data[i], data[i+1]};
-        temp *= T2{other.data[i], other.data[i+1]};
+        T2 temp{data[i], data[i + 1]};
+        temp *= T2{other.data[i], other.data[i + 1]};
         data[i] = temp.x;
-        data[i+1] = temp.y;
+        data[i + 1] = temp.y;
       }
     } else {
-      #pragma unroll
-      for (int i = 0; i < width; ++i)
-        data[i] *= other.data[i];
+#pragma unroll
+      for (int i = 0; i < width; ++i) data[i] *= other.data[i];
     }
     return *this;
   }
 
   __device__ _f16Vec& operator*=(const float scale) {
     if constexpr (width % 2 == 0) {
-      #pragma unroll
+#pragma unroll
       for (int i = 0; i < width; i += 2) {
-        float2 temp_f = Converter::convert(T2{data[i], data[i+1]});
+        float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
         temp_f.x *= scale;
         temp_f.y *= scale;
         T2 temp = Converter::convert(temp_f);
         data[i] = temp.x;
-        data[i+1] = temp.y;
+        data[i + 1] = temp.y;
       }
     } else {
-      #pragma unroll
+#pragma unroll
       for (int i = 0; i < width; ++i) {
         float temp = Converter::convert(data[i]) * scale;
         data[i] = Converter::convert(temp);
@@ -163,13 +176,13 @@ struct alignas(16) _f16Vec {
   __device__ float sum_squares() const {
     float result = 0.0f;
     if constexpr (width % 2 == 0) {
-      #pragma unroll
+#pragma unroll
       for (int i = 0; i < width; i += 2) {
-        float2 z = Converter::convert(T2{data[i], data[i+1]});
+        float2 z = Converter::convert(T2{data[i], data[i + 1]});
         result += z.x * z.x + z.y * z.y;
       }
     } else {
-      #pragma unroll
+#pragma unroll
       for (int i = 0; i < width; ++i) {
         float x = Converter::convert(data[i]);
         result += x * x;
@@ -183,15 +196,13 @@ struct alignas(16) _f16Vec {
    Additional optimizations we can make in this case are
    packed and vectorized operations, which help with the
    memory latency bottleneck. */
-template<typename scalar_t, int width>
-__global__ std::enable_if_t<
-  (width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
-  scalar_t* __restrict__ input,           // [..., hidden_size]
-  scalar_t* __restrict__ residual,        // [..., hidden_size]
-  const scalar_t* __restrict__ weight,    // [hidden_size]
-  const float epsilon,
-  const int num_tokens,
-  const int hidden_size) {
+template <typename scalar_t, int width>
+__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
+fused_add_rms_norm_kernel(
+    scalar_t* __restrict__ input,         // [..., hidden_size]
+    scalar_t* __restrict__ residual,      // [..., hidden_size]
+    const scalar_t* __restrict__ weight,  // [hidden_size]
+    const float epsilon, const int num_tokens, const int hidden_size) {
   // Sanity checks on our vector struct and type-punned pointer arithmetic
   static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
   static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
@@ -202,9 +213,12 @@ __global__ std::enable_if_t<
   /* These and the argument pointers are all declared `restrict` as they are
      not aliased in practice. Argument pointers should not be dereferenced
      in this kernel as that would be undefined behavior */
-  auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
-  auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
-  auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
+  auto* __restrict__ input_v =
+      reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
+  auto* __restrict__ residual_v =
+      reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
+  auto* __restrict__ weight_v =
+      reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
 
   for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
     int id = blockIdx.x * vec_hidden_size + idx;
@@ -214,10 +228,11 @@ __global__ std::enable_if_t<
     residual_v[id] = temp;
   }
   /* Keep the following if-else block in sync with the
-     calculation of max_block_size in fused_add_rms_norm */ 
+     calculation of max_block_size in fused_add_rms_norm */
   if (num_tokens < 256) {
     variance = blockReduceSum<float, 1024>(variance);
-  } else variance = blockReduceSum<float, 256>(variance);
+  } else
+    variance = blockReduceSum<float, 256>(variance);
   if (threadIdx.x == 0) {
     s_variance = rsqrtf(variance / hidden_size + epsilon);
   }
@@ -232,52 +247,50 @@ __global__ std::enable_if_t<
   }
 }
 
-
 /* Generic fused_add_rms_norm_kernel
    The width field is not used here but necessary for other specializations.
  */
-template<typename scalar_t, int width>
-__global__ std::enable_if_t<
-  (width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
-  scalar_t* __restrict__ input,           // [..., hidden_size]
-  scalar_t* __restrict__ residual,        // [..., hidden_size]
-  const scalar_t* __restrict__ weight,    // [hidden_size]
-  const float epsilon,
-  const int num_tokens,
-  const int hidden_size) {
+template <typename scalar_t, int width>
+__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
+fused_add_rms_norm_kernel(
+    scalar_t* __restrict__ input,         // [..., hidden_size]
+    scalar_t* __restrict__ residual,      // [..., hidden_size]
+    const scalar_t* __restrict__ weight,  // [hidden_size]
+    const float epsilon, const int num_tokens, const int hidden_size) {
   __shared__ float s_variance;
   float variance = 0.0f;
 
   for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
     scalar_t z = input[blockIdx.x * hidden_size + idx];
     z += residual[blockIdx.x * hidden_size + idx];
-    float x = (float) z;
+    float x = (float)z;
     variance += x * x;
     residual[blockIdx.x * hidden_size + idx] = z;
   }
   /* Keep the following if-else block in sync with the
-     calculation of max_block_size in fused_add_rms_norm */ 
+     calculation of max_block_size in fused_add_rms_norm */
   if (num_tokens < 256) {
     variance = blockReduceSum<float, 1024>(variance);
-  } else variance = blockReduceSum<float, 256>(variance);
+  } else
+    variance = blockReduceSum<float, 256>(variance);
   if (threadIdx.x == 0) {
     s_variance = rsqrtf(variance / hidden_size + epsilon);
   }
   __syncthreads();
 
   for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
-    float x = (float) residual[blockIdx.x * hidden_size + idx];
-    input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
+    float x = (float)residual[blockIdx.x * hidden_size + idx];
+    input[blockIdx.x * hidden_size + idx] =
+        ((scalar_t)(x * s_variance)) * weight[idx];
   }
 }
 
-} // namespace aphrodite
+}  // namespace aphrodite
 
-void rms_norm(
-  torch::Tensor& out,      // [..., hidden_size]
-  torch::Tensor& input,    // [..., hidden_size]
-  torch::Tensor& weight,   // [hidden_size]
-  float epsilon) {
+void rms_norm(torch::Tensor& out,     // [..., hidden_size]
+              torch::Tensor& input,   // [..., hidden_size]
+              torch::Tensor& weight,  // [hidden_size]
+              double epsilon) {
   int hidden_size = input.size(-1);
   int num_tokens = input.numel() / hidden_size;
 
@@ -286,39 +299,27 @@ void rms_norm(
   const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
-    input.scalar_type(),
-    "rms_norm_kernel",
-    [&] {
-      aphrodite::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
-        out.data_ptr<scalar_t>(),
-        input.data_ptr<scalar_t>(),
-        weight.data_ptr<scalar_t>(),
-        epsilon,
-        num_tokens,
-        hidden_size);
-    });
+      input.scalar_type(), "rms_norm_kernel", [&] {
+        aphrodite::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
+            out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
+            weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
+      });
 }
 
-#define LAUNCH_FUSED_ADD_RMS_NORM(width)              \
-  APHRODITE_DISPATCH_FLOATING_TYPES(                  \
-    input.scalar_type(),                              \
-    "fused_add_rms_norm_kernel",                      \
-    [&] {                                             \
-      aphrodite::fused_add_rms_norm_kernel            \
-      <scalar_t, width><<<grid, block, 0, stream>>>(  \
-        input.data_ptr<scalar_t>(),                   \
-        residual.data_ptr<scalar_t>(),                \
-        weight.data_ptr<scalar_t>(),                  \
-        epsilon,                                      \
-        num_tokens,                                   \
-        hidden_size);                                 \
-    });
+#define LAUNCH_FUSED_ADD_RMS_NORM(width)                                       \
+  APHRODITE_DISPATCH_FLOATING_TYPES(                                           \
+      input.scalar_type(), "fused_add_rms_norm_kernel", [&] {                  \
+        aphrodite::fused_add_rms_norm_kernel<scalar_t, width>                  \
+            <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),           \
+                                         residual.data_ptr<scalar_t>(),        \
+                                         weight.data_ptr<scalar_t>(), epsilon, \
+                                         num_tokens, hidden_size);             \
+      });
 
-void fused_add_rms_norm(
-  torch::Tensor& input,    // [..., hidden_size]
-  torch::Tensor& residual, // [..., hidden_size]
-  torch::Tensor& weight,   // [hidden_size]
-  float epsilon) {
+void fused_add_rms_norm(torch::Tensor& input,     // [..., hidden_size]
+                        torch::Tensor& residual,  // [..., hidden_size]
+                        torch::Tensor& weight,    // [hidden_size]
+                        double epsilon) {
   int hidden_size = input.size(-1);
   int num_tokens = input.numel() / hidden_size;
 
@@ -341,8 +342,8 @@ void fused_add_rms_norm(
   auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
   auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
   auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
-  bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
-                          && wt_ptr % 16 == 0;
+  bool ptrs_are_aligned =
+      inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
   if (ptrs_are_aligned && hidden_size % 8 == 0) {
     LAUNCH_FUSED_ADD_RMS_NORM(8);
   } else {

+ 111 - 102
kernels/moe/align_block_size_kernel.cu

@@ -1,4 +1,4 @@
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <ATen/cuda/CUDAContext.h>
 
 #include <ATen/ATen.h>
@@ -7,119 +7,128 @@
 #include "../cuda_compat.h"
 #include "../dispatch_utils.h"
 
-#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
+#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
 
 namespace aphrodite {
 
 namespace {
-__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
-    // don't worry about overflow because num_experts is relatively small
-    return row * total_col + col;
-}
+__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
+                                         int32_t col) {
+  // don't worry about overflow because num_experts is relatively small
+  return row * total_col + col;
 }
+}  // namespace
 
 template <typename scalar_t>
-__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, 
-                                int32_t *sorted_token_ids, 
-                                int32_t *expert_ids, 
-                                int32_t *total_tokens_post_pad,
-                                int32_t num_experts, 
-                                int32_t block_size, 
-                                size_t numel) {
-    const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
-    const size_t start_idx = threadIdx.x * tokens_per_thread;
-
-    extern __shared__ int32_t shared_mem[];
-
-    int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
-    int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
-
-    for (int i = 0; i < num_experts; ++i) {
-        tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
-    }
-
-    /**
-    * In the first step we compute token_cnts[thread_index + 1][expert_index],
-    * which counts how many tokens in the token shard of thread_index are assigned
-    * to expert expert_index.
-    */
-    for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
-        ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; 
-    }
-
-    __syncthreads();
-
-    // For each expert we accumulate the token counts from the different threads.
-    tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
-    for (int i = 1; i <= blockDim.x; ++i) {
-        tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
-    }
-
-    __syncthreads();
-
-    // We accumulate the token counts of all experts in thread 0.
-    if (threadIdx.x == 0) {
-        cumsum[0] = 0;
-        for (int i = 1; i <= num_experts; ++i) {
-            cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
-        }
-        *total_tokens_post_pad = cumsum[num_experts];
-    }
-
-    __syncthreads();
-
-    /**
-    * For each expert, each thread processes the tokens of the corresponding blocks
-    * and stores the corresponding expert_id for each block.
-    */
-    for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
-        expert_ids[i / block_size] = threadIdx.x;
-    }
-
-    /**
-    * Each thread processes a token shard, calculating the index of each token after
-    * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
-    * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
-    * where * represents a padding value(preset in python).
-    */
-    for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
-        int32_t expert_id = topk_ids[i];
-        /** The cumsum[expert_id] stores the starting index of the tokens that the
-        * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
-        * stores the indices of the tokens processed by the expert with expert_id within
-        * the current thread's token shard.
-        */
-        int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
-        sorted_token_ids[rank_post_pad] = i;
-        ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
+__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
+                                            int32_t* sorted_token_ids,
+                                            int32_t* expert_ids,
+                                            int32_t* total_tokens_post_pad,
+                                            int32_t num_experts,
+                                            int32_t block_size, size_t numel) {
+  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
+  const size_t start_idx = threadIdx.x * tokens_per_thread;
+
+  extern __shared__ int32_t shared_mem[];
+
+  int32_t* tokens_cnts =
+      shared_mem;  // 2d tensor with shape (num_experts + 1, num_experts)
+  int32_t* cumsum =
+      shared_mem + (num_experts + 1) *
+                       num_experts;  // 1d tensor with shape (num_experts + 1)
+
+  for (int i = 0; i < num_experts; ++i) {
+    tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
+  }
+
+  /**
+   * In the first step we compute token_cnts[thread_index + 1][expert_index],
+   * which counts how many tokens in the token shard of thread_index are
+   * assigned to expert expert_index.
+   */
+  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
+    ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
+  }
+
+  __syncthreads();
+
+  // For each expert we accumulate the token counts from the different threads.
+  tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
+  for (int i = 1; i <= blockDim.x; ++i) {
+    tokens_cnts[index(num_experts, i, threadIdx.x)] +=
+        tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
+  }
+
+  __syncthreads();
+
+  // We accumulate the token counts of all experts in thread 0.
+  if (threadIdx.x == 0) {
+    cumsum[0] = 0;
+    for (int i = 1; i <= num_experts; ++i) {
+      cumsum[i] = cumsum[i - 1] +
+                  CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
+                          block_size) *
+                      block_size;
     }
+    *total_tokens_post_pad = cumsum[num_experts];
+  }
+
+  __syncthreads();
+
+  /**
+   * For each expert, each thread processes the tokens of the corresponding
+   * blocks and stores the corresponding expert_id for each block.
+   */
+  for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
+       i += block_size) {
+    expert_ids[i / block_size] = threadIdx.x;
+  }
+
+  /**
+   * Each thread processes a token shard, calculating the index of each token
+   * after sorting by expert number. Given the example topk_ids =
+   * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
+   * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
+   * padding value(preset in python).
+   */
+  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
+    int32_t expert_id = topk_ids[i];
+    /** The cumsum[expert_id] stores the starting index of the tokens that the
+     * expert with expert_id needs to process, and
+     * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
+     * processed by the expert with expert_id within the current thread's token
+     * shard.
+     */
+    int32_t rank_post_pad =
+        tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
+        cumsum[expert_id];
+    sorted_token_ids[rank_post_pad] = i;
+    ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
+  }
 }
-}
-
-void moe_align_block_size(
-    torch::Tensor topk_ids,
-    int num_experts,
-    int block_size,
-    torch::Tensor sorted_token_ids,
-    torch::Tensor experts_ids,
-    torch::Tensor num_tokens_post_pad) {
-    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-    APHRODITE_DISPATCH_INTEGRAL_TYPES(
-        topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
-        // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
-        const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
+}  // namespace aphrodite
+
+void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
+                          int64_t block_size, torch::Tensor sorted_token_ids,
+                          torch::Tensor experts_ids,
+                          torch::Tensor num_tokens_post_pad) {
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  APHRODITE_DISPATCH_INTEGRAL_TYPES(
+      topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
+        // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
+        // tensors
+        const int32_t shared_mem =
+            ((num_experts + 1) * num_experts + (num_experts + 1)) *
+            sizeof(int32_t);
 
         // set dynamic shared mem
         auto kernel = aphrodite::moe_align_block_size_kernel<scalar_t>;
-        AT_CUDA_CHECK(
-            APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem));
+        AT_CUDA_CHECK(APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
+            (void*)kernel, shared_mem));
         kernel<<<1, num_experts, shared_mem, stream>>>(
-            topk_ids.data_ptr<scalar_t>(),
-            sorted_token_ids.data_ptr<int32_t>(), 
-            experts_ids.data_ptr<int32_t>(), 
-            num_tokens_post_pad.data_ptr<int32_t>(), 
-            num_experts,
-            block_size,
+            topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
+            experts_ids.data_ptr<int32_t>(),
+            num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
             topk_ids.numel());
-    });
+      });
 }

+ 0 - 7
kernels/moe/moe_ops.cpp

@@ -1,7 +0,0 @@
-#include "moe_ops.h"
-
-#include <torch/extension.h>
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-    m.def("topk_softmax", &topk_softmax, "Apply top-k softmax to the gating outputs.");
-}

+ 4 - 6
kernels/moe/moe_ops.h

@@ -1,9 +1,7 @@
 #pragma once
 
-#include <torch/extension.h>
+#include <torch/all.h>
 
-void topk_softmax(
-    torch::Tensor& topk_weights,
-    torch::Tensor& topk_indices,
-    torch::Tensor& token_expert_indices,
-    torch::Tensor& gating_output);
+void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
+                  torch::Tensor& token_expert_indices,
+                  torch::Tensor& gating_output);

+ 1 - 1
kernels/moe/softmax.cu

@@ -18,7 +18,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 #include "../cuda_compat.h"

+ 12 - 0
kernels/moe/torch_bindings.cpp

@@ -0,0 +1,12 @@
+#include "registration.h"
+#include "moe_ops.h"
+
+TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
+  // Apply topk softmax to the gating outputs.
+  m.def(
+      "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
+      "token_expert_indices, Tensor gating_output) -> ()");
+  m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
+}
+
+REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

+ 27 - 25
kernels/ops.h

@@ -1,40 +1,42 @@
 #pragma once
 
-#include <torch/extension.h>
+#include <torch/library.h>
 
 void paged_attention_v1(
     torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
-    torch::Tensor& value_cache, int num_kv_heads, float scale,
-    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
-    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
-    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
-    const int blocksparse_block_size, const int blocksparse_head_sliding_step);
+    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
+    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
+    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
+    const int64_t blocksparse_local_blocks,
+    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
+    const int64_t blocksparse_head_sliding_step);
 
 void paged_attention_v2(
     torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
     torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
-    torch::Tensor& value_cache, int num_kv_heads, float scale,
-    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
-    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
-    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
-    const int blocksparse_block_size, const int blocksparse_head_sliding_step);
+    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
+    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
+    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
+    const int64_t blocksparse_local_blocks,
+    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
+    const int64_t blocksparse_head_sliding_step);
 
 void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
-              float epsilon);
+              double epsilon);
 
 void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
-                        torch::Tensor& weight, float epsilon);
+                        torch::Tensor& weight, double epsilon);
 
 void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
-                      torch::Tensor& key, int head_size,
+                      torch::Tensor& key, int64_t head_size,
                       torch::Tensor& cos_sin_cache, bool is_neox);
 
 void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
-                              torch::Tensor& key, int head_size,
+                              torch::Tensor& key, int64_t head_size,
                               torch::Tensor& cos_sin_cache, bool is_neox,
-                              int rot_dim,
+                              int64_t rot_dim,
                               torch::Tensor& cos_sin_cache_offsets);
 
 void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
@@ -47,28 +49,28 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input);
 
 void gelu_fast(torch::Tensor& out, torch::Tensor& input);
 
-void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
-                          int block_size, torch::Tensor sorted_token_ids,
+void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
+                          int64_t block_size, torch::Tensor sorted_token_ids,
                           torch::Tensor expert_ids,
                           torch::Tensor num_tokens_post_pad);
 
 #ifndef USE_ROCM
-using fptr_t = uint64_t;
+using fptr_t = int64_t;
 fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
                       const std::vector<std::string>& handles,
-                      const std::vector<int64_t>& offsets, int rank,
+                      const std::vector<int64_t>& offsets, int64_t rank,
                       bool full_nvlink);
-bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
+bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
                       bool full_nvlink);
 void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
 void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
                       torch::Tensor& out);
 void dispose(fptr_t _fa);
-int meta_size();
+int64_t meta_size();
 void register_buffer(fptr_t _fa, torch::Tensor& t,
                      const std::vector<std::string>& handles,
                      const std::vector<int64_t>& offsets);
-std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
+std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
     fptr_t _fa);
 void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
                             const std::vector<std::vector<int64_t>>& offsets);

+ 108 - 127
kernels/pos_encoding_kernels.cu

@@ -1,4 +1,4 @@
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 
@@ -7,14 +7,10 @@
 
 namespace aphrodite {
 
-template<typename scalar_t, bool IS_NEOX>
+template <typename scalar_t, bool IS_NEOX>
 inline __device__ void apply_token_rotary_embedding(
-  scalar_t* __restrict__ arr,
-  const scalar_t* __restrict__ cos_ptr,
-  const scalar_t* __restrict__ sin_ptr,
-  int rot_offset,
-  int embed_dim)
-{
+    scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
+    const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
   int x_index, y_index;
   scalar_t cos, sin;
   if (IS_NEOX) {
@@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding(
   arr[y_index] = y * cos + x * sin;
 }
 
-template<typename scalar_t, bool IS_NEOX>
+template <typename scalar_t, bool IS_NEOX>
 inline __device__ void apply_rotary_embedding(
-  scalar_t* __restrict__ query,                 // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
-  scalar_t* __restrict__ key,                   // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
-  const scalar_t* cache_ptr,
-  const int head_size,
-  const int num_heads,
-  const int num_kv_heads,
-  const int rot_dim,
-  const int token_idx,
-  const int64_t query_stride,
-  const int64_t key_stride)
-{
+    scalar_t* __restrict__ query,  // [batch_size, seq_len, num_heads,
+                                   // head_size] or [num_tokens, num_heads,
+                                   // head_size]
+    scalar_t* __restrict__ key,    // [batch_size, seq_len, num_kv_heads,
+                                   // head_size] or [num_tokens, num_kv_heads,
+                                   // head_size]
+    const scalar_t* cache_ptr, const int head_size, const int num_heads,
+    const int num_kv_heads, const int rot_dim, const int token_idx,
+    const int64_t query_stride, const int64_t key_stride) {
   const int embed_dim = rot_dim / 2;
   const scalar_t* cos_ptr = cache_ptr;
   const scalar_t* sin_ptr = cache_ptr + embed_dim;
@@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
     const int head_idx = i / embed_dim;
     const int64_t token_head = token_idx * query_stride + head_idx * head_size;
     const int rot_offset = i % embed_dim;
-    apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
-                                              sin_ptr, rot_offset, embed_dim);
+    apply_token_rotary_embedding<scalar_t, IS_NEOX>(
+        query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
   }
 
   const int nk = num_kv_heads * embed_dim;
@@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding(
     const int head_idx = i / embed_dim;
     const int64_t token_head = token_idx * key_stride + head_idx * head_size;
     const int rot_offset = i % embed_dim;
-    apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
-                                              sin_ptr, rot_offset, embed_dim);
+    apply_token_rotary_embedding<scalar_t, IS_NEOX>(
+        key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
   }
 }
 
-template<typename scalar_t, bool IS_NEOX>
+template <typename scalar_t, bool IS_NEOX>
 __global__ void rotary_embedding_kernel(
-  const int64_t* __restrict__ positions,        // [batch_size, seq_len] or [num_tokens]
-  scalar_t* __restrict__ query,                 // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
-  scalar_t* __restrict__ key,                   // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
-  const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2]
-  const int rot_dim,
-  const int64_t query_stride,
-  const int64_t key_stride,
-  const int num_heads,
-  const int num_kv_heads,
-  const int head_size) {
+    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
+                                            // [num_tokens]
+    scalar_t* __restrict__ query,           // [batch_size, seq_len, num_heads,
+                                   // head_size] or [num_tokens, num_heads,
+                                   // head_size]
+    scalar_t* __restrict__ key,  // [batch_size, seq_len, num_kv_heads,
+                                 // head_size] or [num_tokens, num_kv_heads,
+                                 // head_size]
+    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //
+                                                 // 2]
+    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
+    const int num_heads, const int num_kv_heads, const int head_size) {
   // Each thread block is responsible for one token.
   const int token_idx = blockIdx.x;
   int64_t pos = positions[token_idx];
   const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
 
-  apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
+  apply_rotary_embedding<scalar_t, IS_NEOX>(
+      query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
+      token_idx, query_stride, key_stride);
 }
 
-template<typename scalar_t, bool IS_NEOX>
+template <typename scalar_t, bool IS_NEOX>
 __global__ void batched_rotary_embedding_kernel(
-  const int64_t* __restrict__ positions,              // [batch_size, seq_len] or [num_tokens]
-  scalar_t* __restrict__ query,                       // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
-  scalar_t* __restrict__ key,                         // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
-  const scalar_t* __restrict__ cos_sin_cache,         // [max_position, 2, rot_dim // 2]
-  const int64_t* __restrict__ cos_sin_cache_offsets,  // [batch_size, seq_len] or [num_tokens]
-  const int rot_dim,
-  const int64_t query_stride,
-  const int64_t key_stride,
-  const int num_heads,
-  const int num_kv_heads,
-  const int head_size) {
+    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
+                                            // [num_tokens]
+    scalar_t* __restrict__ query,           // [batch_size, seq_len, num_heads,
+                                   // head_size] or [num_tokens, num_heads,
+                                   // head_size]
+    scalar_t* __restrict__ key,  // [batch_size, seq_len, num_kv_heads,
+                                 // head_size] or [num_tokens, num_kv_heads,
+                                 // head_size]
+    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //
+                                                 // 2]
+    const int64_t* __restrict__ cos_sin_cache_offsets,  // [batch_size, seq_len]
+                                                        // or [num_tokens]
+    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
+    const int num_heads, const int num_kv_heads, const int head_size) {
   // Each thread block is responsible for one token.
   const int token_idx = blockIdx.x;
   int64_t pos = positions[token_idx];
   int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
-  const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
+  const scalar_t* cache_ptr =
+      cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
 
-  apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
+  apply_rotary_embedding<scalar_t, IS_NEOX>(
+      query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
+      token_idx, query_stride, key_stride);
 }
 
-} // namespace aphrodite
+}  // namespace aphrodite
 
 void rotary_embedding(
-  torch::Tensor& positions,         // [batch_size, seq_len] or [num_tokens]
-  torch::Tensor& query,             // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
-  torch::Tensor& key,               // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
-  int head_size,
-  torch::Tensor& cos_sin_cache,     // [max_position, rot_dim]
-  bool is_neox) {
+    torch::Tensor& positions,  // [batch_size, seq_len] or [num_tokens]
+    torch::Tensor& query,  // [batch_size, seq_len, num_heads * head_size] or
+                           // [num_tokens, num_heads * head_size]
+    torch::Tensor& key,    // [batch_size, seq_len, num_kv_heads * head_size] or
+                           // [num_tokens, num_kv_heads * head_size]
+    int64_t head_size,
+    torch::Tensor& cos_sin_cache,  // [max_position, rot_dim]
+    bool is_neox) {
   int64_t num_tokens = query.numel() / query.size(-1);
   int rot_dim = cos_sin_cache.size(1);
   int num_heads = query.size(-1) / head_size;
@@ -132,39 +138,27 @@ void rotary_embedding(
   int64_t key_stride = key.stride(-2);
 
   dim3 grid(num_tokens);
-  dim3 block(std::min(num_heads * rot_dim / 2, 512));
+  dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
   const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
-    query.scalar_type(),
-    "rotary_embedding",
-    [&] {
-      if (is_neox) {
-        aphrodite::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
-          positions.data_ptr<int64_t>(),
-          query.data_ptr<scalar_t>(),
-          key.data_ptr<scalar_t>(),
-          cos_sin_cache.data_ptr<scalar_t>(),
-          rot_dim,
-          query_stride,
-          key_stride,
-          num_heads,
-          num_kv_heads,
-          head_size);
-      } else {
-        aphrodite::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
-          positions.data_ptr<int64_t>(),
-          query.data_ptr<scalar_t>(),
-          key.data_ptr<scalar_t>(),
-          cos_sin_cache.data_ptr<scalar_t>(),
-          rot_dim,
-          query_stride,
-          key_stride,
-          num_heads,
-          num_kv_heads,
-          head_size);
-      }
-    });
+      query.scalar_type(), "rotary_embedding", [&] {
+        if (is_neox) {
+          aphrodite::rotary_embedding_kernel<scalar_t, true>
+              <<<grid, block, 0, stream>>>(
+                  positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
+                  key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
+                  rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
+                  head_size);
+        } else {
+          aphrodite::rotary_embedding_kernel<scalar_t, false>
+              <<<grid, block, 0, stream>>>(
+                  positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
+                  key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
+                  rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
+                  head_size);
+        }
+      });
 }
 
 /*
@@ -172,14 +166,15 @@ Batched version of rotary embedding, pack multiple LoRAs together
 and process in batched manner.
 */
 void batched_rotary_embedding(
-  torch::Tensor& positions,         // [batch_size, seq_len] or [num_tokens]
-  torch::Tensor& query,             // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
-  torch::Tensor& key,               // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
-  int head_size,
-  torch::Tensor& cos_sin_cache,     // [max_position, rot_dim]
-  bool is_neox,
-  int rot_dim,
-  torch::Tensor& cos_sin_cache_offsets // [num_tokens]
+    torch::Tensor& positions,  // [batch_size, seq_len] or [num_tokens]
+    torch::Tensor& query,  // [batch_size, seq_len, num_heads * head_size] or
+                           // [num_tokens, num_heads * head_size]
+    torch::Tensor& key,    // [batch_size, seq_len, num_kv_heads * head_size] or
+                           // [num_tokens, num_kv_heads * head_size]
+    int64_t head_size,
+    torch::Tensor& cos_sin_cache,  // [max_position, rot_dim]
+    bool is_neox, int64_t rot_dim,
+    torch::Tensor& cos_sin_cache_offsets  // [num_tokens]
 ) {
   int64_t num_tokens = cos_sin_cache_offsets.size(0);
   int num_heads = query.size(-1) / head_size;
@@ -188,39 +183,25 @@ void batched_rotary_embedding(
   int64_t key_stride = key.stride(-2);
 
   dim3 grid(num_tokens);
-  dim3 block(std::min(num_heads * rot_dim / 2, 512));
+  dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
   const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   APHRODITE_DISPATCH_FLOATING_TYPES(
-    query.scalar_type(),
-    "rotary_embedding",
-    [&] {
-      if (is_neox) {
-        aphrodite::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
-          positions.data_ptr<int64_t>(),
-          query.data_ptr<scalar_t>(),
-          key.data_ptr<scalar_t>(),
-          cos_sin_cache.data_ptr<scalar_t>(),
-          cos_sin_cache_offsets.data_ptr<int64_t>(),
-          rot_dim,
-          query_stride,
-          key_stride,
-          num_heads,
-          num_kv_heads,
-          head_size);
-      } else {
-        aphrodite::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
-          positions.data_ptr<int64_t>(),
-          query.data_ptr<scalar_t>(),
-          key.data_ptr<scalar_t>(),
-          cos_sin_cache.data_ptr<scalar_t>(),
-          cos_sin_cache_offsets.data_ptr<int64_t>(),
-          rot_dim,
-          query_stride,
-          key_stride,
-          num_heads,
-          num_kv_heads,
-          head_size);
-      }
-    });
+      query.scalar_type(), "rotary_embedding", [&] {
+        if (is_neox) {
+          aphrodite::batched_rotary_embedding_kernel<scalar_t, true>
+              <<<grid, block, 0, stream>>>(
+                  positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
+                  key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
+                  cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim,
+                  query_stride, key_stride, num_heads, num_kv_heads, head_size);
+        } else {
+          aphrodite::batched_rotary_embedding_kernel<scalar_t, false>
+              <<<grid, block, 0, stream>>>(
+                  positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
+                  key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
+                  cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim,
+                  query_stride, key_stride, num_heads, num_kv_heads, head_size);
+        }
+      });
 }

+ 3 - 3
kernels/punica/punica_ops.cu

@@ -1,4 +1,4 @@
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <cstdint>
 
@@ -73,7 +73,7 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
 }
 
 void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
-                   torch::Tensor indicies, int64_t layer_idx, float scale) {
+                   torch::Tensor indicies, int64_t layer_idx, double scale) {
   CHECK_INPUT(y);
   CHECK_INPUT(x);
   CHECK_INPUT(w);
@@ -305,7 +305,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
 
 void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
                              torch::Tensor indicies, int64_t layer_idx,
-                             float scale, int64_t h_in, int64_t h_out,
+                             double scale, int64_t h_in, int64_t h_out,
                              int64_t y_offset) {
   CHECK_INPUT(y);
   CHECK_INPUT(x);

+ 4 - 4
kernels/punica/punica_ops.h

@@ -1,11 +1,11 @@
 #pragma once
 
-#include <torch/extension.h>
+#include <torch/all.h>
 
 void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
-                   torch::Tensor indicies, int64_t layer_idx, float scale);
+                   torch::Tensor indicies, int64_t layer_idx, double scale);
 
-void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor w,
+void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
                              torch::Tensor indicies, int64_t layer_idx,
-                             float scale, int64_t h_in, int64_t h_out,
+                             double scale, int64_t h_in, int64_t h_out,
                              int64_t y_offset);

+ 0 - 11
kernels/punica/punica_pybind.cpp

@@ -1,11 +0,0 @@
-#include <torch/extension.h>
-
-#include "punica_ops.h"
-
-#define DEFINE_pybind(name) m.def(#name, &name, #name);
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-  m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
-  m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
-        "dispatch_bgmv_low_level");
-}

+ 18 - 0
kernels/punica/torch_bindings.cpp

@@ -0,0 +1,18 @@
+#include "registration.h"
+#include "punica_ops.h"
+
+TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
+  m.def(
+      "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
+      "layer_idx, float scale) -> ()");
+  m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
+
+  m.def(
+      "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
+      "Tensor indicies, int layer_idx,"
+      "float scale, int h_in, int h_out,"
+      "int y_offset) -> ()");
+  m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
+}
+
+REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

+ 0 - 82
kernels/pybind.cpp

@@ -1,82 +0,0 @@
-#include "cache.h"
-#include "cuda_utils.h"
-#include "ops.h"
-#include <torch/extension.h>
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-  // Aphrodite custom ops
-  pybind11::module ops = m.def_submodule("ops", "Aphrodite custom operators");
-
-  // Attention ops
-  ops.def("paged_attention_v1", &paged_attention_v1,
-          "Compute the attention between an input query and the cached "
-          "keys/values using PagedAttention.");
-  ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
-
-  // Activation ops
-  ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
-  ops.def("gelu_and_mul", &gelu_and_mul,
-          "Activation function used in GeGLU with `none` approximation.");
-  ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
-          "Activation function used in GeGLU with `tanh` approximation.");
-  ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
-  ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
-
-  // Layernorm
-  ops.def("rms_norm", &rms_norm,
-          "Apply Root Mean Square (RMS) Normalization to the input tensor.");
-
-  ops.def("fused_add_rms_norm", &fused_add_rms_norm,
-          "In-place fused Add and RMS Normalization");
-
-  // Rotary embedding
-  ops.def("rotary_embedding", &rotary_embedding,
-          "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
-  ops.def("batched_rotary_embedding", &batched_rotary_embedding,
-          "Apply batched GPT-NeoX or GPT-J style rotary embedding to query and "
-          "key");
-
-  ops.def("moe_align_block_size", &moe_align_block_size,
-          "Aligning the number of tokens to be processed by each expert such "
-          "that it is divisible by the block size.");
-
-  // Cache ops
-  pybind11::module cache_ops =
-      m.def_submodule("cache_ops", "Aphrodite cache ops");
-  cache_ops.def("swap_blocks", &swap_blocks,
-                "Swap in (out) the cache blocks from src to dst");
-  cache_ops.def("copy_blocks", &copy_blocks,
-                "Copy the cache blocks from src to dst");
-  cache_ops.def("reshape_and_cache", &reshape_and_cache,
-                "Reshape the key and value tensors and cache them");
-  cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash,
-                "Reshape the key and value tensors and cache them");
-  cache_ops.def("convert_fp8", &convert_fp8,
-                "Convert the key and value cache to fp8 data type");
-
-  // Cuda utils
-  pybind11::module cuda_utils =
-      m.def_submodule("cuda_utils", "Aphrodite cuda utils");
-  cuda_utils.def("get_device_attribute", &get_device_attribute,
-                 "Gets the specified device attribute.");
-
-  cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
-                 &get_max_shared_memory_per_block_device_attribute,
-                 "Gets the maximum shared memory per block device attribute.");
-
-#ifndef USE_ROCM
-  // Custom all-reduce kernels
-  pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
-  custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
-  custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
-  custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
-  custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
-  custom_ar.def("dispose", &dispose, "dispose");
-  custom_ar.def("meta_size", &meta_size, "meta_size");
-  custom_ar.def("register_buffer", &register_buffer, "register_buffer");
-  custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
-                "get_graph_buffer_ipc_meta");
-  custom_ar.def("register_graph_buffers", &register_graph_buffers,
-                "register_graph_buffers");
-#endif
-}

+ 1 - 1
kernels/quantization/aqlm/gemm_kernels.cu

@@ -18,7 +18,7 @@
 #include <cuda.h>
 #include <cuda_fp16.h>
 #include <cuda_runtime.h>
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <c10/cuda/CUDAStream.h>
 #include <c10/cuda/CUDAGuard.h>
 

+ 4 - 4
kernels/quantization/autoquant/int4_fp16_gemm_kernels.cu

@@ -1,4 +1,4 @@
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <cuda_fp16.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <vector>
@@ -20,9 +20,9 @@ void autoquant_convert_s4_k_m8(
   torch::Tensor _quant_weight_src,
   torch::Tensor _quant_scales,
   torch::Tensor _quant_zeros,
-  int m,
-  int k,
-  int group_size){
+  int64_t m,
+  int64_t k,
+  int64_t group_size){
       auto st_ = _quant_scales.scalar_type();
       const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
       if(st_ == at::ScalarType::Half){

+ 950 - 764
kernels/quantization/awq/gemm_kernels.cu

@@ -1,771 +1,957 @@
 /*
 Adapted from https://github.com/mit-han-lab/llm-awq
 @article{lin2023awq,
-  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
-  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
-  journal={arXiv},
-  year={2023}
+  title={AWQ: Activation-aware Weight Quantization for LLM Compression and
+Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
+Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
 }
  */
 
+#include <torch/all.h>
+#include <c10/cuda/CUDAGuard.h>
 
- #include <torch/extension.h>
- #include <c10/cuda/CUDAGuard.h>
- 
- #include "dequantize.cuh"
- 
- #include <cuda_fp16.h>
- 
- namespace aphrodite {
- namespace awq {
- 
- // Pack two half values.
- static inline __device__ __host__ unsigned
- __pack_half2(const half x, const half y) {
-   unsigned v0 = *((unsigned short *)&x);
-   unsigned v1 = *((unsigned short *)&y);
-   return (v1 << 16) | v0;
- }
- 
- template<int N>
- __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
-   int G,
-   int split_k_iters,
-   half* __restrict__ A,
-   int* __restrict__ B,
-   half* __restrict__ scaling_factors,
-   int* __restrict__ zeros,
-   int M,
-   int IC,
-   int OC,
-   half* __restrict__ C)
- {
-   // Only support matrix n = 64 or 128
-   assert(N == 64 || N == 128);
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
-   assert(false);
- #else
-   static constexpr uint32_t ZERO = 0x0;
-   float C_warp[32];
-   __shared__ half A_shared[16 * (32 + 8)];
-   __shared__ half B_shared[32 * (N + 8)];
- 
-   __shared__ half scaling_factors_shared[N];
-   __shared__ half zeros_shared[N];
- 
-   int j_factors1 = ((OC + N - 1) / N);
-   int blockIdx_x = 0;
-   int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
-   int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
- 
-   half A_shared_warp[8];
-   half B_shared_warp[N / 4];
-   for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
-     for (int i = 0; i < 8; ++i) {
-       C_warp[(j_0_4_init * 8) + i] = 0.0;
-     }
-   }
- 
-   static constexpr int row_stride_warp = 32 * 8 / 32;
-   static constexpr int row_stride = 2 * 32 * 8 / N;
-   bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
-   // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
-   bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M;     // threadIdx.y is warp_id
-   // bool wb_C_flag = (threadIdx.x / 4) < M;
- 
-   half* A_ptr = A
-                 + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
-                 + (((int)threadIdx.x) % (32 / 8)) * 8;
- 
-   int* B_ptr = B
-             + ((int)threadIdx.y) * (OC / 8) * (256 / N)
-             + (((int)threadIdx.x) / (N / 8)) * (OC / 8)
-             + (((int)blockIdx_y) % j_factors1) * (N / 8)
-             + (((int)threadIdx.x) % (N / 8)) * 1;
- // Why * 1 in the above line?
- 
-   half* A_shared_ptr = A_shared
-                     + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
-                     + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
-                     + (((int)threadIdx.x) % (32 / 8) ) * 8;
- 
-   half* B_shared_ptr = B_shared
-                     + ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
-                     + (((int)threadIdx.x) / (N / 8)) * (N + 8)
-                     + (((int)threadIdx.x) % (N / 8)) * 8;
- 
-   int* zeros_ptr = zeros
-                 + (((int)blockIdx_y) % j_factors1) * (N / 8)
-                 + ((int)threadIdx.x) % (N / 8);
- 
-   half* scaling_factors_ptr = scaling_factors
-                             + (((int)blockIdx_y) % j_factors1) * N
-                             + (((int)threadIdx.x) % (N / 8)) * 8;
- 
-   half* C_ptr = C
-               + static_cast<long long>(blockIdx_z) * M * OC        // blockIdz.x -> split_k dim
-               + (((int)blockIdx_y) % j_factors1) * N
-               + ((int)threadIdx.y) * (N / 2)
-               + (((int)threadIdx.x) % 4) * 2;
- 
-   // preload s.f. and zeros
-   int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
-   if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
-   for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
-     int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
-     __syncthreads();
-     // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
-     if (ld_A_flag)
-     {
-       *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
-     }
-     else
-     {
-       *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
-     }
- 
-     // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
-     uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
-     uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
-     uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
-     /*
-     if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
-       printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
-     }
-     */
-     // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
-     int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
- 
-     for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
- 
-       // B: 32 x 136 (128+8) float16
-       // each warp: 32 x 4
-       // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
-       // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
-       // row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
-       uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
-       uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
-       //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
- 
-       // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
-       // - zero and * scale
-       // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
-       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
-       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
-       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
-       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
-       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
-       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
-       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
-       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
-       /*
-       if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
-         printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
-       }
-       */
- 
-       // write back
-       *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
-     }
-     __syncthreads();
- 
-     for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
-       {
-         unsigned int addr;
-         __asm__ __volatile__(
-           "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
-           : "=r"(addr)
-           : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
-         );
- 
- 
-         __asm__ __volatile__(
-           "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
-           "{%0, %1, %2, %3}, [%4];\n"
-           : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
-           : "r"(addr)
-         );
-       }
- 
-       for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
-         {
-           unsigned int addr;
-           __asm__ __volatile__(
-             "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
-             : "=r"(addr)
-             : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
-           );
-           __asm__ __volatile__(
-             "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
-             "{%0, %1, %2, %3}, [%4];\n"
-             : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
-             : "r"(addr)
-           );
-         }
-       }
-       for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
-         {
-           __asm__ __volatile__(
-             "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
-             "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
-             :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
-             : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
-         }
- 
-         {
-           __asm__ __volatile__(
-             "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
-             "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
-             :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
-             : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
-         }
- 
-         {
-           __asm__ __volatile__(
-             "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
-             "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
-             :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
-             : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
-         }
- 
-         {
-           __asm__ __volatile__(
-             "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
-             "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
-             :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
-             : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
-         }
- #else
-         {
-           __asm__ __volatile__(
-             "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
-             "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
-             :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
-             : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
-         }
- 
-         {
-           __asm__ __volatile__(
-             "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
-             "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
-             :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
-             : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
-         }
- 
- #endif
-       }
-     }
-   }
- 
- // TODO: Shang: Hoist loop invariance.
-   for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
-     for (int local_id = 0; local_id < 8; ++local_id) {
-       int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
-       if (row_offset < M)
-       {
-         *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
-       }
-     }
-   }
- #endif
- }
- 
- __global__ void __launch_bounds__(64) dequantize_weights(
-     int* __restrict__ B,
-     half* __restrict__ scaling_factors,
-     int* __restrict__ zeros,
-     half* __restrict__ C,
-     int G,
-     int in_c,
-     int out_c
- )
- {
-   if (blockIdx.z > 0) {
-       B = B + blockIdx.z * in_c * out_c / 8;
-       scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G;
-       zeros = zeros + blockIdx.z * in_c * out_c / G / 8;
-       C = C + blockIdx.z * in_c * out_c;
-   }
-   int j_factors1 = 4;
-   int row_stride2 = 4;
-   int split_k_iters = 1;
-   static constexpr uint32_t ZERO = 0x0;
-   half B_shared[32 * (128 + 8)];
- 
-   half* B_shared_ptr2 = B_shared;
- 
-   half B_shared_warp[32];
-   int OC = 512;
- 
-   int N = blockDim.x * gridDim.x;  // 2
-   int col = (blockIdx.x * blockDim.x + threadIdx.x);
-   int row = blockIdx.y * blockDim.y + threadIdx.y;
-   int index1 = 8 * col + 8 * row * N;
-   half* C_ptr2 = C + index1;
- 
-   int index2 = col + row * N;
-   int* B_ptr2 = B + index2;
- 
-   int index3 = col + (int)(row / G) * N;
-   int* zeros_ptr2 = zeros + index3;
-   int index4 = 8 * col + (int)(row / G) * N * 8;
-   half* scaling_factors_ptr2 = scaling_factors + index4;
- 
-   uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
-   uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
-   uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
- 
-   uint32_t B_loaded = *(uint32_t*)B_ptr2;
-   uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
-   asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
-   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
-   asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
-   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
-   asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
-   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
-   asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
-   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
- 
-   *(uint4*)B_shared_ptr2 = B_loaded_fp16;
- 
-   for (int i = 0; i < 8; ++i) {
-     *(C_ptr2 + i) = B_shared[i];
-   }
- }
- 
- template<int N>
- __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
-   int G,
-   int split_k_iters,
-   half* __restrict__ A,
-   int* __restrict__ B,
-   half* __restrict__ scaling_factors,
-   int* __restrict__ zeros,
-   const float* __restrict__ topk_weights,
-   const int* __restrict__ sorted_token_ids_ptr,
-   const int* __restrict__ expert_ids_ptr,
-   const int* __restrict__ num_tokens_post_padded,
-   const int num_valid_tokens,
-   const int top_k,
-   const int expert_num,
-   int pad_M,
-   int M,
-   int IC,
-   int OC,
-   half* __restrict__ C)
- {
-   // Only support matrix n = 64 or 128
-   assert(N == 64 || N == 128);
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
-   assert(false);
- #else
-   int num_tokens = *num_tokens_post_padded;
-   int j_factors1 = ((OC + N - 1) / N);
-   int blockIdx_x = 0;
-   int blockIdx_y = blockIdx.x % ((pad_M + 16 - 1) / 16 * j_factors1);
-   int blockIdx_z = blockIdx.x / ((pad_M + 16 - 1) / 16 * j_factors1);
-   int block = blockIdx_y / j_factors1;
-   if (block * 16 >= num_tokens) return;
- 
-   static constexpr uint32_t ZERO = 0x0;
-   float C_warp[32];
-   __shared__ half A_shared[16 * (32 + 8)];
-   __shared__ half B_shared[32 * (N + 8)];
- 
-   __shared__ half scaling_factors_shared[N];
-   __shared__ half zeros_shared[N];
- 
-   half A_shared_warp[8];
-   half B_shared_warp[N / 4];
-   for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
-     for (int i = 0; i < 8; ++i) {
-       C_warp[(j_0_4_init * 8) + i] = 0.0;
-     }
-   }
- 
-   static constexpr int row_stride_warp = 32 * 8 / 32;
-   static constexpr int row_stride = 2 * 32 * 8 / N;
-   bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
-   // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
- 
-   int row = (block * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32);
-   int token_id = sorted_token_ids_ptr[row];
-   bool ld_A_flag = (token_id < num_valid_tokens);
-   half* A_ptr = A + token_id / top_k * IC + (((int)threadIdx.x) % (32 / 8)) * 8;
- 
-   int expert_id = expert_ids_ptr[block];
-   B = B + OC * IC / 8 * expert_id;
-   scaling_factors = scaling_factors + OC * IC / G * expert_id;
-   zeros = zeros + OC * IC / G / 8 * expert_id;
- 
-   int* B_ptr = B
-             + ((int)threadIdx.y) * (OC / 8) * (256 / N)
-             + (((int)threadIdx.x) / (N / 8)) * (OC / 8)
-             + (((int)blockIdx_y) % j_factors1) * (N / 8)
-             + (((int)threadIdx.x) % (N / 8)) * 1;
-   // Why * 1 in the above line?
- 
-   half* A_shared_ptr = A_shared
-                     + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
-                     + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
-                     + (((int)threadIdx.x) % (32 / 8) ) * 8;
- 
-   half* B_shared_ptr = B_shared
-                     + ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
-                     + (((int)threadIdx.x) / (N / 8)) * (N + 8)
-                     + (((int)threadIdx.x) % (N / 8)) * 8;
- 
-   int* zeros_ptr = zeros
-                 + (((int)blockIdx_y) % j_factors1) * (N / 8)
-                 + ((int)threadIdx.x) % (N / 8);
- 
-   half* scaling_factors_ptr = scaling_factors
-                             + (((int)blockIdx_y) % j_factors1) * N
-                             + (((int)threadIdx.x) % (N / 8)) * 8;
- 
-   half* C_ptr = C
-               + static_cast<long long>(blockIdx_z) * M * OC * expert_num  // blockIdz.x -> split_k dim
-               + (((int)blockIdx_y) % j_factors1) * N
-               + ((int)threadIdx.y) * (N / 2)
-               + (((int)threadIdx.x) % 4) * 2;
- 
-   // preload s.f. and zeros
-   int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
-   if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
-   for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
-     int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
-     __syncthreads();
-     // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
-     if (ld_A_flag)
-     {
-       *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
-     }
-     else
-     {
-       *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
-     }
- 
-     uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
-     uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
-     uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
- 
-     int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
- 
-     for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
- 
-       uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
-       uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
- 
-       // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
-       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
-       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
-       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
-       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
-       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
-       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
-       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
-       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
- 
-       // write back
-       *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
-     }
-     __syncthreads();
- 
-     for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
-       {
-         unsigned int addr;
-         __asm__ __volatile__(
-           "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
-           : "=r"(addr)
-           : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
-         );
- 
- 
-         __asm__ __volatile__(
-           "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
-           "{%0, %1, %2, %3}, [%4];\n"
-           : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
-           : "r"(addr)
-         );
-       }
- 
-       for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
-         {
-           unsigned int addr;
-           __asm__ __volatile__(
-             "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
-             : "=r"(addr)
-             : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
-           );
-           __asm__ __volatile__(
-             "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
-             "{%0, %1, %2, %3}, [%4];\n"
-             : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
-             : "r"(addr)
-           );
-         }
-       }
-       for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
-         {
-           __asm__ __volatile__(
-             "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
-             "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
-             :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
-             : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
-         }
- 
-         {
-           __asm__ __volatile__(
-             "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
-             "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
-             :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
-             : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
-         }
- 
-         {
-           __asm__ __volatile__(
-             "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
-             "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
-             :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
-             : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
-         }
- 
-         {
-           __asm__ __volatile__(
-             "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
-             "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
-             :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
-             : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
-         }
- #else
-         {
-           __asm__ __volatile__(
-             "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
-             "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
-             :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
-             : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
-         }
- 
-         {
-           __asm__ __volatile__(
-             "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
-             "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
-             :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
-             : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
-         }
- 
- #endif
-       }
-     }
-   }
- 
- // TODO: Shang: Hoist loop invariance.
-   for (int ax1_0_1 = 0; ax1_0_1 < N / 32; ++ax1_0_1) {
-     for (int local_id = 0; local_id < 8; ++local_id) {
-       int row_offset = block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
-       int token_id = sorted_token_ids_ptr[row_offset];
-       if (token_id < num_valid_tokens)
-       {
-         float value = C_warp[(ax1_0_1 * 8) + local_id];
-         if (topk_weights) {
-             value = value * topk_weights[token_id];
-         }
-         *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(value);
-       }
-     }
-   }
- #endif
- }
- 
- } // namespace awq
- } // namespace aphrodite
- 
- torch::Tensor awq_dequantize(
-     torch::Tensor _kernel,
-     torch::Tensor _scaling_factors,
-     torch::Tensor _zeros,
-     int split_k_iters,
-     int thx,
-     int thy)
- {
-     int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1);
-     int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2);
-     int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0);
-     int out_c = qout_c * 8;
-     int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0) : _scaling_factors.size(1));
- 
-     int x_thread = thx;
-     int y_thread = thy;
- 
-     int x_blocks = 1;
-     int y_blocks = 1;
-     if (thx==0) {
-       x_thread = qout_c;
-     }
-     if (thy==0) {
-       y_thread = in_c;
-     }
-     if (thx==0 && thy==0) {
-       x_thread = 8;
-       y_thread = 8;
-       x_blocks = (int)(qout_c / 8);
-       y_blocks = (int)(in_c / 8);
-     }
- 
-     const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
- 
-     auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
-     at::Tensor _de_kernel;
-     if (num_experts == 1) {
-       _de_kernel = torch::empty({in_c, out_c}, options);
-     } else {
-       _de_kernel = torch::empty({num_experts, in_c, out_c}, options);
-     }
- 
-     auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
-     auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
-     auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
-     auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
- 
-     dim3 num_blocks(x_blocks, y_blocks, num_experts);
-     dim3 threads_per_block(x_thread, y_thread);
- 
-     const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-     aphrodite::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
-         kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c);
- 
-     return _de_kernel;
- }
- 
- // in_feats: M, IC [float16]
- // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
- // scaling_factors: IC // G, OC [float16]
- // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
- // assume that batch_size < 16 for now
- 
- torch::Tensor awq_gemm(
-     torch::Tensor _in_feats,
-     torch::Tensor _kernel,
-     torch::Tensor _scaling_factors,
-     torch::Tensor _zeros,
-     int split_k_iters)
- {
-     int num_in_feats = _in_feats.size(0);
-     int num_in_channels = _in_feats.size(1);
-     const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
- 
-     auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
-     at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
-     int num_out_feats = _out_feats.size(-2);
-     int num_out_channels = _out_feats.size(-1);
- 
-     auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
-     auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
-     auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
-     auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
-     auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
-     int group_size = num_in_channels / _scaling_factors.size(0);
- 
-     if (num_out_channels % 64 != 0)
-         throw std::invalid_argument("OC is not multiple of cta_N = 64");
-     if (num_out_channels % 8 != 0)
-         throw std::invalid_argument("OC is not multiple of pack_num = 8");
-     if (group_size % 32 != 0)
-         throw std::invalid_argument("Group size should be a multiple of 32");
-     if (num_out_channels % group_size != 0)
-         throw std::invalid_argument("OC is not multiple of Group size");
- 
-     const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-     if (num_out_channels % 128 == 0)
-     {
-         int j_factors1 = num_out_channels / 128 / 1;
-         dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
-         // threadIdx.x: 32
-         // threadIdx.y: i_factors[2] * j_factors[2]
-         dim3 threads_per_block(32, 2);
-         aphrodite::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
-             group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
-             num_out_channels, out_feats);
-     }
-     else if (num_out_channels % 64 == 0)
-     {
-         int j_factors1 = num_out_channels / 64 / 1;
-         dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
- 
-         // threadIdx.x: 32
-         // threadIdx.y: i_factors[2] * j_factors[2]
-         dim3 threads_per_block(32, 2);
-         aphrodite::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
-             group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
-             num_out_channels, out_feats);
-     }
-     return _out_feats.sum(0);
- }
- 
- torch::Tensor awq_group_gemm(
-     torch::Tensor _in_feats,
-     torch::Tensor _kernel,
-     torch::Tensor _scaling_factors,
-     torch::Tensor _zeros,
-     torch::Tensor _topk_weights,
-     torch::Tensor _sorted_token_ids_ptr,
-     torch::Tensor _expert_ids_ptr,
-     torch::Tensor _num_tokens_post_padded,
-     bool mul_weights,
-     int split_k_iters)
- {
-     int num_in_feats = _in_feats.size(0);
-     int pad_num_in_feats = _sorted_token_ids_ptr.size(0);
-     int num_in_channels = _in_feats.size(2);
-     const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
- 
-     auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
-     int num_experts = _topk_weights.size(1);
-     int top_k = num_experts / _in_feats.size(1);
-     int group_size = num_in_channels / _scaling_factors.size(1);
- 
-     at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8}, options);
-     int num_out_channels = _out_feats.size(-1);
- 
-     auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
-     auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
-     auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
-     auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
-     auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
-     auto topk_weights = mul_weights ? reinterpret_cast<float*>(_topk_weights.data_ptr()) : nullptr;
-     auto sorted_token_ids_ptr = reinterpret_cast<int*>(_sorted_token_ids_ptr.data_ptr());
-     auto expert_ids_ptr = reinterpret_cast<int*>(_expert_ids_ptr.data_ptr());
-     auto num_tokens_post_padded = reinterpret_cast<int*>(_num_tokens_post_padded.data_ptr());
- 
-     const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-     if (num_out_channels % 128 == 0)
-     {
-         int j_factors1 = num_out_channels / 128 / 1;
-         dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
-         // threadIdx.x: 32
-         // threadIdx.y: i_factors[2] * j_factors[2]
-         dim3 threads_per_block(32, 2);
-         aphrodite::awq::group_gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
-             group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
-             topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded,
-             _topk_weights.numel(), top_k, num_experts, pad_num_in_feats,
-             num_in_feats, num_in_channels, num_out_channels, out_feats);
-     }
-     else if (num_out_channels % 64 == 0)
-     {
-         int j_factors1 = num_out_channels / 64 / 1;
-         dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
- 
-         // threadIdx.x: 32
-         // threadIdx.y: i_factors[2] * j_factors[2]
-         dim3 threads_per_block(32, 2);
-         aphrodite::awq::group_gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
-             group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
-             topk_weights, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded,
-             _topk_weights.numel(), top_k, num_experts, pad_num_in_feats,
-             num_in_feats, num_in_channels, num_out_channels, out_feats);
-     }
-     return _out_feats.sum(0);
- }
+#include "dequantize.cuh"
+
+#include <cuda_fp16.h>
+
+namespace aphrodite {
+namespace awq {
+
+// Pack two half values.
+static inline __device__ __host__ unsigned __pack_half2(const half x,
+                                                        const half y) {
+  unsigned v0 = *((unsigned short*)&x);
+  unsigned v1 = *((unsigned short*)&y);
+  return (v1 << 16) | v0;
+}
+
+template <int N>
+__global__ void __launch_bounds__(64)
+    gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
+                                    half* __restrict__ A, int* __restrict__ B,
+                                    half* __restrict__ scaling_factors,
+                                    int* __restrict__ zeros, int M, int IC,
+                                    int OC, half* __restrict__ C) {
+  // Only support matrix n = 64 or 128
+  assert(N == 64 || N == 128);
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+  assert(false);
+#else
+  static constexpr uint32_t ZERO = 0x0;
+  float C_warp[32];
+  __shared__ half A_shared[16 * (32 + 8)];
+  __shared__ half B_shared[32 * (N + 8)];
+
+  __shared__ half scaling_factors_shared[N];
+  __shared__ half zeros_shared[N];
+
+  int j_factors1 = ((OC + N - 1) / N);
+  int blockIdx_x = 0;
+  int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
+  int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
+
+  half A_shared_warp[8];
+  half B_shared_warp[N / 4];
+  for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
+    for (int i = 0; i < 8; ++i) {
+      C_warp[(j_0_4_init * 8) + i] = 0.0;
+    }
+  }
+
+  static constexpr int row_stride_warp = 32 * 8 / 32;
+  static constexpr int row_stride = 2 * 32 * 8 / N;
+  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
+  // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+  bool ld_A_flag =
+      (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
+       threadIdx.x * 8 / 32) < M;  // threadIdx.y is warp_id
+  // bool wb_C_flag = (threadIdx.x / 4) < M;
+
+  half* A_ptr =
+      A +
+      (((int)blockIdx_y) / j_factors1 * 16 +
+       (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) *
+          IC +
+      (((int)threadIdx.x) % (32 / 8)) * 8;
+
+  int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
+               (((int)threadIdx.x) / (N / 8)) * (OC / 8) +
+               (((int)blockIdx_y) % j_factors1) * (N / 8) +
+               (((int)threadIdx.x) % (N / 8)) * 1;
+  // Why * 1 in the above line?
+
+  half* A_shared_ptr = A_shared +
+                       ((int)threadIdx.y) * row_stride_warp * (32 + 8) +
+                       (((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
+                       (((int)threadIdx.x) % (32 / 8)) * 8;
+
+  half* B_shared_ptr = B_shared +
+                       ((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
+                       (((int)threadIdx.x) / (N / 8)) * (N + 8) +
+                       (((int)threadIdx.x) % (N / 8)) * 8;
+
+  int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
+                   ((int)threadIdx.x) % (N / 8);
+
+  half* scaling_factors_ptr = scaling_factors +
+                              (((int)blockIdx_y) % j_factors1) * N +
+                              (((int)threadIdx.x) % (N / 8)) * 8;
+
+  half* C_ptr =
+      C +
+      static_cast<long long>(blockIdx_z) * M * OC  // blockIdz.x -> split_k dim
+      + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) +
+      (((int)threadIdx.x) % 4) * 2;
+
+  // preload s.f. and zeros
+  int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
+  if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
+  for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
+    int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
+    __syncthreads();
+    // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+    if (ld_A_flag) {
+      *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
+    } else {
+      *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
+    }
+
+    // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
+    uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
+    uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
+    uint4 B_loaded_scale =
+        *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
+    /*
+    if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
+    threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
+    B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
+    B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
+    }
+    */
+    // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
+    int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
+
+    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
+      // B: 32 x 136 (128+8) float16
+      // each warp: 32 x 4
+      // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
+      // zero -> WB UINT4
+      // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
+      // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
+      // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
+      // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
+      // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
+      // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
+      uint32_t B_loaded =
+          *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
+      uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
+      // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
+      // 8)) * 8);
+
+      // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
+      // % (cta_N / 8)) * 8);
+      // - zero and * scale
+      // TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
+      // q * scale - zero * scale.
+      asm volatile("sub.f16x2 %0, %1, %2;\n"
+                   : "=r"(B_loaded_fp16.x)
+                   : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
+                   : "=r"(B_loaded_fp16.x)
+                   : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n"
+                   : "=r"(B_loaded_fp16.y)
+                   : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
+                   : "=r"(B_loaded_fp16.y)
+                   : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n"
+                   : "=r"(B_loaded_fp16.z)
+                   : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
+                   : "=r"(B_loaded_fp16.z)
+                   : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n"
+                   : "=r"(B_loaded_fp16.w)
+                   : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
+                   : "=r"(B_loaded_fp16.w)
+                   : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
+      /*
+      if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
+      0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
+      B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
+      }
+      */
+
+      // write back
+      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
+          B_loaded_fp16;
+    }
+    __syncthreads();
+
+    for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
+      {
+        unsigned int addr;
+        __asm__ __volatile__(
+            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
+            "addr; }\n"
+            : "=r"(addr)
+            : "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
+                          (((((int)threadIdx.x) & 15) * 40) +
+                           ((((int)threadIdx.x) >> 4) * 8)))));
+
+        __asm__ __volatile__(
+            "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
+            "{%0, %1, %2, %3}, [%4];\n"
+            : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
+              "=r"(((unsigned*)(A_shared_warp + 0))[1]),
+              "=r"(((unsigned*)(A_shared_warp + 0))[2]),
+              "=r"(((unsigned*)(A_shared_warp + 0))[3])
+            : "r"(addr));
+      }
+
+      for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
+        {
+          unsigned int addr;
+          __asm__ __volatile__(
+              "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
+              "addr; }\n"
+              : "=r"(addr)
+              : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
+                                          (((int)threadIdx.y) * (N / 2))) +
+                                         (ax1_0 * 16))])) +
+                            (((((int)threadIdx.x) & 15) * (N + 8)) +
+                             ((((int)threadIdx.x) >> 4) * 8)))));
+          __asm__ __volatile__(
+              "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
+              "{%0, %1, %2, %3}, [%4];\n"
+              : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
+                "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
+                "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
+                "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
+              : "r"(addr));
+        }
+      }
+      for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
+  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
+        {
+          __asm__ __volatile__(
+              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+              "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+              : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
+              : "r"(((unsigned*)(A_shared_warp + 0))[0]),
+                "r"(((unsigned*)(A_shared_warp + 0))[1]),
+                "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
+        }
+
+        {
+          __asm__ __volatile__(
+              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+              "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+              : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
+              : "r"(((unsigned*)(A_shared_warp + 0))[0]),
+                "r"(((unsigned*)(A_shared_warp + 0))[1]),
+                "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+        }
+
+        {
+          __asm__ __volatile__(
+              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+              "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+              : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
+              : "r"(((unsigned*)(A_shared_warp + 0))[2]),
+                "r"(((unsigned*)(A_shared_warp + 0))[3]),
+                "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
+        }
+
+        {
+          __asm__ __volatile__(
+              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+              "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+              : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
+              : "r"(((unsigned*)(A_shared_warp + 0))[2]),
+                "r"(((unsigned*)(A_shared_warp + 0))[3]),
+                "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+        }
+  #else
+        {
+          __asm__ __volatile__(
+              "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+              "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
+              "%13};\n"
+              : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
+              : "r"(((unsigned*)(A_shared_warp + 0))[0]),
+                "r"(((unsigned*)(A_shared_warp + 0))[1]),
+                "r"(((unsigned*)(A_shared_warp + 0))[2]),
+                "r"(((unsigned*)(A_shared_warp + 0))[3]),
+                "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
+                "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
+        }
+
+        {
+          __asm__ __volatile__(
+              "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+              "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
+              "%13};\n"
+              : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
+              : "r"(((unsigned*)(A_shared_warp + 0))[0]),
+                "r"(((unsigned*)(A_shared_warp + 0))[1]),
+                "r"(((unsigned*)(A_shared_warp + 0))[2]),
+                "r"(((unsigned*)(A_shared_warp + 0))[3]),
+                "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
+                "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+        }
+
+  #endif
+      }
+    }
+  }
+
+  // TODO: Shang: Hoist loop invariance.
+  for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
+    for (int local_id = 0; local_id < 8; ++local_id) {
+      int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
+                       ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
+      if (row_offset < M) {
+        *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 +
+          local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
+      }
+    }
+  }
+#endif
+}
+
+__global__ void __launch_bounds__(64)
+    dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
+                       int* __restrict__ zeros, half* __restrict__ C, int G,
+                       int in_c, int out_c) {
+  if (blockIdx.z > 0) {
+    B = B + blockIdx.z * in_c * out_c / 8;
+    scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G;
+    zeros = zeros + blockIdx.z * in_c * out_c / G / 8;
+    C = C + blockIdx.z * in_c * out_c;
+  }
+  int j_factors1 = 4;
+  int row_stride2 = 4;
+  int split_k_iters = 1;
+  static constexpr uint32_t ZERO = 0x0;
+  half B_shared[32 * (128 + 8)];
+
+  half* B_shared_ptr2 = B_shared;
+
+  half B_shared_warp[32];
+  int OC = 512;
+
+  int N = blockDim.x * gridDim.x;  // 2
+  int col = (blockIdx.x * blockDim.x + threadIdx.x);
+  int row = blockIdx.y * blockDim.y + threadIdx.y;
+  int index1 = 8 * col + 8 * row * N;
+  half* C_ptr2 = C + index1;
+
+  int index2 = col + row * N;
+  int* B_ptr2 = B + index2;
+
+  int index3 = col + (int)(row / G) * N;
+  int* zeros_ptr2 = zeros + index3;
+  int index4 = 8 * col + (int)(row / G) * N * 8;
+  half* scaling_factors_ptr2 = scaling_factors + index4;
+
+  uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
+  uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
+  uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
+
+  uint32_t B_loaded = *(uint32_t*)B_ptr2;
+  uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
+  asm volatile("sub.f16x2 %0, %1, %2;\n"
+               : "=r"(B_loaded_fp16.x)
+               : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
+  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
+               : "=r"(B_loaded_fp16.x)
+               : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
+  asm volatile("sub.f16x2 %0, %1, %2;\n"
+               : "=r"(B_loaded_fp16.y)
+               : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
+  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
+               : "=r"(B_loaded_fp16.y)
+               : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
+  asm volatile("sub.f16x2 %0, %1, %2;\n"
+               : "=r"(B_loaded_fp16.z)
+               : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
+  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
+               : "=r"(B_loaded_fp16.z)
+               : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
+  asm volatile("sub.f16x2 %0, %1, %2;\n"
+               : "=r"(B_loaded_fp16.w)
+               : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
+  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
+               : "=r"(B_loaded_fp16.w)
+               : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
+
+  *(uint4*)B_shared_ptr2 = B_loaded_fp16;
+
+  for (int i = 0; i < 8; ++i) {
+    *(C_ptr2 + i) = B_shared[i];
+  }
+}
+
+template <int N>
+__global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
+    int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B,
+    half* __restrict__ scaling_factors, int* __restrict__ zeros,
+    const float* __restrict__ topk_weights,
+    const int* __restrict__ sorted_token_ids_ptr,
+    const int* __restrict__ expert_ids_ptr,
+    const int* __restrict__ num_tokens_post_padded, const int num_valid_tokens,
+    const int top_k, const int expert_num, int pad_M, int M, int IC, int OC,
+    half* __restrict__ C) {
+  // Only support matrix n = 64 or 128
+  assert(N == 64 || N == 128);
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+  assert(false);
+#else
+  int num_tokens = *num_tokens_post_padded;
+  int j_factors1 = ((OC + N - 1) / N);
+  int blockIdx_x = 0;
+  int blockIdx_y = blockIdx.x % ((pad_M + 16 - 1) / 16 * j_factors1);
+  int blockIdx_z = blockIdx.x / ((pad_M + 16 - 1) / 16 * j_factors1);
+  int block = blockIdx_y / j_factors1;
+  if (block * 16 >= num_tokens) return;
+
+  static constexpr uint32_t ZERO = 0x0;
+  float C_warp[32];
+  __shared__ half A_shared[16 * (32 + 8)];
+  __shared__ half B_shared[32 * (N + 8)];
+
+  __shared__ half scaling_factors_shared[N];
+  __shared__ half zeros_shared[N];
+
+  half A_shared_warp[8];
+  half B_shared_warp[N / 4];
+  for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
+    for (int i = 0; i < 8; ++i) {
+      C_warp[(j_0_4_init * 8) + i] = 0.0;
+    }
+  }
+
+  static constexpr int row_stride_warp = 32 * 8 / 32;
+  static constexpr int row_stride = 2 * 32 * 8 / N;
+  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
+  // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+
+  int row = (block * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32);
+  int token_id = sorted_token_ids_ptr[row];
+  bool ld_A_flag = (token_id < num_valid_tokens);
+  half* A_ptr = A + token_id / top_k * IC + (((int)threadIdx.x) % (32 / 8)) * 8;
+
+  int expert_id = expert_ids_ptr[block];
+  B = B + OC * IC / 8 * expert_id;
+  scaling_factors = scaling_factors + OC * IC / G * expert_id;
+  zeros = zeros + OC * IC / G / 8 * expert_id;
+
+  int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) +
+               (((int)threadIdx.x) / (N / 8)) * (OC / 8) +
+               (((int)blockIdx_y) % j_factors1) * (N / 8) +
+               (((int)threadIdx.x) % (N / 8)) * 1;
+  // Why * 1 in the above line?
+
+  half* A_shared_ptr = A_shared +
+                       ((int)threadIdx.y) * row_stride_warp * (32 + 8) +
+                       (((int)threadIdx.x) / (32 / 8)) * (32 + 8) +
+                       (((int)threadIdx.x) % (32 / 8)) * 8;
+
+  half* B_shared_ptr = B_shared +
+                       ((int)threadIdx.y) * (row_stride / 2) * (N + 8) +
+                       (((int)threadIdx.x) / (N / 8)) * (N + 8) +
+                       (((int)threadIdx.x) % (N / 8)) * 8;
+
+  int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) +
+                   ((int)threadIdx.x) % (N / 8);
+
+  half* scaling_factors_ptr = scaling_factors +
+                              (((int)blockIdx_y) % j_factors1) * N +
+                              (((int)threadIdx.x) % (N / 8)) * 8;
+
+  half* C_ptr = C +
+                static_cast<long long>(blockIdx_z) * M * OC *
+                    expert_num  // blockIdz.x -> split_k dim
+                + (((int)blockIdx_y) % j_factors1) * N +
+                ((int)threadIdx.y) * (N / 2) + (((int)threadIdx.x) % 4) * 2;
+
+  // preload s.f. and zeros
+  int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
+  if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
+  for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
+    int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
+    __syncthreads();
+    // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+    if (ld_A_flag) {
+      *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
+    } else {
+      *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
+    }
+
+    uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
+    uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
+    uint4 B_loaded_scale =
+        *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
+
+    int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
+
+    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
+      uint32_t B_loaded =
+          *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
+      uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
+
+      // TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
+      // q * scale - zero * scale.
+      asm volatile("sub.f16x2 %0, %1, %2;\n"
+                   : "=r"(B_loaded_fp16.x)
+                   : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
+                   : "=r"(B_loaded_fp16.x)
+                   : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n"
+                   : "=r"(B_loaded_fp16.y)
+                   : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
+                   : "=r"(B_loaded_fp16.y)
+                   : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n"
+                   : "=r"(B_loaded_fp16.z)
+                   : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
+                   : "=r"(B_loaded_fp16.z)
+                   : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n"
+                   : "=r"(B_loaded_fp16.w)
+                   : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
+                   : "=r"(B_loaded_fp16.w)
+                   : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
+
+      // write back
+      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) =
+          B_loaded_fp16;
+    }
+    __syncthreads();
+
+    for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
+      {
+        unsigned int addr;
+        __asm__ __volatile__(
+            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
+            "addr; }\n"
+            : "=r"(addr)
+            : "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
+                          (((((int)threadIdx.x) & 15) * 40) +
+                           ((((int)threadIdx.x) >> 4) * 8)))));
+
+        __asm__ __volatile__(
+            "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
+            "{%0, %1, %2, %3}, [%4];\n"
+            : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
+              "=r"(((unsigned*)(A_shared_warp + 0))[1]),
+              "=r"(((unsigned*)(A_shared_warp + 0))[2]),
+              "=r"(((unsigned*)(A_shared_warp + 0))[3])
+            : "r"(addr));
+      }
+
+      for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
+        {
+          unsigned int addr;
+          __asm__ __volatile__(
+              "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
+              "addr; }\n"
+              : "=r"(addr)
+              : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
+                                          (((int)threadIdx.y) * (N / 2))) +
+                                         (ax1_0 * 16))])) +
+                            (((((int)threadIdx.x) & 15) * (N + 8)) +
+                             ((((int)threadIdx.x) >> 4) * 8)))));
+          __asm__ __volatile__(
+              "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
+              "{%0, %1, %2, %3}, [%4];\n"
+              : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
+                "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
+                "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]),
+                "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3])
+              : "r"(addr));
+        }
+      }
+      for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
+  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
+        {
+          __asm__ __volatile__(
+              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+              "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+              : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
+              : "r"(((unsigned*)(A_shared_warp + 0))[0]),
+                "r"(((unsigned*)(A_shared_warp + 0))[1]),
+                "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
+        }
+
+        {
+          __asm__ __volatile__(
+              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+              "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+              : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
+              : "r"(((unsigned*)(A_shared_warp + 0))[0]),
+                "r"(((unsigned*)(A_shared_warp + 0))[1]),
+                "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+        }
+
+        {
+          __asm__ __volatile__(
+              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+              "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+              : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
+              : "r"(((unsigned*)(A_shared_warp + 0))[2]),
+                "r"(((unsigned*)(A_shared_warp + 0))[3]),
+                "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
+        }
+
+        {
+          __asm__ __volatile__(
+              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+              "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+              : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
+              : "r"(((unsigned*)(A_shared_warp + 0))[2]),
+                "r"(((unsigned*)(A_shared_warp + 0))[3]),
+                "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+        }
+  #else
+        {
+          __asm__ __volatile__(
+              "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+              "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
+              "%13};\n"
+              : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
+                "=f"(((float*)(C_warp + (j_0_4 * 8)))[3])
+              : "r"(((unsigned*)(A_shared_warp + 0))[0]),
+                "r"(((unsigned*)(A_shared_warp + 0))[1]),
+                "r"(((unsigned*)(A_shared_warp + 0))[2]),
+                "r"(((unsigned*)(A_shared_warp + 0))[3]),
+                "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]),
+                "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[2]),
+                "f"(((float*)(C_warp + (j_0_4 * 8)))[3]));
+        }
+
+        {
+          __asm__ __volatile__(
+              "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+              "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
+              "%13};\n"
+              : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
+                "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])
+              : "r"(((unsigned*)(A_shared_warp + 0))[0]),
+                "r"(((unsigned*)(A_shared_warp + 0))[1]),
+                "r"(((unsigned*)(A_shared_warp + 0))[2]),
+                "r"(((unsigned*)(A_shared_warp + 0))[3]),
+                "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
+                "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]),
+                "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+        }
+
+  #endif
+      }
+    }
+  }
+
+  // TODO: Shang: Hoist loop invariance.
+  for (int ax1_0_1 = 0; ax1_0_1 < N / 32; ++ax1_0_1) {
+    for (int local_id = 0; local_id < 8; ++local_id) {
+      int row_offset =
+          block * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
+      int token_id = sorted_token_ids_ptr[row_offset];
+      if (token_id < num_valid_tokens) {
+        float value = C_warp[(ax1_0_1 * 8) + local_id];
+        if (topk_weights) {
+          value = value * topk_weights[token_id];
+        }
+        *(C_ptr + ax1_0_1 * 16 + token_id * OC + (local_id / 4) * 8 +
+          local_id % 2) = __float2half(value);
+      }
+    }
+  }
+#endif
+}
+
+}  // namespace awq
+}  // namespace aphrodite
+
+torch::Tensor awq_dequantize(torch::Tensor _kernel,
+                             torch::Tensor _scaling_factors,
+                             torch::Tensor _zeros, int64_t split_k_iters,
+                             int64_t thx, int64_t thy) {
+  int in_c = _kernel.dim() == 2 ? _kernel.size(0) : _kernel.size(1);
+  int qout_c = _kernel.dim() == 2 ? _kernel.size(1) : _kernel.size(2);
+  int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0);
+  int out_c = qout_c * 8;
+  int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0)
+                                     : _scaling_factors.size(1));
+
+  int x_thread = thx;
+  int y_thread = thy;
+
+  int x_blocks = 1;
+  int y_blocks = 1;
+  if (thx == 0) {
+    x_thread = qout_c;
+  }
+  if (thy == 0) {
+    y_thread = in_c;
+  }
+  if (thx == 0 && thy == 0) {
+    x_thread = 8;
+    y_thread = 8;
+    x_blocks = (int)(qout_c / 8);
+    y_blocks = (int)(in_c / 8);
+  }
+
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
+
+  auto options = torch::TensorOptions()
+                     .dtype(_scaling_factors.dtype())
+                     .device(_scaling_factors.device());
+  at::Tensor _de_kernel;
+  if (num_experts == 1) {
+    _de_kernel = torch::empty({in_c, out_c}, options);
+  } else {
+    _de_kernel = torch::empty({num_experts, in_c, out_c}, options);
+  }
+
+  auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
+  auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
+  auto scaling_factors =
+      reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
+  auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
+
+  dim3 num_blocks(x_blocks, y_blocks, num_experts);
+  dim3 threads_per_block(x_thread, y_thread);
+
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  aphrodite::awq::
+      dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
+          kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c);
+
+  return _de_kernel;
+}
+
+// in_feats: M, IC [float16]
+// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
+// scaling_factors: IC // G, OC [float16]
+// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
+// assume that batch_size < 16 for now
+
+torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
+                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
+                       int64_t split_k_iters) {
+  int num_in_feats = _in_feats.size(0);
+  int num_in_channels = _in_feats.size(1);
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
+
+  auto options = torch::TensorOptions()
+                     .dtype(_in_feats.dtype())
+                     .device(_in_feats.device());
+  at::Tensor _out_feats =
+      torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
+  int num_out_feats = _out_feats.size(-2);
+  int num_out_channels = _out_feats.size(-1);
+
+  auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
+  auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
+  auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
+  auto scaling_factors =
+      reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
+  auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
+  int group_size = num_in_channels / _scaling_factors.size(0);
+
+  if (num_out_channels % 64 != 0)
+    throw std::invalid_argument("OC is not multiple of cta_N = 64");
+  if (num_out_channels % 8 != 0)
+    throw std::invalid_argument("OC is not multiple of pack_num = 8");
+  if (group_size % 32 != 0)
+    throw std::invalid_argument("Group size should be a multiple of 32");
+  if (num_out_channels % group_size != 0)
+    throw std::invalid_argument("OC is not multiple of Group size");
+
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  if (num_out_channels % 128 == 0) {
+    int j_factors1 = num_out_channels / 128 / 1;
+    dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
+    // threadIdx.x: 32
+    // threadIdx.y: i_factors[2] * j_factors[2]
+    dim3 threads_per_block(32, 2);
+    aphrodite::awq::gemm_forward_4bit_cuda_m16nXk32<128>
+        <<<num_blocks, threads_per_block, 0, stream>>>(
+            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
+            num_in_feats, num_in_channels, num_out_channels, out_feats);
+  } else if (num_out_channels % 64 == 0) {
+    int j_factors1 = num_out_channels / 64 / 1;
+    dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 *
+                    split_k_iters);
+
+    // threadIdx.x: 32
+    // threadIdx.y: i_factors[2] * j_factors[2]
+    dim3 threads_per_block(32, 2);
+    aphrodite::awq::gemm_forward_4bit_cuda_m16nXk32<64>
+        <<<num_blocks, threads_per_block, 0, stream>>>(
+            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
+            num_in_feats, num_in_channels, num_out_channels, out_feats);
+  }
+  return _out_feats.sum(0);
+}
+
+torch::Tensor awq_group_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
+                             torch::Tensor _scaling_factors,
+                             torch::Tensor _zeros, torch::Tensor _topk_weights,
+                             torch::Tensor _sorted_token_ids_ptr,
+                             torch::Tensor _expert_ids_ptr,
+                             torch::Tensor _num_tokens_post_padded,
+                             bool mul_weights, int split_k_iters) {
+  int num_in_feats = _in_feats.size(0);
+  int pad_num_in_feats = _sorted_token_ids_ptr.size(0);
+  int num_in_channels = _in_feats.size(2);
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
+
+  auto options = torch::TensorOptions()
+                     .dtype(_in_feats.dtype())
+                     .device(_in_feats.device());
+  int num_experts = _topk_weights.size(1);
+  int top_k = num_experts / _in_feats.size(1);
+  int group_size = num_in_channels / _scaling_factors.size(1);
+
+  at::Tensor _out_feats = torch::empty(
+      {split_k_iters, num_in_feats, _topk_weights.size(1), _kernel.size(2) * 8},
+      options);
+  int num_out_channels = _out_feats.size(-1);
+
+  auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
+  auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
+  auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
+  auto scaling_factors =
+      reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
+  auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
+  auto topk_weights = mul_weights
+                          ? reinterpret_cast<float*>(_topk_weights.data_ptr())
+                          : nullptr;
+  auto sorted_token_ids_ptr =
+      reinterpret_cast<int*>(_sorted_token_ids_ptr.data_ptr());
+  auto expert_ids_ptr = reinterpret_cast<int*>(_expert_ids_ptr.data_ptr());
+  auto num_tokens_post_padded =
+      reinterpret_cast<int*>(_num_tokens_post_padded.data_ptr());
+
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  if (num_out_channels % 128 == 0) {
+    int j_factors1 = num_out_channels / 128 / 1;
+    dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 *
+                    split_k_iters);
+    // threadIdx.x: 32
+    // threadIdx.y: i_factors[2] * j_factors[2]
+    dim3 threads_per_block(32, 2);
+    aphrodite::awq::group_gemm_forward_4bit_cuda_m16nXk32<128>
+        <<<num_blocks, threads_per_block, 0, stream>>>(
+            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
+            topk_weights, sorted_token_ids_ptr, expert_ids_ptr,
+            num_tokens_post_padded, _topk_weights.numel(), top_k, num_experts,
+            pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels,
+            out_feats);
+  } else if (num_out_channels % 64 == 0) {
+    int j_factors1 = num_out_channels / 64 / 1;
+    dim3 num_blocks((pad_num_in_feats + 16 - 1) / 16 * j_factors1 *
+                    split_k_iters);
+
+    // threadIdx.x: 32
+    // threadIdx.y: i_factors[2] * j_factors[2]
+    dim3 threads_per_block(32, 2);
+    aphrodite::awq::group_gemm_forward_4bit_cuda_m16nXk32<64>
+        <<<num_blocks, threads_per_block, 0, stream>>>(
+            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros,
+            topk_weights, sorted_token_ids_ptr, expert_ids_ptr,
+            num_tokens_post_padded, _topk_weights.numel(), top_k, num_experts,
+            pad_num_in_feats, num_in_feats, num_in_channels, num_out_channels,
+            out_feats);
+  }
+  return _out_feats.sum(0);
+}

+ 1 - 1
kernels/quantization/compressed_tensors/int8_quant_kernels.cu

@@ -1,5 +1,5 @@
 #include <ATen/cuda/CUDAContext.h>
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <cmath>
 
 #include "../../dispatch_utils.h"

+ 1 - 1
kernels/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu

@@ -1,5 +1,5 @@
 #include <stddef.h>
-#include <torch/extension.h>
+#include <torch/all.h>
 
 #include <ATen/cuda/CUDAContext.h>
 

+ 1 - 1
kernels/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu

@@ -4,7 +4,7 @@
 
 #if defined CUDA_VERSION && CUDA_VERSION >= 12000
 
-#include <torch/extension.h>
+#include <torch/all.h>
 
 #include <ATen/cuda/CUDAContext.h>
 

+ 1 - 1
kernels/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu

@@ -1,7 +1,7 @@
 #include <cudaTypedefs.h>
 
 #include <c10/cuda/CUDAGuard.h>
-#include <torch/extension.h>
+#include <torch/all.h>
 
 void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
                                torch::Tensor const& b,

+ 101 - 180
kernels/quantization/exl2/q_gemm_exl2.cu

@@ -9,8 +9,8 @@
  * copies of the Software, and to permit persons to whom the Software is
  * furnished to do so, subject to the following conditions:
  *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
  *
  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
@@ -20,7 +20,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <cuda_runtime.h>
@@ -43,196 +43,117 @@ namespace exl2 {
 #define EXL2_BLOCK_M_SIZE_MAX 8
 #define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
 #if defined(USE_ROCM)
-__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t    handle,
-                                                               hipblasOperation_t transA,
-                                                               hipblasOperation_t transB,
-                                                               int                m,
-                                                               int                n,
-                                                               int                k,
-                                                               const half*        alpha,
-                                                               const half*        AP,
-                                                               int                lda,
-                                                               const half*        BP,
-                                                               int                ldb,
-                                                               const half*        beta,
-                                                               half*              CP,
-                                                               int                ldc) {
-    return hipblasHgemm(handle, transA, transB, m, n, k,
-                        reinterpret_cast<const hipblasHalf *>(alpha),
-                        reinterpret_cast<const hipblasHalf *>(AP), lda,
-                        reinterpret_cast<const hipblasHalf *>(BP), ldb,
-                        reinterpret_cast<const hipblasHalf *>(beta),
-                        reinterpret_cast<hipblasHalf *>(CP), ldc);
+__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(
+    hipblasHandle_t handle, hipblasOperation_t transA,
+    hipblasOperation_t transB, int m, int n, int k, const half* alpha,
+    const half* AP, int lda, const half* BP, int ldb, const half* beta,
+    half* CP, int ldc) {
+  return hipblasHgemm(handle, transA, transB, m, n, k,
+                      reinterpret_cast<const hipblasHalf*>(alpha),
+                      reinterpret_cast<const hipblasHalf*>(AP), lda,
+                      reinterpret_cast<const hipblasHalf*>(BP), ldb,
+                      reinterpret_cast<const hipblasHalf*>(beta),
+                      reinterpret_cast<hipblasHalf*>(CP), ldc);
 }
-#define hipblasHgemm __compat_hipblasHgemm
+  #define hipblasHgemm __compat_hipblasHgemm
 #endif
 #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
 
-void gemm_half_q_half_cuda_part
-(
-    const half* a,
-    QMatrix* b,
-    half* c,
-    int size_m,
-    int size_n,
-    int size_k,
-    int m_count,
-    bool clear
-)
-{
-    {
-        dim3 blockDim, gridDim;
-        blockDim.x = EXL2_BLOCK_KN_SIZE;
-        blockDim.y = 1;
-        blockDim.z = 1;
-        gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);
-        gridDim.y = DIVIDE(size_m, m_count);
-        gridDim.z = DIVIDE(b->height, EXL2_BLOCK_KN_SIZE);
-
-        fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count);
-        const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-
-        kernel<<<gridDim, blockDim, 0, stream>>>
-        (
-            a,
-            b->cuda_q_weight,
-            b->cuda_q_scale,
-            b->cuda_q_scale_max,
-            c,
-            size_m,
-            size_n,
-            size_k,
-            b->height,
-            b->groups,
-            b->cuda_q_group_map,
-            b->cuda_q_perm,
-            b->rows_8,
-            b->rows_6,
-            b->rows_5,
-            b->rows_4,
-            b->rows_3,
-            b->rows_2,
-            clear
-        );
-    }
-
+void gemm_half_q_half_cuda_part(const half* a, QMatrix* b, half* c, int size_m,
+                                int size_n, int size_k, int m_count,
+                                bool clear) {
+  {
+    dim3 blockDim, gridDim;
+    blockDim.x = EXL2_BLOCK_KN_SIZE;
+    blockDim.y = 1;
+    blockDim.z = 1;
+    gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);
+    gridDim.y = DIVIDE(size_m, m_count);
+    gridDim.z = DIVIDE(b->height, EXL2_BLOCK_KN_SIZE);
+
+    fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count);
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+    kernel<<<gridDim, blockDim, 0, stream>>>(
+        a, b->cuda_q_weight, b->cuda_q_scale, b->cuda_q_scale_max, c, size_m,
+        size_n, size_k, b->height, b->groups, b->cuda_q_group_map,
+        b->cuda_q_perm, b->rows_8, b->rows_6, b->rows_5, b->rows_4, b->rows_3,
+        b->rows_2, clear);
+  }
 }
 
-void gemm_half_q_half_cuda
-(
-    cublasHandle_t cublas_handle,
-    const half* a,
-    QMatrix* b,
-    half* c,
-    int size_m,
-    int size_n,
-    int size_k,
-    bool clear,
-    half* temp_dq
-)
-{
-    if (size_m > MAX_Q_GEMM_ROWS)
-    {
-        // Reconstruct FP16 matrix, then cuBLAS
-        b->reconstruct(temp_dq);
-
-        //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
-
-        const half alpha = __float2half(1.0f);
-        const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
-        cublasHgemm(cublas_handle,
-                    CUBLAS_OP_N,
-                    CUBLAS_OP_N,
-                    size_n, size_m, size_k,
-                    &alpha, temp_dq, size_n,
-                            a,       size_k,
-                    &beta,  c,       size_n);
+void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
+                           QMatrix* b, half* c, int size_m, int size_n,
+                           int size_k, bool clear, half* temp_dq) {
+  if (size_m > MAX_Q_GEMM_ROWS) {
+    // Reconstruct FP16 matrix, then cuBLAS
+    b->reconstruct(temp_dq);
+
+    // cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
+
+    const half alpha = __float2half(1.0f);
+    const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
+    cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k,
+                &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n);
+  } else {
+    // Quantized matmul
+
+    int block_m_size_max = EXL2_BLOCK_M_SIZE_MAX;
+    int max_chunks = size_m / block_m_size_max;
+    int last_chunk = max_chunks * block_m_size_max;
+    int last_chunk_size = size_m - last_chunk;
+
+    if (max_chunks) {
+      gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k,
+                                 block_m_size_max, clear);
     }
-    else
-    {
-        // Quantized matmul
-
-        int block_m_size_max = EXL2_BLOCK_M_SIZE_MAX;
-        int max_chunks = size_m / block_m_size_max;
-        int last_chunk = max_chunks * block_m_size_max;
-        int last_chunk_size = size_m - last_chunk;
-
-        if (max_chunks)
-        {
-            gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear);
-        }
-
-        if (last_chunk_size)
-        {
-            gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
-        }
+
+    if (last_chunk_size) {
+      gemm_half_q_half_cuda_part(a + last_chunk * size_k, b,
+                                 c + last_chunk * size_n, last_chunk_size,
+                                 size_n, size_k, last_chunk_size, clear);
     }
+  }
 }
 
 }  // namespace exl2
 }  // namespace aphrodite
 
-torch::Tensor exl2_gemm
-(
-    torch::Tensor a,
-    uintptr_t b
-)
-{
-    const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
-    aphrodite::exl2::QMatrix* qm = reinterpret_cast<aphrodite::exl2::QMatrix*> (b);
-
-    auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
-    at::Tensor c = torch::empty({a.size(0), qm->width}, options);
-    at::Tensor temp_dq;
-    if (c.size(0) > MAX_Q_GEMM_ROWS) {
-      temp_dq = torch::zeros({a.size(1), qm->width}, options);
-    }
-
-    aphrodite::exl2::gemm_half_q_half_cuda
-    (
-        at::cuda::getCurrentCUDABlasHandle(),
-        (const half*) a.data_ptr(),
-        qm,
-        (half*) c.data_ptr(),
-        c.size(0),  // m
-        c.size(1),  // n
-        a.size(1),  // k
-        true,
-        c.size(0) > MAX_Q_GEMM_ROWS? (half*)temp_dq.data_ptr() : NULL
-    );
-    return c;
+torch::Tensor exl2_gemm(torch::Tensor a, uintptr_t b) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
+  aphrodite::exl2::QMatrix* qm = reinterpret_cast<aphrodite::exl2::QMatrix*>(b);
+
+  auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
+  at::Tensor c = torch::empty({a.size(0), qm->width}, options);
+  at::Tensor temp_dq;
+  if (c.size(0) > MAX_Q_GEMM_ROWS) {
+    temp_dq = torch::zeros({a.size(1), qm->width}, options);
+  }
+
+  aphrodite::exl2::gemm_half_q_half_cuda(
+      at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(), qm,
+      (half*)c.data_ptr(),
+      c.size(0),  // m
+      c.size(1),  // n
+      a.size(1),  // k
+      true, c.size(0) > MAX_Q_GEMM_ROWS ? (half*)temp_dq.data_ptr() : NULL);
+  return c;
 }
 
-uintptr_t make_q_matrix
-(
-    torch::Tensor q_weight,
-    torch::Tensor q_perm,
-    torch::Tensor q_invperm,
-    torch::Tensor q_scale,
-    torch::Tensor q_scale_max,
-    torch::Tensor q_groups,
-    torch::Tensor q_group_map
-)
-{
-    const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
-    int device = q_weight.device().index();
-    int width = q_weight.size(1);
-    int groups = q_scale.size(0);
-    int height = q_perm.size(0);
-
-    aphrodite::exl2::QMatrix* m = new aphrodite::exl2::QMatrix
-    (
-        device,
-        height,
-        width,
-        groups,
-        (uint32_t*) q_weight.data_ptr(),
-        (uint16_t*) q_perm.data_ptr(),
-        (uint16_t*) q_invperm.data_ptr(),
-        (uint32_t*) q_scale.data_ptr(),
-        (half*) q_scale_max.data_ptr(),
-        (uint16_t*) q_groups.data_ptr(),
-        (uint16_t*) q_group_map.data_ptr()
-    );
-    return reinterpret_cast<uintptr_t>(m);
+uintptr_t make_q_matrix(torch::Tensor q_weight, torch::Tensor q_perm,
+                        torch::Tensor q_invperm, torch::Tensor q_scale,
+                        torch::Tensor q_scale_max, torch::Tensor q_groups,
+                        torch::Tensor q_group_map) {
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
+  int device = q_weight.device().index();
+  int width = q_weight.size(1);
+  int groups = q_scale.size(0);
+  int height = q_perm.size(0);
+
+  aphrodite::exl2::QMatrix* m = new aphrodite::exl2::QMatrix(
+      device, height, width, groups, (uint32_t*)q_weight.data_ptr(),
+      (uint16_t*)q_perm.data_ptr(), (uint16_t*)q_invperm.data_ptr(),
+      (uint32_t*)q_scale.data_ptr(), (half*)q_scale_max.data_ptr(),
+      (uint16_t*)q_groups.data_ptr(), (uint16_t*)q_group_map.data_ptr());
+  return reinterpret_cast<uintptr_t>(m);
 }

+ 311 - 305
kernels/quantization/exl2/q_matrix.cu

@@ -9,8 +9,8 @@
  * copies of the Software, and to permit persons to whom the Software is
  * furnished to do so, subject to the following conditions:
  *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
  *
  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
@@ -20,7 +20,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <cuda_runtime.h>
@@ -35,7 +35,6 @@
 #include "quant/qdq_6.cuh"
 #include "quant/qdq_8.cuh"
 
-
 namespace aphrodite {
 namespace exl2 {
 
@@ -48,329 +47,336 @@ namespace exl2 {
 
 // Shuffle quantized data on load
 
-__global__ void shuffle_kernel
-(
-    uint32_t* __restrict__ b_q_weight,
-    const int size_k,
-    const int size_n,
-    const int rows_8,
-    const int rows_6,
-    const int rows_5,
-    const int rows_4,
-    const int rows_3,
-    const int rows_2
-)
-{
-    int n = blockIdx.x * THREADS_X + threadIdx.x;
-    if (n >= size_n) return;
-    int k = 0;
-    uint32_t* b_ptr = b_q_weight + n;
-    while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k +=  4; }
-    while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
-    while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
-    while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k +=  8; }
-    while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
-    while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
+__global__ void shuffle_kernel(uint32_t* __restrict__ b_q_weight,
+                               const int size_k, const int size_n,
+                               const int rows_8, const int rows_6,
+                               const int rows_5, const int rows_4,
+                               const int rows_3, const int rows_2) {
+  int n = blockIdx.x * THREADS_X + threadIdx.x;
+  if (n >= size_n) return;
+  int k = 0;
+  uint32_t* b_ptr = b_q_weight + n;
+  while (k < rows_8) {
+    shuffle_8bit_4(b_ptr, size_n);
+    b_ptr += 1 * size_n;
+    k += 4;
+  }
+  while (k < rows_6) {
+    shuffle_6bit_16(b_ptr, size_n);
+    b_ptr += 3 * size_n;
+    k += 16;
+  }
+  while (k < rows_5) {
+    shuffle_5bit_32(b_ptr, size_n);
+    b_ptr += 5 * size_n;
+    k += 32;
+  }
+  while (k < rows_4) {
+    shuffle_4bit_8(b_ptr, size_n);
+    b_ptr += 1 * size_n;
+    k += 8;
+  }
+  while (k < rows_3) {
+    shuffle_3bit_32(b_ptr, size_n);
+    b_ptr += 3 * size_n;
+    k += 32;
+  }
+  while (k < rows_2) {
+    shuffle_2bit_16(b_ptr, size_n);
+    b_ptr += 1 * size_n;
+    k += 16;
+  }
 }
 
-
 // QMatrix constructor
 
-QMatrix::QMatrix
-(
-    const int _device,
-    const int _height,
-    const int _width,
-    const int _groups,
-
-    uint32_t* _q_weight,
-    uint16_t* _q_perm,
-    uint16_t* _q_invperm,
-    uint32_t* _q_scale,
-    half* _q_scale_max,
-    uint16_t* _q_groups,
-    uint16_t* _q_group_map
-):
-    device(_device),
-    height(_height),
-    width(_width),
-    groups(_groups)
-{
-    cudaSetDevice(device);
-
-    failed = false;
-
-    cuda_q_weight = _q_weight;
-    cuda_q_perm = _q_perm;
-    cuda_q_invperm = _q_invperm;
-    cuda_q_scale = _q_scale;
-    cuda_q_scale_max = _q_scale_max;
-    cuda_q_groups = _q_groups;
-    cuda_q_group_map = _q_group_map;
-
-    // Create group map
-
-    rows_8 = 0;
-    rows_6 = 0;
-    rows_5 = 0;
-    rows_4 = 0;
-    rows_3 = 0;
-    rows_2 = 0;
-
-    {
-        uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
-        cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
-
-        int row = 0;
-        for (int i = 0; i < groups; i++)
-        {
-            int bits = cpu_q_groups[i * 2];
-
-            int rows;
-            if (i < groups - 1)
-            {
-                int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
-                rows = qrows * 32 / bits;
-            }
-            else rows = height - row;
-
-            if (bits == 8) rows_8 += rows;
-            if (bits == 6) rows_6 += rows;
-            if (bits == 5) rows_5 += rows;
-            if (bits == 4) rows_4 += rows;
-            if (bits == 3) rows_3 += rows;
-            if (bits == 2) rows_2 += rows;
-            row += rows;
-        }
-
-        free(cpu_q_groups);
-
-        rows_6 += rows_8;
-        rows_5 += rows_6;
-        rows_4 += rows_5;
-        rows_3 += rows_4;
-        rows_2 += rows_3;
+QMatrix::QMatrix(const int _device, const int _height, const int _width,
+                 const int _groups,
+
+                 uint32_t* _q_weight, uint16_t* _q_perm, uint16_t* _q_invperm,
+                 uint32_t* _q_scale, half* _q_scale_max, uint16_t* _q_groups,
+                 uint16_t* _q_group_map)
+    : device(_device), height(_height), width(_width), groups(_groups) {
+  cudaSetDevice(device);
+
+  failed = false;
+
+  cuda_q_weight = _q_weight;
+  cuda_q_perm = _q_perm;
+  cuda_q_invperm = _q_invperm;
+  cuda_q_scale = _q_scale;
+  cuda_q_scale_max = _q_scale_max;
+  cuda_q_groups = _q_groups;
+  cuda_q_group_map = _q_group_map;
+
+  // Create group map
+
+  rows_8 = 0;
+  rows_6 = 0;
+  rows_5 = 0;
+  rows_4 = 0;
+  rows_3 = 0;
+  rows_2 = 0;
+
+  {
+    uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
+    cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t),
+               cudaMemcpyDeviceToHost);
+
+    int row = 0;
+    for (int i = 0; i < groups; i++) {
+      int bits = cpu_q_groups[i * 2];
+
+      int rows;
+      if (i < groups - 1) {
+        int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
+        rows = qrows * 32 / bits;
+      } else
+        rows = height - row;
+
+      if (bits == 8) rows_8 += rows;
+      if (bits == 6) rows_6 += rows;
+      if (bits == 5) rows_5 += rows;
+      if (bits == 4) rows_4 += rows;
+      if (bits == 3) rows_3 += rows;
+      if (bits == 2) rows_2 += rows;
+      row += rows;
     }
 
-    // Shuffle quantized data
+    free(cpu_q_groups);
 
-    dim3 blockDim, gridDim;
-    blockDim.x = THREADS_X;
-    blockDim.y = 1;
-    gridDim.x = DIVIDE(width, THREADS_X);
-    gridDim.y = 1;
-    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    rows_6 += rows_8;
+    rows_5 += rows_6;
+    rows_4 += rows_5;
+    rows_3 += rows_4;
+    rows_2 += rows_3;
+  }
 
-    shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(
-        cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
-}
+  // Shuffle quantized data
+
+  dim3 blockDim, gridDim;
+  blockDim.x = THREADS_X;
+  blockDim.y = 1;
+  gridDim.x = DIVIDE(width, THREADS_X);
+  gridDim.y = 1;
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
-QMatrix::~QMatrix()
-{
+  shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width,
+                                                   rows_8, rows_6, rows_5,
+                                                   rows_4, rows_3, rows_2);
 }
 
+QMatrix::~QMatrix() {}
 
 // Reconstruct b[k,n]
 
-__global__ void reconstruct_kernel
-(
-    const uint32_t* __restrict__ b_q_weight,
-    const uint16_t* __restrict__ b_q_perm,
-    const uint32_t* __restrict__ b_q_scale,
-    const half* __restrict__ b_q_scale_max,
-    const uint16_t* __restrict__ b_q_group_map,
-    const int size_k,
-    const int size_n,
-    //const int groupsize,
-    const int groups,
-    half* __restrict__ b,
-    const int rows_8,
-    const int rows_6,
-    const int rows_5,
-    const int rows_4,
-    const int rows_3,
-    const int rows_2
-)
-{
-    MatrixView_half_rw b_(b, size_k, size_n);
-    MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
-
-    int offset_k = BLOCK_KN_SIZE * blockIdx.y;
-    int offset_n = BLOCK_KN_SIZE * blockIdx.x;
-
-    // Preload remapping table
-
-    int t = threadIdx.x;
-    __shared__ uint16_t perm[BLOCK_KN_SIZE];
-    if (offset_k + t < size_k)
-        perm[t] = b_q_perm[offset_k + t];
-
-    // Column
-
-    int n = offset_n + t;
-    if (n >= size_n) return;
-
-    // Find initial group
-
-    // int group = offset_k / groupsize;
-    int group = b_q_group_map[offset_k * 2];
-
-    int pre_rows_8 = min(rows_8, offset_k);
-    int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
-    int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
-    int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
-    int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
-    int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
-    int qk = 0;
-    qk += pre_rows_8 / 32 * 8;
-    qk += pre_rows_6 / 32 * 6;
-    qk += pre_rows_5 / 32 * 5;
-    qk += pre_rows_4 / 32 * 4;
-    qk += pre_rows_3 / 32 * 3;
-    qk += pre_rows_2 / 32 * 2;
-
-    const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
-
-    half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
-    half2 qs_h2 = __halves2half2(qs_h, qs_h);
-    int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
-
-    int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
-    int k = offset_k;
-    int lk = 0;
-
-    __syncthreads();
-
-    while (k < rows_8 && k < end_k)
-    {
-        if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
-        for (int p = 0; p < 4; p++)
-        {
-            half2 dq[4];
-            uint32_t q_0 = *b_ptr; b_ptr += size_n;
-            uint32_t q_1 = *b_ptr; b_ptr += size_n;
-            dequant_8bit_8(q_0, q_1, dq, size_n);
-            for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
-            half* dqh = (half*) dq;
-            for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
-        }
-        k += 32;
+__global__ void reconstruct_kernel(const uint32_t* __restrict__ b_q_weight,
+                                   const uint16_t* __restrict__ b_q_perm,
+                                   const uint32_t* __restrict__ b_q_scale,
+                                   const half* __restrict__ b_q_scale_max,
+                                   const uint16_t* __restrict__ b_q_group_map,
+                                   const int size_k, const int size_n,
+                                   // const int groupsize,
+                                   const int groups, half* __restrict__ b,
+                                   const int rows_8, const int rows_6,
+                                   const int rows_5, const int rows_4,
+                                   const int rows_3, const int rows_2) {
+  MatrixView_half_rw b_(b, size_k, size_n);
+  MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
+
+  int offset_k = BLOCK_KN_SIZE * blockIdx.y;
+  int offset_n = BLOCK_KN_SIZE * blockIdx.x;
+
+  // Preload remapping table
+
+  int t = threadIdx.x;
+  __shared__ uint16_t perm[BLOCK_KN_SIZE];
+  if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
+
+  // Column
+
+  int n = offset_n + t;
+  if (n >= size_n) return;
+
+  // Find initial group
+
+  // int group = offset_k / groupsize;
+  int group = b_q_group_map[offset_k * 2];
+
+  int pre_rows_8 = min(rows_8, offset_k);
+  int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
+  int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
+  int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
+  int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
+  int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
+  int qk = 0;
+  qk += pre_rows_8 / 32 * 8;
+  qk += pre_rows_6 / 32 * 6;
+  qk += pre_rows_5 / 32 * 5;
+  qk += pre_rows_4 / 32 * 4;
+  qk += pre_rows_3 / 32 * 3;
+  qk += pre_rows_2 / 32 * 2;
+
+  const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
+
+  half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
+  half2 qs_h2 = __halves2half2(qs_h, qs_h);
+  int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
+
+  int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
+  int k = offset_k;
+  int lk = 0;
+
+  __syncthreads();
+
+  while (k < rows_8 && k < end_k) {
+    if (k == nextgroup) {
+      group++;
+      qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
+      nextgroup += b_q_group_map[k * 2 + 1];
+      qs_h2 = __halves2half2(qs_h, qs_h);
     }
-
-    while (k < rows_6 && k < end_k)
-    {
-        if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
-        for (int p = 0; p < 2; p++)
-        {
-            half2 dq[8];
-            uint32_t q_0 = *b_ptr; b_ptr += size_n;
-            uint32_t q_1 = *b_ptr; b_ptr += size_n;
-            uint32_t q_2 = *b_ptr; b_ptr += size_n;
-            dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
-            for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
-            half* dqh = (half*) dq;
-            for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
-        }
-        k += 32;
+    for (int p = 0; p < 4; p++) {
+      half2 dq[4];
+      uint32_t q_0 = *b_ptr;
+      b_ptr += size_n;
+      uint32_t q_1 = *b_ptr;
+      b_ptr += size_n;
+      dequant_8bit_8(q_0, q_1, dq, size_n);
+      for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
+      half* dqh = (half*)dq;
+      for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
     }
-
-    while (k < rows_5 && k < end_k)
-    {
-        if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
-        for (int p = 0; p < 1; p++)
-        {
-            half2 dq[16];
-            uint32_t q_0 = *b_ptr; b_ptr += size_n;
-            uint32_t q_1 = *b_ptr; b_ptr += size_n;
-            uint32_t q_2 = *b_ptr; b_ptr += size_n;
-            uint32_t q_3 = *b_ptr; b_ptr += size_n;
-            uint32_t q_4 = *b_ptr; b_ptr += size_n;
-            dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
-            for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
-            half* dqh = (half*) dq;
-            for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
-        }
-        k += 32;
+    k += 32;
+  }
+
+  while (k < rows_6 && k < end_k) {
+    if (k == nextgroup) {
+      group++;
+      qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
+      nextgroup += b_q_group_map[k * 2 + 1];
+      qs_h2 = __halves2half2(qs_h, qs_h);
     }
-
-    while (k < rows_4 && k < end_k)
-    {
-        if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
-        for (int p = 0; p < 4; p++)
-        {
-            half2 dq[4];
-            uint32_t q_0 = *b_ptr; b_ptr += size_n;
-            dequant_4bit_8(q_0, dq, size_n);
-            for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
-            half* dqh = (half*) dq;
-            for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
-        }
-        k += 32;
+    for (int p = 0; p < 2; p++) {
+      half2 dq[8];
+      uint32_t q_0 = *b_ptr;
+      b_ptr += size_n;
+      uint32_t q_1 = *b_ptr;
+      b_ptr += size_n;
+      uint32_t q_2 = *b_ptr;
+      b_ptr += size_n;
+      dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
+      for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
+      half* dqh = (half*)dq;
+      for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
     }
-
-    while (k < rows_3 && k < end_k)
-    {
-        if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
-        for (int p = 0; p < 1; p++)
-        {
-            half2 dq[16];
-            uint32_t q_0 = *b_ptr; b_ptr += size_n;
-            uint32_t q_1 = *b_ptr; b_ptr += size_n;
-            uint32_t q_2 = *b_ptr; b_ptr += size_n;
-            dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
-            for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
-            half* dqh = (half*) dq;
-            for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
-        }
-        k += 32;
+    k += 32;
+  }
+
+  while (k < rows_5 && k < end_k) {
+    if (k == nextgroup) {
+      group++;
+      qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
+      nextgroup += b_q_group_map[k * 2 + 1];
+      qs_h2 = __halves2half2(qs_h, qs_h);
     }
-
-    while (k < rows_2 && k < end_k)
-    {
-        if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
-        for (int p = 0; p < 1; p++)
-        {
-            half2 dq[8];
-            uint32_t q_0 = *b_ptr; b_ptr += size_n;
-            dequant_2bit_16(q_0, dq, size_n);
-            for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
-            half* dqh = (half*) dq;
-            for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
-        }
-        k += 16;
+    for (int p = 0; p < 1; p++) {
+      half2 dq[16];
+      uint32_t q_0 = *b_ptr;
+      b_ptr += size_n;
+      uint32_t q_1 = *b_ptr;
+      b_ptr += size_n;
+      uint32_t q_2 = *b_ptr;
+      b_ptr += size_n;
+      uint32_t q_3 = *b_ptr;
+      b_ptr += size_n;
+      uint32_t q_4 = *b_ptr;
+      b_ptr += size_n;
+      dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
+      for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
+      half* dqh = (half*)dq;
+      for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
     }
+    k += 32;
+  }
+
+  while (k < rows_4 && k < end_k) {
+    if (k == nextgroup) {
+      group++;
+      qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
+      nextgroup += b_q_group_map[k * 2 + 1];
+      qs_h2 = __halves2half2(qs_h, qs_h);
+    }
+    for (int p = 0; p < 4; p++) {
+      half2 dq[4];
+      uint32_t q_0 = *b_ptr;
+      b_ptr += size_n;
+      dequant_4bit_8(q_0, dq, size_n);
+      for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
+      half* dqh = (half*)dq;
+      for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
+    }
+    k += 32;
+  }
+
+  while (k < rows_3 && k < end_k) {
+    if (k == nextgroup) {
+      group++;
+      qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
+      nextgroup += b_q_group_map[k * 2 + 1];
+      qs_h2 = __halves2half2(qs_h, qs_h);
+    }
+    for (int p = 0; p < 1; p++) {
+      half2 dq[16];
+      uint32_t q_0 = *b_ptr;
+      b_ptr += size_n;
+      uint32_t q_1 = *b_ptr;
+      b_ptr += size_n;
+      uint32_t q_2 = *b_ptr;
+      b_ptr += size_n;
+      dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
+      for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
+      half* dqh = (half*)dq;
+      for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
+    }
+    k += 32;
+  }
+
+  while (k < rows_2 && k < end_k) {
+    if (k == nextgroup) {
+      group++;
+      qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
+      nextgroup += b_q_group_map[k * 2 + 1];
+      qs_h2 = __halves2half2(qs_h, qs_h);
+    }
+    for (int p = 0; p < 1; p++) {
+      half2 dq[8];
+      uint32_t q_0 = *b_ptr;
+      b_ptr += size_n;
+      dequant_2bit_16(q_0, dq, size_n);
+      for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
+      half* dqh = (half*)dq;
+      for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
+    }
+    k += 16;
+  }
 }
 
-void QMatrix::reconstruct(half* out)
-{
-    dim3 blockDim, gridDim;
-    blockDim.x = BLOCK_KN_SIZE;
-    blockDim.y = 1;
-    gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
-
-    {
-        gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
-        const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-        reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>
-        (
-            cuda_q_weight,
-            cuda_q_perm,
-            cuda_q_scale,
-            cuda_q_scale_max,
-            cuda_q_group_map,
-            height,
-            width,
-            //groupsize,
-            groups,
-            out,
-            rows_8,
-            rows_6,
-            rows_5,
-            rows_4,
-            rows_3,
-            rows_2
-        );
-    }
+void QMatrix::reconstruct(half* out) {
+  dim3 blockDim, gridDim;
+  blockDim.x = BLOCK_KN_SIZE;
+  blockDim.y = 1;
+  gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
+
+  {
+    gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>(
+        cuda_q_weight, cuda_q_perm, cuda_q_scale, cuda_q_scale_max,
+        cuda_q_group_map, height, width,
+        // groupsize,
+        groups, out, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
+  }
 }
 
 }  // namespace exl2

+ 1 - 1
kernels/quantization/fp8/common.cu

@@ -1,5 +1,5 @@
 #include <ATen/cuda/CUDAContext.h>
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <c10/cuda/CUDAGuard.h>
 
 #include <cmath>

+ 3 - 3
kernels/quantization/gptq/q_gemm.cu

@@ -6,7 +6,7 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa
 #include <cstdint>
 #include <cstdio>
 
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <cuda_runtime.h>
@@ -2238,7 +2238,7 @@ void group_gemm_half_q_half_cuda(const half* a, const uint32_t* b_q_weight,
 torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                         torch::Tensor b_gptq_qzeros,
                         torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
-                        bool use_exllama, int bit) {
+                        bool use_exllama, int64_t bit) {
   const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
   auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
   at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
@@ -2260,7 +2260,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
   return c;
 }
 
-void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit) {
+void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
   const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
 
   int num_experts = q_weight.dim() == 3 ? q_weight.size(0) : 1;

+ 1 - 1
kernels/quantization/gptq_marlin/gptq_marlin.cuh

@@ -1,6 +1,6 @@
 #pragma once
 
-#include <torch/extension.h>
+#include <torch/all.h>
 
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>

+ 1 - 1
kernels/quantization/marlin/dense/marlin_cuda_kernel.cu

@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
- #include <torch/extension.h>
+ #include <torch/all.h>
 
  #include <ATen/cuda/CUDAContext.h>
  #include <c10/cuda/CUDAGuard.h>

+ 1 - 1
kernels/quantization/marlin/sparse/marlin_24_cuda_kernel.cu

@@ -16,7 +16,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include <torch/extension.h>
+#include <torch/all.h>
 
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>

+ 0 - 72
kernels/quantization/quant_ops.cpp

@@ -1,72 +0,0 @@
-#include "quant_ops.h"
-#include <torch/extension.h>
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-  // Aphrodite quantization ops
-  pybind11::module quant_ops =
-      m.def_submodule("quant_ops", "Aphrodite custom quant operators");
-
-#ifndef USE_ROCM
-  // AQLM
-  quant_ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
-  quant_ops.def("aqlm_dequant", &aqlm_dequant, "Dequantization for AQLM");
-  // AWQ
-  quant_ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
-  quant_ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
-  quant_ops.def("awq_group_gemm", &awq_group_gemm,
-                "Grouped Quantized GEMM for AWQ");
-  // GGUF
-  quant_ops.def("ggml_dequantize", &ggml_dequantize, "ggml_dequantize");
-  quant_ops.def("ggml_mul_mat_vec", &ggml_mul_mat_vec, "ggml_mul_mat_vec");
-  quant_ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8,
-                "ggml_mul_mat_vec_a8");
-  quant_ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8, "ggml_mul_mat_a8");
-  // Marlin
-  quant_ops.def("marlin_gemm", &marlin_gemm,
-                "Marlin Optimized Quantized GEMM for GPTQ");
-  quant_ops.def("marlin_gemm", &marlin_gemm,
-                "Marlin (Dense) Optimized Quantized GEMM for GPTQ");
-  quant_ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
-                "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
-  quant_ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
-                "gptq_marlin Optimized Quantized GEMM for GPTQ");
-  quant_ops.def("gptq_marlin_repack", &gptq_marlin_repack,
-                "gptq_marlin repack from GPTQ");
-  // SmoothQuant+
-  quant_ops.def("autoquant_convert_s4_k_m8", &autoquant_convert_s4_k_m8,
-                "convert kernel.");
-  quant_ops.def("autoquant_s4_f16_gemm", &autoquant_s4_f16_gemm,
-                "weight int4 activation float16 gemm kernel.");
-  // QuIP#
-  quant_ops.def("quip_decompress", &decompress_e8p_origorder,
-                "decompress_packed_e8p");
-  quant_ops.def("quip_gemv", &e8p_mm_origorder, "e8p_mm_origorder");
-  // CUTLASS w8a8
-  quant_ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
-                "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
-                "per-row/column quantization.");
-#endif
-  // GPTQ
-  quant_ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
-  quant_ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
-  quant_ops.def("group_gptq_gemm", &group_gptq_gemm,
-                "Grouped Quantized GEMM for GPTQ");
-  quant_ops.def("dequant_gptq", &dequant_gptq,
-                "Dequantize gptq weight to half");
-  // SqueezeLLM
-  quant_ops.def("squeezellm_gemm", &squeezellm_gemm,
-                "Quantized GEMM for SqueezeLLM");
-  // INT8
-  quant_ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
-                "Compute int8 quantized tensor for given scaling factor");
-  quant_ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant,
-                "Compute int8 quantized tensor and scaling factor");
-  // ExLlamaV2
-  quant_ops.def("exl2_make_q_matrix", &make_q_matrix, "preprocess for exl2");
-  quant_ops.def("exl2_gemm", &exl2_gemm, "exl2 gemm");
-  // FP8
-  quant_ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
-                "Compute FP8 quantized tensor for given scaling factor");
-  quant_ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
-                "Compute FP8 quantized tensor and scaling factor");
-}

+ 11 - 50
kernels/quantization/quant_ops.h

@@ -1,6 +1,6 @@
 #pragma once
 
-#include <torch/extension.h>
+#include <torch/library.h>
 
 #ifndef USE_ROCM
 // AQLM
@@ -17,12 +17,12 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes,
 // AWQ
 torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                        torch::Tensor _scaling_factors, torch::Tensor _zeros,
-                       int split_k_iters);
+                       int64_t split_k_iters);
 
 torch::Tensor awq_dequantize(torch::Tensor _kernel,
                              torch::Tensor _scaling_factors,
-                             torch::Tensor _zeros, int split_k_iters, int thx,
-                             int thy);
+                             torch::Tensor _zeros, int64_t split_k_iters,
+                             int64_t thx, int64_t thy);
 
 torch::Tensor awq_group_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                              torch::Tensor _scaling_factors,
@@ -30,42 +30,16 @@ torch::Tensor awq_group_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                              torch::Tensor _sorted_token_ids_ptr,
                              torch::Tensor _expert_ids_ptr,
                              torch::Tensor _num_tokens_post_padded,
-                             bool mul_weights, int split_k_iters);
-#endif
-
-// ExLlamav2
-torch::Tensor exl2_gemm(torch::Tensor a, uintptr_t b);
-
-uintptr_t make_q_matrix(torch::Tensor q_weight, torch::Tensor q_perm,
-                        torch::Tensor q_invperm, torch::Tensor q_scale,
-                        torch::Tensor q_scale_max, torch::Tensor q_groups,
-                        torch::Tensor q_group_map);
-
-#ifndef USE_ROCM
-// GGUF
-torch::Tensor ggml_dequantize(torch::Tensor X, int8_t type, int64_t m,
-                              int64_t n);
-
-torch::Tensor ggml_mul_mat_vec(torch::Tensor W,  // quant weight
-                               torch::Tensor X,  // input
-                               int8_t type, int64_t m);
-
-torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W,  // quant weight
-                                  torch::Tensor X,  // input
-                                  int8_t type, int64_t row);
-
-torch::Tensor ggml_mul_mat_a8(torch::Tensor W,  // quant weight
-                              torch::Tensor X,  // input
-                              int8_t type, int64_t row);
+                             bool mul_weights, int64_t split_k_iters);
 #endif
 
 // GPTQ
 torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                         torch::Tensor b_gptq_qzeros,
                         torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
-                        bool use_exllama, int bit);
+                        bool use_exllama, int64_t bit);
 
-void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
+void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
 
 torch::Tensor group_gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                               torch::Tensor b_gptq_qzeros,
@@ -79,7 +53,7 @@ torch::Tensor group_gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
 torch::Tensor dequant_gptq(torch::Tensor b_q_weight,
                            torch::Tensor b_gptq_qzeros,
                            torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
-                           int bits, bool use_exllama);
+                           int64_t bits, bool use_exllama);
 
 #ifndef USE_ROCM
 // Marlin
@@ -111,22 +85,9 @@ at::Tensor e8p_mm_origorder(const at::Tensor& A, const at::Tensor& B,
 void decompress_e8p_origorder(torch::Tensor YIs, torch::Tensor CB,
                               torch::Tensor& Y);
 
-// SmoothQuant+
-torch::Tensor autoquant_s4_f16_gemm(torch::Tensor _in_feats,
-                                    torch::Tensor _kernel,
-                                    torch::Tensor _scales_zeros);
-
-void autoquant_convert_s4_k_m8(torch::Tensor _weight_dest,
-                               torch::Tensor _quant_scales_zeros_dest,
-                               torch::Tensor _workspace,
-                               torch::Tensor _quant_weight_src,
-                               torch::Tensor _quant_scales,
-                               torch::Tensor _quant_zeros, int m, int k,
-                               int group_size);
-
-int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
-                         torch::Tensor const& b, torch::Tensor const& a_scales,
-                         torch::Tensor const& b_scales);
+void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
+                          torch::Tensor const& b, torch::Tensor const& a_scales,
+                          torch::Tensor const& b_scales);
 
 #endif
 

+ 231 - 340
kernels/quantization/quip/origin_order.cu

@@ -9,10 +9,9 @@
 #include <ATen/core/Tensor.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/DeviceGuard.h>
-#include <torch/extension.h>
+#include <torch/all.h>
 #include <c10/cuda/CUDAGuard.h>
 
-
 template <typename U, typename V>
 constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) {
   static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
@@ -45,144 +44,118 @@ constexpr int32_t kMTileSize = 16;
 constexpr int32_t kNTileSize = 8;
 constexpr int32_t kKTileSize = 16;
 
-struct __align__(16) f16x2x4_u32 {
-  uint32_t vals[4];
-};
-struct __align__(16) f16x2x2_u32 {
-  uint32_t vals[2];
-};
+struct __align__(16) f16x2x4_u32 { uint32_t vals[4]; };
+struct __align__(16) f16x2x2_u32 { uint32_t vals[2]; };
 
 struct ALayout_RM {
-template <int KTilesToLoad>
-static __device__ void load(
-    const half* A,
-    int32_t m,
-    int32_t k,
-    int32_t mTiles,
-    int32_t mTile,
-    int32_t kTiles,
-    int32_t kTileStart,
-    int32_t laneId,
-    f16x2x4_u32 out[KTilesToLoad]) {
-  const auto mLane = mTile * kMTileSize + (laneId / 4);
-  const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 4;
-
-  // access
-  // [mTile * kMTileSize + (laneId / 4)]
-  // [kTileStart * kKTileSize + (laneId % 4) * 2]
-  auto aPtr = A + mLane * k + kLane;
-
-  auto aPtrPlus8Rows = aPtr + 8 * k;
-
-  bool m0InBounds = mLane < m;
-  bool m1InBounds = (mLane + 8) < m;
+  template <int KTilesToLoad>
+  static __device__ void load(const half* A, int32_t m, int32_t k,
+                              int32_t mTiles, int32_t mTile, int32_t kTiles,
+                              int32_t kTileStart, int32_t laneId,
+                              f16x2x4_u32 out[KTilesToLoad]) {
+    const auto mLane = mTile * kMTileSize + (laneId / 4);
+    const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 4;
+
+    // access
+    // [mTile * kMTileSize + (laneId / 4)]
+    // [kTileStart * kKTileSize + (laneId % 4) * 2]
+    auto aPtr = A + mLane * k + kLane;
+
+    auto aPtrPlus8Rows = aPtr + 8 * k;
+
+    bool m0InBounds = mLane < m;
+    bool m1InBounds = (mLane + 8) < m;
 
 #pragma unroll
-  for (int i = 0; i < KTilesToLoad; ++i) {
-    out[i].vals[0] = m0InBounds
-          ? *reinterpret_cast<const uint32_t*>(aPtr  + i * kKTileSize)
-          : uint32_t(0);
-    out[i].vals[1] = m1InBounds
-          ? *reinterpret_cast<const uint32_t*>(aPtrPlus8Rows  + i * kKTileSize)
-          : uint32_t(0);
-
-    out[i].vals[2] = m0InBounds
-          ? *reinterpret_cast<const uint32_t*>(aPtr  + i * kKTileSize + 2)
-          : uint32_t(0);
-    out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
-                                        aPtrPlus8Rows  + i * kKTileSize + 2)
+    for (int i = 0; i < KTilesToLoad; ++i) {
+      out[i].vals[0] =
+          m0InBounds ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize)
+                     : uint32_t(0);
+      out[i].vals[1] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
+                                        aPtrPlus8Rows + i * kKTileSize)
                                   : uint32_t(0);
+
+      out[i].vals[2] =
+          m0InBounds
+              ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize + 2)
+              : uint32_t(0);
+      out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
+                                        aPtrPlus8Rows + i * kKTileSize + 2)
+                                  : uint32_t(0);
+    }
   }
-}
 
-static __device__ void store(
-    half* C,
-    int32_t m,
-    int32_t n,
-    int32_t mOutTiles,
-    int32_t mTile,
-    int32_t nOutTiles,
-    int32_t nTile,
-    int32_t laneId,
-    const float4& out) {
+  static __device__ void store(half* C, int32_t m, int32_t n, int32_t mOutTiles,
+                               int32_t mTile, int32_t nOutTiles, int32_t nTile,
+                               int32_t laneId, const float4& out) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
 
-  // sum.x / sum.y are written at
-  // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
-  // sum.z / sum.w are written at
-  // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
-  // i.e., same columns, different row.
-  const int outRow = mTile * kMTileSize + (laneId / 4);
-  const int outCol = nTile * kNTileSize + (laneId % 4) * 2;
+    // sum.x / sum.y are written at
+    // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
+    // sum.z / sum.w are written at
+    // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
+    // i.e., same columns, different row.
+    const int outRow = mTile * kMTileSize + (laneId / 4);
+    const int outCol = nTile * kNTileSize + (laneId % 4) * 2;
 
-  // Pointer where sum.x / sum.y is written
-  auto cPtr = C + outRow * n + outCol;
+    // Pointer where sum.x / sum.y is written
+    auto cPtr = C + outRow * n + outCol;
 
-  auto v01 = __float22half2_rn(float2{out.x, out.y});
-  auto v23 = __float22half2_rn(float2{out.z, out.w});
+    auto v01 = __float22half2_rn(float2{out.x, out.y});
+    auto v23 = __float22half2_rn(float2{out.z, out.w});
 
-  if (outRow < m) {
-    *reinterpret_cast<half2*>(cPtr) = v01;
-  }
+    if (outRow < m) {
+      *reinterpret_cast<half2*>(cPtr) = v01;
+    }
 
-  // sum.z, sum.w at +8 rows from cPtr
-  if (outRow + 8 < m) {
-    *reinterpret_cast<half2*>(cPtr + 8 * n) = v23;
-  }
+    // sum.z, sum.w at +8 rows from cPtr
+    if (outRow + 8 < m) {
+      *reinterpret_cast<half2*>(cPtr + 8 * n) = v23;
+    }
 #endif
-}
+  }
 };
 
 struct BLayout_D4 {
-static constexpr bool use_codebook = true;
-
-template <int KTilesPerIteration>
-static __device__ void load(
-    const void* __restrict__ B,
-    const uint64_t* __restrict__ CB,
-    int32_t n,
-    int32_t k,
-    int32_t nTiles,
-    int32_t nTile,
-    int32_t kTiles,
-    int32_t kTileStart,
-    int32_t laneId,
-    f16x2x2_u32 b[KTilesPerIteration]) {
+  static constexpr bool use_codebook = true;
+
+  template <int KTilesPerIteration>
+  static __device__ void load(const void* __restrict__ B,
+                              const uint64_t* __restrict__ CB, int32_t n,
+                              int32_t k, int32_t nTiles, int32_t nTile,
+                              int32_t kTiles, int32_t kTileStart,
+                              int32_t laneId,
+                              f16x2x2_u32 b[KTilesPerIteration]) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  auto Bptr = reinterpret_cast<const uint8_t*>(B);
+    auto Bptr = reinterpret_cast<const uint8_t*>(B);
   #pragma unroll
-  for (int i = 0; i < KTilesPerIteration; ++i) {
-       const int row = nTile * kNTileSize + laneId / 4;
-       const int col = (kTileStart + i) * kKTileSize / 4 + laneId % 4;
-       *(reinterpret_cast<uint64_t*>(b[i].vals)) = CB[Bptr[row * k/4 + col]];
-  }
+    for (int i = 0; i < KTilesPerIteration; ++i) {
+      const int row = nTile * kNTileSize + laneId / 4;
+      const int col = (kTileStart + i) * kKTileSize / 4 + laneId % 4;
+      *(reinterpret_cast<uint64_t*>(b[i].vals)) = CB[Bptr[row * k / 4 + col]];
+    }
 #endif
-}
+  }
 };
 
 struct BLayout_HI {
-static constexpr bool use_codebook = false;
-
-template <int KTilesPerIteration>
-static __device__ void load(
-    const void* __restrict__ B,
-    const uint64_t* __restrict__ CB,
-    int32_t n,
-    int32_t k,
-    int32_t nTiles,
-    int32_t nTile,
-    int32_t kTiles,
-    int32_t kTileStart,
-    int32_t laneId,
-    f16x2x2_u32 b[KTilesPerIteration]) {
+  static constexpr bool use_codebook = false;
+
+  template <int KTilesPerIteration>
+  static __device__ void load(const void* __restrict__ B,
+                              const uint64_t* __restrict__ CB, int32_t n,
+                              int32_t k, int32_t nTiles, int32_t nTile,
+                              int32_t kTiles, int32_t kTileStart,
+                              int32_t laneId,
+                              f16x2x2_u32 b[KTilesPerIteration]) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  auto Bptr = reinterpret_cast<const uint32_t*>(B);
+    auto Bptr = reinterpret_cast<const uint32_t*>(B);
   #pragma unroll
-  for (int i = 0; i < KTilesPerIteration; ++i) {
+    for (int i = 0; i < KTilesPerIteration; ++i) {
       const int row = nTile * kNTileSize + laneId / 4;
       const int col = (kTileStart + i) * kKTileSize / 8 + (laneId % 4) / 2;
       // simply use code - 7.5 instead of reading codebook
-      uint32_t code = Bptr[row * k/8 + col];
+      uint32_t code = Bptr[row * k / 8 + col];
 
       const uint32_t c0 = 0x64086408;
       const half y16_ = __float2half_rn(1.0f / 16.0f);
@@ -191,22 +164,20 @@ static __device__ void load(
       const half2 z16 = __halves2half2(z16_, z16_);
 
       uint32_t qa = code >> ((laneId & 1) * 8);
-      uint32_t q0 = (((qa & 0x000f000f) << 4)| c0);
+      uint32_t q0 = (((qa & 0x000f000f) << 4) | c0);
       uint32_t q1 = ((qa & 0x00f000f0) | c0);
       *(half2*)(b[i].vals) = __hfma2(*((half2*)(&q0)), y16, z16);
-      *(half2*)(b[i].vals+1) = __hfma2(*((half2*)(&q1)), y16, z16);
-  }
+      *(half2*)(b[i].vals + 1) = __hfma2(*((half2*)(&q1)), y16, z16);
+    }
 #endif
-}
+  }
 };
 
 struct BLayout_E8 {
-static constexpr bool use_codebook = true;
+  static constexpr bool use_codebook = true;
 
-__device__ static inline uint64_t decode8weights(
-    uint16_t weight_compressed,
-    const int64_t *__restrict__ codebook_abs
-) {
+  __device__ static inline uint64_t decode8weights(
+      uint16_t weight_compressed, const int64_t* __restrict__ codebook_abs) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
 
     uint8_t bits_sign = weight_compressed & 0xff;
@@ -225,18 +196,17 @@ __device__ static inline uint64_t decode8weights(
 
     return packed;
 #endif
-}
+  }
 
-__device__ static inline uint32_t decode8weights(
-    uint16_t weight_compressed,
-    const int64_t *__restrict__ codebook_abs,
-    int idx
-) {
+  __device__ static inline uint32_t decode8weights(
+      uint16_t weight_compressed, const int64_t* __restrict__ codebook_abs,
+      int idx) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-    uint8_t bits_sign = weight_compressed & 0xff; //__brev(weight_compressed) >> 24;
+    uint8_t bits_sign =
+        weight_compressed & 0xff;  //__brev(weight_compressed) >> 24;
     const uint32_t magic_nums[2] = {0x08040201ll, 0x80402010ll};
     uint8_t parity = __popc(bits_sign) & 1;
-    uint8_t sign_vec = bits_sign ^ parity; // (parity << 7);
+    uint8_t sign_vec = bits_sign ^ parity;  // (parity << 7);
     uint16_t bits_abs = (weight_compressed >> 8);
     uint32_t packed = ((uint32_t*)codebook_abs)[(bits_abs << 1) + idx];
     uint32_t magic_num = magic_nums[idx];
@@ -249,75 +219,61 @@ __device__ static inline uint32_t decode8weights(
     packed -= parity * 0x02020202;
     return packed;
 #endif
-};
-
-template <int KTilesPerIteration>
-static __device__ void load(
-    const void* __restrict__ B,
-    const uint64_t* __restrict__ CB,
-    int32_t n,
-    int32_t k,
-    int32_t nTiles,
-    int32_t nTile,
-    int32_t kTiles,
-    int32_t kTileStart,
-    int32_t laneId,
-    f16x2x2_u32 b[KTilesPerIteration]) {
+  };
+
+  template <int KTilesPerIteration>
+  static __device__ void load(const void* __restrict__ B,
+                              const uint64_t* __restrict__ CB, int32_t n,
+                              int32_t k, int32_t nTiles, int32_t nTile,
+                              int32_t kTiles, int32_t kTileStart,
+                              int32_t laneId,
+                              f16x2x2_u32 b[KTilesPerIteration]) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
-  auto Bptr = (const uint16_t*) B;
+    auto Bptr = (const uint16_t*)B;
   #pragma unroll
-  for (int i = 0; i < KTilesPerIteration; ++i) {
-       const int row = nTile * kNTileSize + laneId / 4;
-       const int col = (kTileStart + i) * kKTileSize / 8 + laneId % 4 / 2;
-       uint32_t decoded = decode8weights(Bptr[row * k/8 + col], (const int64_t*)CB, laneId & 1);
-       half2 unpacked[2];
-       uint32_t lower_half = decoded & 0x00ff00ff;
-       lower_half = (lower_half ^ 0x5c805c80);
-       memcpy(unpacked, &lower_half, sizeof(uint32_t));
-       uint32_t upper_half = (decoded & 0xff00ff00) >> 8;
-       upper_half = (upper_half ^ 0x5c805c80);
-       memcpy(unpacked + 1, &upper_half, sizeof(uint32_t));
-
-       const half adjust_ = __float2half_rn(-288.0f);
-       const half2 adjust = __halves2half2(adjust_, adjust_);
-       unpacked[0] = __hadd2(unpacked[0], adjust);
-       unpacked[1] = __hadd2(unpacked[1], adjust);
-       *(reinterpret_cast<uint64_t*>(b[i].vals)) = *(reinterpret_cast<uint64_t*>(unpacked));
-       //*((half*)(b[i].vals)) = unpacked[0];
-       //*((half*)(b[i].vals) + 1) = unpacked[0].y;
-       //*((half*)(b[i].vals) + 2) = unpacked[1].x;
-       //*((half*)(b[i].vals) + 3) = unpacked[1].y;
-  }
+    for (int i = 0; i < KTilesPerIteration; ++i) {
+      const int row = nTile * kNTileSize + laneId / 4;
+      const int col = (kTileStart + i) * kKTileSize / 8 + laneId % 4 / 2;
+      uint32_t decoded = decode8weights(Bptr[row * k / 8 + col],
+                                        (const int64_t*)CB, laneId & 1);
+      half2 unpacked[2];
+      uint32_t lower_half = decoded & 0x00ff00ff;
+      lower_half = (lower_half ^ 0x5c805c80);
+      memcpy(unpacked, &lower_half, sizeof(uint32_t));
+      uint32_t upper_half = (decoded & 0xff00ff00) >> 8;
+      upper_half = (upper_half ^ 0x5c805c80);
+      memcpy(unpacked + 1, &upper_half, sizeof(uint32_t));
+
+      const half adjust_ = __float2half_rn(-288.0f);
+      const half2 adjust = __halves2half2(adjust_, adjust_);
+      unpacked[0] = __hadd2(unpacked[0], adjust);
+      unpacked[1] = __hadd2(unpacked[1], adjust);
+      *(reinterpret_cast<uint64_t*>(b[i].vals)) =
+          *(reinterpret_cast<uint64_t*>(unpacked));
+      //*((half*)(b[i].vals)) = unpacked[0];
+      //*((half*)(b[i].vals) + 1) = unpacked[0].y;
+      //*((half*)(b[i].vals) + 2) = unpacked[1].x;
+      //*((half*)(b[i].vals) + 3) = unpacked[1].y;
+    }
 #endif
-}
+  }
 };
 
-
-template <
-    typename ALayout,
-    typename BLayout,
-    typename CLayout,
-    int Warps,
-    int KTilesPerIteration>
-__global__
-__launch_bounds__(256) void tinygemm_m16n8k16_chunk_kernel(
+template <typename ALayout, typename BLayout, typename CLayout, int Warps,
+          int KTilesPerIteration>
+__global__ __launch_bounds__(256) void tinygemm_m16n8k16_chunk_kernel(
     // Data for the A matrix, loaded as per ALayout
-    const half* __restrict__ A,
-    const void* __restrict__ B,
+    const half* __restrict__ A, const void* __restrict__ B,
     const uint64_t* __restrict__ CB,
 
     // Output data for the C matrix, stored as per CLayout
     half* __restrict__ C,
 
     // The size of the matrix multiplication
-    int32_t m,
-    int32_t n,
-    int32_t k,
+    int32_t m, int32_t n, int32_t k,
 
     // The size of the matrix multiplication, in multiples of our TC tile size
-    int32_t mTiles,
-    int32_t nTiles,
-    int32_t kTiles) {
+    int32_t mTiles, int32_t nTiles, int32_t kTiles) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   __shared__ uint64_t CB_[256];
   if (BLayout::use_codebook) {
@@ -333,57 +289,49 @@ __launch_bounds__(256) void tinygemm_m16n8k16_chunk_kernel(
 
   float4 c{0.0f, 0.0f, 0.0f, 0.0f};
 
- // First, handle whole multiples of KTilesPerIteration
+  // First, handle whole multiples of KTilesPerIteration
   auto kTilesLimit = roundDown(kTiles, KTilesPerIteration);
 
   // Each warp handles a set of KTilesPerIteration under the above limit
-  for (int32_t kTileBase = warpId * KTilesPerIteration; kTileBase < kTilesLimit; kTileBase += Warps * KTilesPerIteration) {
+  for (int32_t kTileBase = warpId * KTilesPerIteration; kTileBase < kTilesLimit;
+       kTileBase += Warps * KTilesPerIteration) {
     //
     // Load data from A
     //
     f16x2x4_u32 a[KTilesPerIteration];
-    ALayout::template load<KTilesPerIteration>(
-        A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a);
+    ALayout::template load<KTilesPerIteration>(A, m, k, mTiles, mTile, kTiles,
+                                               kTileBase, laneId, a);
 
     //
     // Load data from B and de-quantize as needed
     //
     f16x2x2_u32 b[KTilesPerIteration];
-    BLayout::template load<KTilesPerIteration>(
-        B, CB_, n, k, nTiles, nTile, kTiles, kTileBase, laneId, b);
+    BLayout::template load<KTilesPerIteration>(B, CB_, n, k, nTiles, nTile,
+                                               kTiles, kTileBase, laneId, b);
 
-    // Now, perform the matrix multiplication
-    //
-    #pragma unroll
+  // Now, perform the matrix multiplication
+  //
+  #pragma unroll
     for (int i = 0; i < KTilesPerIteration / 2; ++i) {
       float4 cTmp[2];
 
-      #pragma unroll
+  #pragma unroll
       for (int k = 0; k < 2; ++k) {
         cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
       }
 
-      #pragma unroll
+  #pragma unroll
       for (int k = 0; k < 2; ++k) {
         asm volatile(
-              "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
-              "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
-              : "=f"(cTmp[k].x),
-                "=f"(cTmp[k].y),
-                "=f"(cTmp[k].z),
-                "=f"(cTmp[k].w)
-              : "r"(a[i * 2 + k].vals[0]),
-                "r"(a[i * 2 + k].vals[1]),
-                "r"(a[i * 2 + k].vals[2]),
-                "r"(a[i * 2 + k].vals[3]),
-                "r"(b[i * 2 + k].vals[0]),
-                "r"(b[i * 2 + k].vals[1]),
-                "f"(cTmp[k].x),
-                "f"(cTmp[k].y),
-                "f"(cTmp[k].z),
-                "f"(cTmp[k].w));
+            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
+            "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
+            : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w)
+            : "r"(a[i * 2 + k].vals[0]), "r"(a[i * 2 + k].vals[1]),
+              "r"(a[i * 2 + k].vals[2]), "r"(a[i * 2 + k].vals[3]),
+              "r"(b[i * 2 + k].vals[0]), "r"(b[i * 2 + k].vals[1]),
+              "f"(cTmp[k].x), "f"(cTmp[k].y), "f"(cTmp[k].z), "f"(cTmp[k].w));
       }
-      #pragma unroll
+  #pragma unroll
       for (int k = 0; k < 2; ++k) {
         c.x += cTmp[k].x;
         c.y += cTmp[k].y;
@@ -392,8 +340,7 @@ __launch_bounds__(256) void tinygemm_m16n8k16_chunk_kernel(
       }
     }
 
-  } // for all tiles under kTilesLimit
-
+  }  // for all tiles under kTilesLimit
 
   auto kTileBaseRemaining = kTilesLimit + warpId;
 
@@ -401,30 +348,20 @@ __launch_bounds__(256) void tinygemm_m16n8k16_chunk_kernel(
   // kInnerKTiles k-tiles at a time
   if (kTileBaseRemaining < kTiles) {
     f16x2x4_u32 a;
-    ALayout::template load<1>(
-        A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, &a);
+    ALayout::template load<1>(A, m, k, mTiles, mTile, kTiles,
+                              kTileBaseRemaining, laneId, &a);
 
     f16x2x2_u32 b;
-    BLayout::template load<1>(
-        B, CB, n, k, nTiles, nTile, kTiles, kTileBaseRemaining, laneId, &b);
+    BLayout::template load<1>(B, CB, n, k, nTiles, nTile, kTiles,
+                              kTileBaseRemaining, laneId, &b);
 
     asm volatile(
-              "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
-              "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
-              : "=f"(c.x),
-                "=f"(c.y),
-                "=f"(c.z),
-                "=f"(c.w)
-              : "r"(a.vals[0]),
-                "r"(a.vals[1]),
-                "r"(a.vals[2]),
-                "r"(a.vals[3]),
-                "r"(b.vals[0]),
-                "r"(b.vals[1]),
-                "f"(c.x),
-                "f"(c.y),
-                "f"(c.z),
-                "f"(c.w));
+        "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
+        "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
+        : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
+        : "r"(a.vals[0]), "r"(a.vals[1]), "r"(a.vals[2]), "r"(a.vals[3]),
+          "r"(b.vals[0]), "r"(b.vals[1]), "f"(c.x), "f"(c.y), "f"(c.z),
+          "f"(c.w));
   }
   // Reduce independent k-tiles (same m/n) across warps
   __shared__ float4 smem_sum[Warps][kWarpSize];
@@ -446,26 +383,16 @@ __launch_bounds__(256) void tinygemm_m16n8k16_chunk_kernel(
     }
 
     // Write the reduced result (in the first warp) into the output
-    CLayout::store(
-        C,
-        m,
-        n,
-        mTiles,
-        mTile,
-        // n for C output becomes k for A input, so for m16n8k16,
-        // we need to halve the tiles
-        nTiles / 2,
-        nTile,
-        laneId,
-        sum_f32);
+    CLayout::store(C, m, n, mTiles, mTile,
+                   // n for C output becomes k for A input, so for m16n8k16,
+                   // we need to halve the tiles
+                   nTiles / 2, nTile, laneId, sum_f32);
   }
 #endif
 }
 
-at::Tensor d4_mm_origorder(
-    const at::Tensor& A,
-    const at::Tensor& B,
-    const at::Tensor& CB) {
+at::Tensor d4_mm_origorder(const at::Tensor& A, const at::Tensor& B,
+                           const at::Tensor& CB) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   c10::cuda::CUDAGuard g(A.device());
   auto stream = at::cuda::getCurrentCUDAStream();
@@ -490,28 +417,20 @@ at::Tensor d4_mm_origorder(
 
   auto grid = dim3(1, nTiles, mTiles);
   auto block = dim3(kWarpSize, Warps);
-  auto kernel = tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_D4, ALayout_RM, 8, 8>;
+  auto kernel =
+      tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_D4, ALayout_RM, 8, 8>;
 
   kernel<<<grid, block, 0, stream>>>(
-      (const half*)A.data_ptr(),
-      (const void*)B.data_ptr(),
-      (const uint64_t*)CB.data_ptr(),
-      (half*)C_final.data_ptr(),
-      m,
-      n,
-      k,
-      mTiles,
-      nTiles,
-      kTiles);
+      (const half*)A.data_ptr(), (const void*)B.data_ptr(),
+      (const uint64_t*)CB.data_ptr(), (half*)C_final.data_ptr(), m, n, k,
+      mTiles, nTiles, kTiles);
 
   return C_final;
 #endif
 }
 
-at::Tensor e8p_mm_origorder(
-    const at::Tensor& A,
-    const at::Tensor& B,
-    const at::Tensor& CB) {
+at::Tensor e8p_mm_origorder(const at::Tensor& A, const at::Tensor& B,
+                            const at::Tensor& CB) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   c10::cuda::CUDAGuard g(A.device());
   auto stream = at::cuda::getCurrentCUDAStream();
@@ -536,26 +455,18 @@ at::Tensor e8p_mm_origorder(
 
   auto grid = dim3(1, nTiles, mTiles);
   auto block = dim3(kWarpSize, Warps);
-  auto kernel = tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_E8, ALayout_RM, 8, 8>;
+  auto kernel =
+      tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_E8, ALayout_RM, 8, 8>;
   kernel<<<grid, block, 0, stream>>>(
-      (const half*)A.data_ptr(),
-      (const void*)B.data_ptr(),
-      (const uint64_t*)CB.data_ptr(),
-      (half*)C_final.data_ptr(),
-      m,
-      n,
-      k,
-      mTiles,
-      nTiles,
-      kTiles);
+      (const half*)A.data_ptr(), (const void*)B.data_ptr(),
+      (const uint64_t*)CB.data_ptr(), (half*)C_final.data_ptr(), m, n, k,
+      mTiles, nTiles, kTiles);
 
   return C_final;
 #endif
 }
 
-at::Tensor hi_mm_origorder(
-    const at::Tensor& A,
-    const at::Tensor& B) {
+at::Tensor hi_mm_origorder(const at::Tensor& A, const at::Tensor& B) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   c10::cuda::CUDAGuard g(A.device());
   auto stream = at::cuda::getCurrentCUDAStream();
@@ -580,18 +491,11 @@ at::Tensor hi_mm_origorder(
 
   auto grid = dim3(1, nTiles, mTiles);
   auto block = dim3(kWarpSize, Warps);
-  auto kernel = tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_HI, ALayout_RM, 8, 8>;
+  auto kernel =
+      tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_HI, ALayout_RM, 8, 8>;
   kernel<<<grid, block, 0, stream>>>(
-      (const half*)A.data_ptr(),
-      (const void*)B.data_ptr(),
-      nullptr,
-      (half*)C_final.data_ptr(),
-      m,
-      n,
-      k,
-      mTiles,
-      nTiles,
-      kTiles);
+      (const half*)A.data_ptr(), (const void*)B.data_ptr(), nullptr,
+      (half*)C_final.data_ptr(), m, n, k, mTiles, nTiles, kTiles);
 
   return C_final;
 #endif
@@ -600,25 +504,23 @@ at::Tensor hi_mm_origorder(
 #define DECOMPRESS_D4_BLOCK_SIZE 256
 
 __global__ void cuda_decompress_d4_origorder_kernel(
-    const uint8_t* __restrict__ YIs,	  // m x (n/4)
-    const c10::Half* __restrict__ CB,           // 256 x 4
-    c10::Half* __restrict__ Y             // m x n
+    const uint8_t* __restrict__ YIs,   // m x (n/4)
+    const c10::Half* __restrict__ CB,  // 256 x 4
+    c10::Half* __restrict__ Y          // m x n
 ) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   const long i = threadIdx.x + DECOMPRESS_D4_BLOCK_SIZE * blockIdx.x;
 
-  for(long r = 0; r < 4; r++) {
-    uint8_t yidx = ((uint8_t*)YIs)[i*4 + r];
-    ((uint64_t*)Y)[i*4 + r] = ((uint64_t*)CB)[yidx & 255];
+  for (long r = 0; r < 4; r++) {
+    uint8_t yidx = ((uint8_t*)YIs)[i * 4 + r];
+    ((uint64_t*)Y)[i * 4 + r] = ((uint64_t*)CB)[yidx & 255];
   }
 #endif
 }
 
-
-void decompress_d4_origorder(
-    torch::Tensor YIs,      // m x (n/4)
-    torch::Tensor CB,       // 256 x 4
-    torch::Tensor Y         // m x n
+void decompress_d4_origorder(torch::Tensor YIs,  // m x (n/4)
+                             torch::Tensor CB,   // 256 x 4
+                             torch::Tensor Y     // m x n
 ) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   size_t m = Y.sizes()[0];
@@ -633,27 +535,25 @@ void decompress_d4_origorder(
   assert(CB.sizes()[0] == 256);
 
   const dim3 threads(DECOMPRESS_D4_BLOCK_SIZE);
-  const dim3 blocks(m*n/(16*DECOMPRESS_D4_BLOCK_SIZE));
+  const dim3 blocks(m * n / (16 * DECOMPRESS_D4_BLOCK_SIZE));
   cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
   cuda_decompress_d4_origorder_kernel<<<blocks, threads, 0, stream>>>(
-    YIs.data_ptr<uint8_t>(),
-    CB.data_ptr<c10::Half>(),
-    Y.data_ptr<c10::Half>()
-  );
+      YIs.data_ptr<uint8_t>(), CB.data_ptr<c10::Half>(),
+      Y.data_ptr<c10::Half>());
 #endif
 }
 
 #define DECOMPRESS_E8P_BLOCK_SIZE 256
 
 __global__ void cuda_decompress_e8p_origorder_kernel(
-    const int16_t* __restrict__ YIs,	  // m x (n/8)
-    const int64_t* __restrict__ CB, // 256 x 8
-    c10::Half* __restrict__ Y             // m x n
+    const int16_t* __restrict__ YIs,  // m x (n/8)
+    const int64_t* __restrict__ CB,   // 256 x 8
+    c10::Half* __restrict__ Y         // m x n
 ) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   const long i = threadIdx.x + DECOMPRESS_E8P_BLOCK_SIZE * blockIdx.x;
   uint16_t yidx = ((uint16_t*)YIs)[i];
-  uint64_t decoded =  BLayout_E8::decode8weights(yidx, CB);
+  uint64_t decoded = BLayout_E8::decode8weights(yidx, CB);
 
   half2 unpacked[2][2];
   uint64_t lower_half = decoded & 0x00ff00ff00ff00ff;
@@ -666,18 +566,16 @@ __global__ void cuda_decompress_e8p_origorder_kernel(
   const half adjust_ = __float2half_rn(-288.0f);
   const half2 adjust = __halves2half2(adjust_, adjust_);
 
-  ((__half2*)Y)[i*4] = __hadd2(unpacked[0][0], adjust); // 01
-  ((__half2*)Y)[i*4+2] = __hadd2(unpacked[0][1], adjust); // 45
-  ((__half2*)Y)[i*4+1] = __hadd2(unpacked[1][0], adjust); // 23
-  ((__half2*)Y)[i*4+3] = __hadd2(unpacked[1][1], adjust); // 67
+  ((__half2*)Y)[i * 4] = __hadd2(unpacked[0][0], adjust);      // 01
+  ((__half2*)Y)[i * 4 + 2] = __hadd2(unpacked[0][1], adjust);  // 45
+  ((__half2*)Y)[i * 4 + 1] = __hadd2(unpacked[1][0], adjust);  // 23
+  ((__half2*)Y)[i * 4 + 3] = __hadd2(unpacked[1][1], adjust);  // 67
 #endif
 }
 
-
-void decompress_e8p_origorder(
-    torch::Tensor YIs,      // m x (n/8)
-    torch::Tensor CB,       // 256 x 8
-    torch::Tensor &Y         // m x n
+void decompress_e8p_origorder(torch::Tensor YIs,  // m x (n/8)
+                              torch::Tensor CB,   // 256 x 8
+                              torch::Tensor& Y    // m x n
 ) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   size_t m = Y.sizes()[0];
@@ -692,21 +590,18 @@ void decompress_e8p_origorder(
   assert(CB.sizes()[0] == 256);
 
   const dim3 threads(DECOMPRESS_E8P_BLOCK_SIZE);
-  const dim3 blocks(m*n/(8*DECOMPRESS_E8P_BLOCK_SIZE));
+  const dim3 blocks(m * n / (8 * DECOMPRESS_E8P_BLOCK_SIZE));
   cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
   cuda_decompress_e8p_origorder_kernel<<<blocks, threads, 0, stream>>>(
-    YIs.data_ptr<int16_t>(),
-    CB.data_ptr<int64_t>(),
-    Y.data_ptr<c10::Half>()
-  );
+      YIs.data_ptr<int16_t>(), CB.data_ptr<int64_t>(), Y.data_ptr<c10::Half>());
 #endif
 }
 
 #define DECOMPRESS_HI_BLOCK_SIZE 256
 
 __global__ void cuda_decompress_hi_origorder_kernel(
-    const uint32_t* __restrict__ YIs,	  // m x (n/8)
-    c10::Half* __restrict__ Y             // m x n
+    const uint32_t* __restrict__ YIs,  // m x (n/8)
+    c10::Half* __restrict__ Y          // m x n
 ) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   const long i = threadIdx.x + DECOMPRESS_HI_BLOCK_SIZE * blockIdx.x;
@@ -718,23 +613,21 @@ __global__ void cuda_decompress_hi_origorder_kernel(
   const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
   const half2 z16 = __halves2half2(z16_, z16_);
 
-
   uint32_t q0 = (((qa & 0x000f000f) << 4) | c0);
-  uint32_t q1 = ((qa & 0x00f000f0)| c0);
+  uint32_t q1 = ((qa & 0x00f000f0) | c0);
   qa >>= 8;
   uint32_t q2 = (((qa & 0x000f000f) << 4) | c0);
   uint32_t q3 = ((qa & 0x00f000f0) | c0);
-  ((__half2*)Y)[i*4] = __hfma2(*((half2*)(&q0)), y16, z16);
-  ((__half2*)Y)[i*4+1] = __hfma2(*((half2*)(&q1)), y16, z16);
-  ((__half2*)Y)[i*4+2] = __hfma2(*((half2*)(&q2)), y16, z16);
-  ((__half2*)Y)[i*4+3] = __hfma2(*((half2*)(&q3)), y16, z16);
+  ((__half2*)Y)[i * 4] = __hfma2(*((half2*)(&q0)), y16, z16);
+  ((__half2*)Y)[i * 4 + 1] = __hfma2(*((half2*)(&q1)), y16, z16);
+  ((__half2*)Y)[i * 4 + 2] = __hfma2(*((half2*)(&q2)), y16, z16);
+  ((__half2*)Y)[i * 4 + 3] = __hfma2(*((half2*)(&q3)), y16, z16);
 #endif
 }
 
-void decompress_hi_origorder(
-    torch::Tensor YIs,      // m x (n/8)
-    torch::Tensor Y         // m x n
-){
+void decompress_hi_origorder(torch::Tensor YIs,  // m x (n/8)
+                             torch::Tensor Y     // m x n
+) {
 #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
   size_t m = Y.sizes()[0];
   size_t n = Y.sizes()[1];
@@ -746,11 +639,9 @@ void decompress_hi_origorder(
   assert(YIs.sizes()[1] * 8 == n);
 
   const dim3 threads(DECOMPRESS_HI_BLOCK_SIZE);
-  const dim3 blocks(m*n/(8*DECOMPRESS_HI_BLOCK_SIZE));
+  const dim3 blocks(m * n / (8 * DECOMPRESS_HI_BLOCK_SIZE));
   cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
   cuda_decompress_hi_origorder_kernel<<<blocks, threads, 0, stream>>>(
-    (uint32_t*)YIs.data_ptr<int32_t>(),
-    Y.data_ptr<c10::Half>()
-  );
+      (uint32_t*)YIs.data_ptr<int32_t>(), Y.data_ptr<c10::Half>());
 #endif
 }

+ 27 - 37
kernels/quantization/squeezellm/quant_cuda_kernel.cu

@@ -1,5 +1,4 @@
 #include <torch/all.h>
-#include <torch/python.h>
 #include <cuda.h>
 #include <cuda_runtime.h>
 #include <cuda_fp16.h>
@@ -22,27 +21,23 @@ __device__ inline unsigned int as_unsigned(int i) {
 // 4-bit matvec kernel (LUT-based)
 __global__ void NUQ4MatMulKernel(
 #ifndef USE_ROCM
-    const  half2* __restrict__ vec,
+    const half2* __restrict__ vec,
 #else
-    const  __half2* __restrict__ vec,
+    const __half2* __restrict__ vec,
 #endif
-    const    int* __restrict__ mat,
+    const int* __restrict__ mat,
 #ifndef USE_ROCM
-           half2* __restrict__ mul,
+    half2* __restrict__ mul,
 #else
-          float2* __restrict__ mul,
+    float2* __restrict__ mul,
 #endif
-    const  __half* __restrict__ lookup_table,
-    int height,
-    int width,
-    int batch,
-    int vec_height
-) {
+    const __half* __restrict__ lookup_table, int height, int width, int batch,
+    int vec_height) {
 
   const int blockwidth2 = BLOCKWIDTH / 2;
 
   int row = BLOCKHEIGHT4 * blockIdx.x;
-  int col =  BLOCKWIDTH * blockIdx.y + threadIdx.x;
+  int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
 
 #ifndef USE_ROCM
   __shared__ half2 blockvec[blockwidth2];
@@ -73,14 +68,16 @@ __global__ void NUQ4MatMulKernel(
   unsigned int tmp1;
   unsigned int lut_index1, lut_index2;
 
-  for (int b = 0; b < batch; ++b){
+  for (int b = 0; b < batch; ++b) {
     i = width * row + col;
     res = __int2half_rd(0);
     k = 0;
 
     __syncthreads();
     if (threadIdx.x < blockwidth2)
-      blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
+      blockvec[threadIdx.x] =
+          vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
+              threadIdx.x];
     __syncthreads();
 
     while (k < blockwidth2) {
@@ -143,7 +140,8 @@ __global__ void NUQ4MatMulKernel(
 #ifndef USE_ROCM
       res = __hadd(__hadd(res2.x, res2.y), res);
 #else
-      res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
+      res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
+                   res);
 #endif
 
       i += width;
@@ -179,46 +177,38 @@ __global__ void NUQ4MatMulKernel(
   }
 }
 
-} // namespace squeezellm
-} // namespace aphrodite
+}  // namespace squeezellm
+}  // namespace aphrodite
 
 // 4-bit matvec kernel (LUT-based)
-void squeezellm_gemm(
-  torch::Tensor vec,
-  torch::Tensor mat,
-  torch::Tensor mul,
-  torch::Tensor lookup_table
-) {
+void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
+                     torch::Tensor lookup_table) {
   int height = mat.size(0);
   int width = mat.size(1);
 
   int batch = vec.size(0);
   int vec_height = vec.size(1);
 
-  dim3 blocks(
-    (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
-    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
-  );
+  dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
+              (width + BLOCKWIDTH - 1) / BLOCKWIDTH);
   dim3 threads(BLOCKWIDTH);
 
   const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   aphrodite::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
 #ifndef USE_ROCM
-    (half2*) vec.data<at::Half>(),
+      (half2*)vec.data<at::Half>(),
 #else
-    (__half2*) vec.data_ptr<at::Half>(),
+      (__half2*)vec.data_ptr<at::Half>(),
 #endif
-    mat.data_ptr<int>(),
+      mat.data_ptr<int>(),
 #ifndef USE_ROCM
-    (half2*) mul.data<at::Half>(),
-    (__half*) lookup_table.data<at::Half>(),
+      (half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
 #else
-    (float2*) mul.data_ptr<float>(),
-    (__half*) lookup_table.data_ptr<at::Half>(),
+      (float2*)mul.data_ptr<float>(),
+      (__half*)lookup_table.data_ptr<at::Half>(),
 #endif
-    height, width, batch, vec_height
-  );
+      height, width, batch, vec_height);
 }
 
 #undef BLOCKWIDTH

+ 22 - 0
kernels/registration.h

@@ -0,0 +1,22 @@
+#pragma once
+
+#include <Python.h>
+
+#define _CONCAT(A, B) A##B
+#define CONCAT(A, B) _CONCAT(A, B)
+
+#define _STRINGIFY(A) #A
+#define STRINGIFY(A) _STRINGIFY(A)
+
+// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
+// could be a macro instead of a literal token.
+#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
+
+// REGISTER_EXTENSION allows the shared library to be loaded and initialized
+// via python's import statement.
+#define REGISTER_EXTENSION(NAME)                                               \
+  PyMODINIT_FUNC CONCAT(PyInit_, NAME)() {                                     \
+    static struct PyModuleDef module = {PyModuleDef_HEAD_INIT,                 \
+                                        STRINGIFY(NAME), nullptr, 0, nullptr}; \
+    return PyModule_Create(&module);                                           \
+  }

+ 292 - 0
kernels/torch_bindings.cpp

@@ -0,0 +1,292 @@
+#include "cache.h"
+#include "cuda_utils.h"
+#include "ops.h"
+#include "registration.h"
+#include "quantization/quant_ops.h"
+
+#include <torch/library.h>
+
+// Note on op signatures:
+// The X_meta signatures are for the meta functions corresponding to op X.
+// They must be kept in sync with the signature for X. Generally, only
+// functions that return Tensors require a meta function.
+//
+// See the following links for detailed docs on op registration and function
+// schemas.
+// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
+// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
+
+TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
+  // Aphrodite custom ops
+
+  // Attention ops
+  // Compute the attention between an input query and the cached
+  // keys/values using PagedAttention.
+  ops.def(
+      "paged_attention_v1("
+      "    Tensor! out, Tensor query, Tensor key_cache,"
+      "    Tensor value_cache, int num_kv_heads, float scale,"
+      "    Tensor block_tables, Tensor seq_lens, int block_size,"
+      "    int max_seq_len, Tensor? alibi_slopes,"
+      "    str kv_cache_dtype, float kv_scale, int tp_rank,"
+      "    int blocksparse_local_blocks,"
+      "    int blocksparse_vert_stride, int blocksparse_block_size,"
+      "    int blocksparse_head_sliding_step) -> ()");
+  ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
+
+  // PagedAttention V2.
+  ops.def(
+      "paged_attention_v2("
+      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
+      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
+      "    Tensor value_cache, int num_kv_heads, float scale,"
+      "    Tensor block_tables, Tensor seq_lens, int block_size,"
+      "    int max_seq_len, Tensor? alibi_slopes,"
+      "    str kv_cache_dtype, float kv_scale, int tp_rank,"
+      "    int blocksparse_local_blocks,"
+      "    int blocksparse_vert_stride, int blocksparse_block_size,"
+      "    int blocksparse_head_sliding_step) -> ()");
+  ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
+
+  // Activation ops
+  // Activation function used in SwiGLU.
+  ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
+  ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
+
+  // Activation function used in GeGLU with `none` approximation.
+  ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
+  ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
+
+  // Activation function used in GeGLU with `tanh` approximation.
+  ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
+  ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
+
+  // GELU implementation used in GPT-2.
+  ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
+  ops.impl("gelu_new", torch::kCUDA, &gelu_new);
+
+  // Approximate GELU implementation.
+  ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
+  ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
+
+  // Layernorm
+  // Apply Root Mean Square (RMS) Normalization to the input tensor.
+  ops.def(
+      "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
+      "()");
+  ops.impl("rms_norm", torch::kCUDA, &rms_norm);
+
+  // In-place fused Add and RMS Normalization.
+  ops.def(
+      "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
+      "float epsilon) -> ()");
+  ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
+
+  // Rotary embedding
+  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
+  ops.def(
+      "rotary_embedding(Tensor positions, Tensor! query,"
+      "                 Tensor! key, int head_size,"
+      "                 Tensor cos_sin_cache, bool is_neox) -> ()");
+  ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
+
+  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
+  // (supports multiple loras).
+  ops.def(
+      "batched_rotary_embedding(Tensor positions, Tensor! query,"
+      "                         Tensor! key, int head_size,"
+      "                         Tensor cos_sin_cache, bool is_neox,"
+      "                         int rot_dim,"
+      "                         Tensor cos_sin_cache_offsets) -> ()");
+  ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
+
+  // Quantization ops
+#ifndef USE_ROCM
+  // Quantized GEMM for AQLM.
+  ops.def("aqlm_gemm", &aqlm_gemm);
+  ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
+
+  // Decompression method for AQLM.
+  ops.def("aqlm_dequant", &aqlm_dequant);
+  ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
+
+  // Quantized GEMM for AWQ.
+  ops.def("awq_gemm", &awq_gemm);
+  ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
+
+  // Dequantization for AWQ.
+  ops.def("awq_dequantize", &awq_dequantize);
+  ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
+
+  // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
+  ops.def("marlin_gemm", &marlin_gemm);
+  ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
+
+  // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
+  ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
+  ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
+
+  // gptq_marlin Optimized Quantized GEMM for GPTQ.
+  ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
+  ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
+
+  // gptq_marlin repack from GPTQ.
+  ops.def("gptq_marlin_repack", &gptq_marlin_repack);
+  ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
+
+  // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
+  // quantization.
+  ops.def(
+      "cutlass_scaled_mm_dq(Tensor! out, Tensor a,"
+      "                     Tensor b, Tensor a_scales,"
+      "                     Tensor b_scales) -> ()");
+  ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq);
+
+  // QuIP# GEMV
+  ops.def("quip_gemv", &e8p_mm_origorder);
+  ops.impl("quip_gemv", torch::kCUDA, &e8p_mm_origorder);
+
+  // QuIP# Decompress
+  ops.def("quip_decompress", &decompress_e8p_origorder);
+  ops.impl("quip_decompress", torch::kCUDA, &decompress_e8p_origorder);
+#endif
+
+  // Quantized GEMM for GPTQ.
+  ops.def("gptq_gemm", &gptq_gemm);
+  ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
+
+  // Post processing for GPTQ.
+  ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
+  ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
+
+  // Quantized GEMM for SqueezeLLM.
+  ops.def(
+      "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
+      "lookup_table) -> ()");
+  ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);
+
+  // Compute FP8 quantized tensor for given scaling factor.
+  ops.def(
+      "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
+  ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
+
+  // Compute FP8 quantized tensor and scaling factor.
+  ops.def(
+      "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
+      "()");
+  ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
+
+  // Aligning the number of tokens to be processed by each expert such
+  // that it is divisible by the block size.
+  ops.def(
+      "moe_align_block_size(Tensor topk_ids, int num_experts,"
+      "                     int block_size, Tensor! sorted_token_ids,"
+      "                     Tensor! experts_ids,"
+      "                     Tensor! num_tokens_post_pad) -> ()");
+  ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
+
+  // Compute int8 quantized tensor for given scaling factor.
+  ops.def(
+      "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
+      "()");
+  ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
+
+  // Compute int8 quantized tensor and scaling factor
+  ops.def(
+      "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
+      "()");
+  ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
+           &dynamic_scaled_int8_quant);
+}
+
+TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
+  // Cache ops
+  // Swap in (out) the cache blocks from src to dst.
+  cache_ops.def(
+      "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
+  cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
+
+  // Copy the cache blocks from src to dst.
+  cache_ops.def(
+      "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
+      "block_mapping) -> ()");
+  cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
+
+  // Reshape the key and value tensors and cache them.
+  cache_ops.def(
+      "reshape_and_cache(Tensor key, Tensor value,"
+      "                  Tensor! key_cache, Tensor! value_cache,"
+      "                  Tensor slot_mapping,"
+      "                  str kv_cache_dtype,"
+      "                  float kv_scale) -> ()");
+  cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
+
+  // Reshape the key and value tensors and cache them.
+  cache_ops.def(
+      "reshape_and_cache_flash(Tensor key, Tensor value,"
+      "                        Tensor! key_cache,"
+      "                        Tensor! value_cache,"
+      "                        Tensor slot_mapping,"
+      "                        str kv_cache_dtype) -> ()");
+  cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
+                 &reshape_and_cache_flash);
+
+  // Convert the key and value cache to fp8 data type.
+  cache_ops.def(
+      "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
+      "kv_cache_dtype) -> ()");
+  cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
+}
+
+TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
+  // Cuda utils
+
+  // Gets the specified device attribute.
+  cuda_utils.def("get_device_attribute", &get_device_attribute);
+  cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);
+
+  // Gets the maximum shared memory per block device attribute.
+  cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
+                 &get_max_shared_memory_per_block_device_attribute);
+  cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
+                  torch::kCUDA,
+                  &get_max_shared_memory_per_block_device_attribute);
+}
+
+#ifndef USE_ROCM
+TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
+  // Custom all-reduce kernels
+  custom_ar.def("init_custom_ar", &init_custom_ar);
+  custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
+
+  custom_ar.def("should_custom_ar", &should_custom_ar);
+  custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
+
+  custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
+  custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
+
+  custom_ar.def(
+      "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
+      "()");
+  custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
+
+  custom_ar.def("dispose", &dispose);
+  custom_ar.impl("dispose", torch::kCPU, &dispose);
+
+  custom_ar.def("meta_size", &meta_size);
+  custom_ar.impl("meta_size", torch::kCPU, &meta_size);
+
+  custom_ar.def("register_buffer", &register_buffer);
+  custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);
+
+  custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
+  custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
+                 &get_graph_buffer_ipc_meta);
+
+  custom_ar.def("register_graph_buffers", &register_graph_buffers);
+  custom_ar.impl("register_graph_buffers", torch::kCPU,
+                 &register_graph_buffers);
+}
+#endif
+
+REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

+ 15 - 32
setup.py

@@ -46,7 +46,7 @@ def remove_prefix(text, prefix):
 class CMakeExtension(Extension):
 
     def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
-        super().__init__(name, sources=[], **kwa)
+        super().__init__(name, sources=[], py_limited_api=True, **kwa)
         self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
 
 
@@ -143,14 +143,11 @@ class cmake_build_ext(build_ext):
             '-DAPHRODITE_PYTHON_EXECUTABLE={}'.format(sys.executable)
         ]
 
-        if _install_quants():
-            cmake_args += ['-DAPHRODITE_INSTALL_QUANT_KERNELS=ON']
-
         if _install_punica():
             cmake_args += ['-DAPHRODITE_INSTALL_PUNICA_KERNELS=ON']
 
-        if _install_hadamard():
-            cmake_args += ['-DAPHRODITE_INSTALL_HADAMARD_KERNELS=ON']
+        # if _install_hadamard():
+        #     cmake_args += ['-DAPHRODITE_INSTALL_HADAMARD_KERNELS=ON']
 
         #
         # Setup parallelism and build tool
@@ -227,18 +224,6 @@ def _is_cpu() -> bool:
     return APHRODITE_TARGET_DEVICE == "cpu"
 
 
-def _install_quants() -> bool:
-    install_quants = bool(
-        int(os.getenv("APHRODITE_INSTALL_QUANT_KERNELS", "1")))
-    device_count = torch.cuda.device_count()
-    for i in range(device_count):
-        major, minor = torch.cuda.get_device_capability(i)
-        if major < 6:
-            install_quants = False
-            break
-    return install_quants
-
-
 def _install_punica() -> bool:
     install_punica = bool(
         int(os.getenv("APHRODITE_INSTALL_PUNICA_KERNELS", "1")))
@@ -251,16 +236,16 @@ def _install_punica() -> bool:
     return install_punica
 
 
-def _install_hadamard() -> bool:
-    install_hadamard = bool(
-        int(os.getenv("APHRODITE_INSTALL_HADAMARD_KERNELS", "1")))
-    device_count = torch.cuda.device_count()
-    for i in range(device_count):
-        major, minor = torch.cuda.get_device_capability(i)
-        if major <= 6:
-            install_hadamard = False
-            break
-    return install_hadamard
+# def _install_hadamard() -> bool:
+#     install_hadamard = bool(
+#         int(os.getenv("APHRODITE_INSTALL_HADAMARD_KERNELS", "1")))
+#     device_count = torch.cuda.device_count()
+#     for i in range(device_count):
+#         major, minor = torch.cuda.get_device_capability(i)
+#         if major <= 6:
+#             install_hadamard = False
+#             break
+#     return install_hadamard
 
 
 def get_hipcc_rocm_version():
@@ -416,13 +401,11 @@ if _is_cuda() or _is_hip():
 
 if not _is_neuron():
     ext_modules.append(CMakeExtension(name="aphrodite._C"))
-    if _install_quants() and _is_cuda() or _is_hip():
-        ext_modules.append(CMakeExtension(name="aphrodite._quant_C"))
     if _install_punica() and _is_cuda() or _is_hip():
         ext_modules.append(CMakeExtension(name="aphrodite._punica_C"))
     # TODO: see if hadamard kernels work with HIP
-    if _install_hadamard() and _is_cuda():
-        ext_modules.append(CMakeExtension(name="aphrodite._hadamard_C"))
+    # if _install_hadamard() and _is_cuda():
+    #     ext_modules.append(CMakeExtension(name="aphrodite._hadamard_C"))
 
 package_data = {
     "aphrodite": [

+ 1 - 1
tests/benchmarks/attention.py

@@ -4,7 +4,7 @@ import time
 
 import torch
 
-from aphrodite._C import ops as attention_ops
+from aphrodite import _custom_ops as attention_ops
 
 NUM_BLOCKS = 1024
 PARTITION_SIZE = 512