Explorar o código

initial support; port most of the kernels

AlpinDale hai 4 meses
pai
achega
378d624b3d

+ 3 - 1
.gitignore

@@ -199,4 +199,6 @@ _build/
 kv_cache_states/*
 quant_params/*
 .ruff_cache/
-images/
+images/
+
+*.pdb

+ 31 - 20
CMakeLists.txt

@@ -66,6 +66,16 @@ endif()
 #
 find_package(Torch REQUIRED)
 
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+set(CMAKE_CUDA_STANDARD 17)
+set(CMAKE_CUDA_STANDARD_REQUIRED ON)
+
+# Replace -std=c++20 with -std=c++17 in APHRODITE_GPU_FLAGS
+if(APHRODITE_GPU_LANG STREQUAL "CUDA")
+  list(APPEND APHRODITE_GPU_FLAGS "--std=c++17")
+endif()
+
 #
 # Add the `default` target which detects which extensions should be
 # built based on platform/architecture.  This is the same logic that
@@ -188,15 +198,15 @@ set(APHRODITE_EXT_SRC
   "kernels/torch_bindings.cpp")
 
 if(APHRODITE_GPU_LANG STREQUAL "CUDA")
-  include(FetchContent)
-  SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
-  FetchContent_Declare(
-        cutlass
-        GIT_REPOSITORY https://github.com/nvidia/cutlass.git
-        # CUTLASS 3.5.1
-        GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 
-  )
-  FetchContent_MakeAvailable(cutlass)
+#  include(FetchContent)
+#  SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
+#  FetchContent_Declare(
+#        cutlass
+#        GIT_REPOSITORY https://github.com/nvidia/cutlass.git
+#        # CUTLASS 3.5.1
+#        GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 
+#  )
+#  FetchContent_MakeAvailable(cutlass)
 
   list(APPEND APHRODITE_EXT_SRC
     "kernels/quantization/fp6/fp6_linear.cu"
@@ -214,22 +224,23 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
     "kernels/quantization/gptq_marlin/awq_marlin_repack.cu"
     "kernels/quantization/fp8/fp8_marlin.cu"
     "kernels/all_reduce/custom_all_reduce.cu"
-    "kernels/quantization/cutlass_w8a8/scaled_mm_entry.cu"
-    "kernels/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
-    "kernels/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
+    #"kernels/quantization/cutlass_w8a8/scaled_mm_entry.cu"
+    #"kernels/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
+    #"kernels/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
+    )
 
   #
   # The CUTLASS kernels for Hopper require sm90a to be enabled.
   # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
   # That adds an extra 17MB to compiled binary, so instead we selectively enable it.
-  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
-    set_source_files_properties(
-          "kernels/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
-          "kernels/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
-          PROPERTIES
-          COMPILE_FLAGS
-          "-gencode arch=compute_90a,code=sm_90a -Wno-psabi")
-  endif()
+  #if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
+  #  set_source_files_properties(
+  #        "kernels/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
+  #        "kernels/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
+  #        PROPERTIES
+  #        COMPILE_FLAGS
+  #        "-gencode arch=compute_90a,code=sm_90a -Wno-psabi")
+  #endif()
 
 endif()
 

+ 2 - 1
cmake/utils.cmake

@@ -3,7 +3,8 @@
 # `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
 #
 macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
-  file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
+  file(TO_CMAKE_PATH "${EXECUTABLE}" EXECUTABLE_PATH)
+  file(REAL_PATH "${EXECUTABLE_PATH}" EXECUTABLE)
   set(Python_EXECUTABLE ${EXECUTABLE})
   find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
   if (NOT Python_FOUND)

+ 2 - 2
kernels/attention/attention_kernels.cu

@@ -208,8 +208,8 @@ __device__ void paged_attention_kernel(
   const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
 
   // blocksparse specific vars
-  int bs_block_offset;
-  int q_bs_block_id;
+  [[maybe_unused]] int bs_block_offset;
+  [[maybe_unused]] int q_bs_block_id;
   if constexpr (IS_BLOCK_SPARSE) {
     // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
     // blocksparse_block_size);

+ 4 - 4
kernels/cache_kernels.cu

@@ -111,8 +111,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
 
   // Create data structures for the kernel.
   // Create an array of pointers to the key and value caches.
-  int64_t key_cache_ptrs[num_layers];
-  int64_t value_cache_ptrs[num_layers];
+  std::vector<int64_t> key_cache_ptrs(num_layers);
+  std::vector<int64_t> value_cache_ptrs(num_layers);
   for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
     key_cache_ptrs[layer_idx] =
         reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
@@ -126,10 +126,10 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
   // Move the data structures to the GPU.
   // NOTE: This synchronizes the CPU and GPU.
   torch::Tensor key_cache_ptrs_tensor =
-      torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
+      torch::from_blob(key_cache_ptrs.data(), {num_layers}, torch::kInt64)
           .to(cache_device);
   torch::Tensor value_cache_ptrs_tensor =
-      torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
+      torch::from_blob(value_cache_ptrs.data(), {num_layers}, torch::kInt64)
           .to(cache_device);
 
   // Launch the kernel.

+ 26 - 13
kernels/core/scalar_type.hpp

@@ -36,7 +36,7 @@ class ScalarType {
         signed_(signed_),
         bias(bias),
         finite_values_only(finite_values_only),
-        nan_repr(nan_repr){};
+        nan_repr(nan_repr) {}
 
   static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
     return ScalarType(0, size_bits - 1, true, bias);
@@ -107,19 +107,32 @@ class ScalarType {
                                  finite_values_only, nan_repr);
   };
 
-  template <typename Fn, typename Init>
-  static constexpr auto reduce_member_types(Fn f, Init init) {
-    constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
-    return dummy_type.reduce_members(f, init);
-  };
+private:
+    template <typename Fn, typename Init>
+    static constexpr auto reduce_member_types(Fn f, Init init) {
+        // Directly specify the types in the same order as the members
+        using MemberTypes = std::tuple<
+            uint8_t,    // exponent
+            uint8_t,    // mantissa
+            bool,       // signed_
+            int32_t,    // bias
+            bool,       // finite_values_only
+            NanRepr     // nan_repr
+        >;
+        
+        return std::apply([f, init](auto... types) {
+            return reduce_members_helper(f, init, 
+                typename std::decay<decltype(types)>::type{}...);
+        }, MemberTypes{});
+    }
 
-  static constexpr auto id_size_bits() {
-    return reduce_member_types(
-        [](int acc, auto member) -> int {
-          return acc + member_id_field_width<decltype(member)>();
-        },
-        0);
-  }
+    static constexpr auto id_size_bits() {
+        return reduce_member_types(
+            [](int acc, auto member) -> int {
+                return acc + member_id_field_width<decltype(member)>();
+            },
+            0);
+    }
 
  public:
   // unique id for this scalar type that can be computed at compile time for

+ 11 - 9
kernels/mamba/causal_conv1d/causal_conv1d.cu

@@ -124,7 +124,7 @@ at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
     TORCH_CHECK(
         dim % 8 == 0,
         "causal_conv1d only supports channel dimension divisible by 8 for now");
-    TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0,
+    TORCH_CHECK(x.stride(2) % 8 == 0 && x.stride(0) % 8 == 0,
                 "causal_conv1d with channel last layout requires strides "
                 "(x.stride(0) and x.stride(2)) to be multiples of 8");
   }
@@ -342,18 +342,18 @@ __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(
 
   // Shared memory.
   extern __shared__ char smem_[];
-  auto& smem_load =
+  [[maybe_unused]] auto& smem_load =
       reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
-  auto& smem_load_vec =
+  [[maybe_unused]] auto& smem_load_vec =
       reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
-  auto& smem_load_index =
+  [[maybe_unused]] auto& smem_load_index =
       reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
-  auto& smem_load_index_vec =
+  [[maybe_unused]] auto& smem_load_index_vec =
       reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(
           smem_);
-  auto& smem_store =
+  [[maybe_unused]] auto& smem_store =
       reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
-  auto& smem_store_vec =
+  [[maybe_unused]] auto& smem_store_vec =
       reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
   vec_t* smem_exchange = reinterpret_cast<vec_t*>(smem_ + Ktraits::kSmemIOSize);
 
@@ -589,7 +589,7 @@ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_channellast_fwd_kernel(
                  batch_id * params.out_batch_stride +
                  (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride +
                  chunk_c_id * kChunkSizeC + c_idx * kNElts;
-  int* seq_idx = !kHasSeqIdx
+  [[maybe_unused]] int* seq_idx = !kHasSeqIdx
                      ? nullptr
                      : reinterpret_cast<int*>(params.seq_idx_ptr) +
                            batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
@@ -702,7 +702,9 @@ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_channellast_fwd_kernel(
 #pragma unroll
   for (int i = 0; i < kLPerThread; ++i) {
     out_vals[i] = bias_val;
-    const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
+    [[maybe_unused]] const int seq_idx_cur = !kHasSeqIdx
+                                              ? 0
+                                              : seq_idx_thread[i + kWidth - 1];
 #pragma unroll
     for (int w = 0; w < kWidth; ++w) {
       if constexpr (!kHasSeqIdx) {

+ 8 - 8
kernels/mamba/mamba_ssm/selective_scan_fwd.cu

@@ -116,11 +116,11 @@ __global__ __launch_bounds__(
   // reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
   auto& smem_load =
       reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
-  auto& smem_load_weight =
+  [[maybe_unused]] auto& smem_load_weight =
       reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
-  auto& smem_load_index =
+  [[maybe_unused]] auto& smem_load_index =
       reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
-  auto& smem_load_weight1 =
+  [[maybe_unused]] auto& smem_load_weight1 =
       *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(
           smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
   auto& smem_store =
@@ -145,12 +145,12 @@ __global__ __launch_bounds__(
                    dim_id * kNRows * params.delta_d_stride;
   weight_t* A = reinterpret_cast<weight_t*>(params.A_ptr) +
                 dim_id * kNRows * params.A_d_stride;
-  weight_t* B = reinterpret_cast<weight_t*>(params.B_ptr) +
+  [[maybe_unused]] weight_t* B = reinterpret_cast<weight_t*>(params.B_ptr) +
                 dim_id * kNRows * params.B_d_stride;
   input_t* Bvar = reinterpret_cast<input_t*>(params.B_ptr) +
                   batch_id * params.B_batch_stride +
                   group_id * params.B_group_stride;
-  weight_t* C = reinterpret_cast<weight_t*>(params.C_ptr) +
+  [[maybe_unused]] weight_t* C = reinterpret_cast<weight_t*>(params.C_ptr) +
                 dim_id * kNRows * params.C_d_stride;
   input_t* Cvar = reinterpret_cast<input_t*>(params.C_ptr) +
                   batch_id * params.C_batch_stride +
@@ -158,7 +158,7 @@ __global__ __launch_bounds__(
   scan_t* x = reinterpret_cast<scan_t*>(params.x_ptr) +
               (batch_id * params.dim + dim_id * kNRows) * params.n_chunks *
                   params.dstate;
-  int* index = !kUseIndex ? nullptr
+  [[maybe_unused]] int* index = !kUseIndex ? nullptr
                           : reinterpret_cast<int*>(params.index_ptr) +
                                 batch_id * params.seqlen;
 
@@ -188,7 +188,7 @@ __global__ __launch_bounds__(
   constexpr int kChunkSize = kNThreads * kNItems;
   for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
     input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
-    int index_vals_load[kNRows][kNItems];
+    [[maybe_unused]] int index_vals_load[kNRows][kNItems];
 
     __syncthreads();
 #pragma unroll
@@ -397,7 +397,7 @@ template <int kNThreads, int kNItems, typename input_t, typename weight_t>
 void selective_scan_fwd_launch(SSMParamsBase& params, cudaStream_t stream) {
   // Only kNRows == 1 is tested for now, which ofc doesn't differ from
   // previously when we had each block processing 1 row.
-  constexpr int kNRows = 1;
+  static constexpr int kNRows = 1;
   BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
     BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
       BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {

+ 2 - 2
kernels/mamba/mamba_ssm/static_switch.h

@@ -17,10 +17,10 @@
 #define BOOL_SWITCH(COND, CONST_NAME, ...) \
   [&] {                                    \
     if (COND) {                            \
-      constexpr bool CONST_NAME = true;    \
+      static constexpr bool CONST_NAME = true;    \
       return __VA_ARGS__();                \
     } else {                               \
-      constexpr bool CONST_NAME = false;   \
+      static constexpr bool CONST_NAME = false;   \
       return __VA_ARGS__();                \
     }                                      \
   }()

+ 20 - 20
kernels/quantization/awq/gemm_kernels.cu

@@ -176,7 +176,7 @@ __global__ void __launch_bounds__(64)
     for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
       {
         unsigned int addr;
-        __asm__ __volatile__(
+        asm(
             "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
             "addr; }\n"
             : "=r"(addr)
@@ -184,7 +184,7 @@ __global__ void __launch_bounds__(64)
                           (((((int)threadIdx.x) & 15) * 40) +
                            ((((int)threadIdx.x) >> 4) * 8)))));
 
-        __asm__ __volatile__(
+        asm(
             "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
             "{%0, %1, %2, %3}, [%4];\n"
             : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
@@ -197,7 +197,7 @@ __global__ void __launch_bounds__(64)
       for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
         {
           unsigned int addr;
-          __asm__ __volatile__(
+          asm(
               "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
               "addr; }\n"
               : "=r"(addr)
@@ -206,7 +206,7 @@ __global__ void __launch_bounds__(64)
                                          (ax1_0 * 16))])) +
                             (((((int)threadIdx.x) & 15) * (N + 8)) +
                              ((((int)threadIdx.x) >> 4) * 8)))));
-          __asm__ __volatile__(
+          asm(
               "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
               "{%0, %1, %2, %3}, [%4];\n"
               : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
@@ -219,7 +219,7 @@ __global__ void __launch_bounds__(64)
       for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
   #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
         {
-          __asm__ __volatile__(
+          asm(
               "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
@@ -236,7 +236,7 @@ __global__ void __launch_bounds__(64)
         }
 
         {
-          __asm__ __volatile__(
+          asm(
               "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
@@ -253,7 +253,7 @@ __global__ void __launch_bounds__(64)
         }
 
         {
-          __asm__ __volatile__(
+          asm(
               "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
@@ -270,7 +270,7 @@ __global__ void __launch_bounds__(64)
         }
 
         {
-          __asm__ __volatile__(
+          asm(
               "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
@@ -287,7 +287,7 @@ __global__ void __launch_bounds__(64)
         }
   #else
         {
-          __asm__ __volatile__(
+          asm(
               "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
               "%13};\n"
@@ -308,7 +308,7 @@ __global__ void __launch_bounds__(64)
         }
 
         {
-          __asm__ __volatile__(
+          asm(
               "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
               "%13};\n"
@@ -558,7 +558,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
     for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
       {
         unsigned int addr;
-        __asm__ __volatile__(
+        asm(
             "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
             "addr; }\n"
             : "=r"(addr)
@@ -566,7 +566,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
                           (((((int)threadIdx.x) & 15) * 40) +
                            ((((int)threadIdx.x) >> 4) * 8)))));
 
-        __asm__ __volatile__(
+        asm(
             "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
             "{%0, %1, %2, %3}, [%4];\n"
             : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
@@ -579,7 +579,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
       for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
         {
           unsigned int addr;
-          __asm__ __volatile__(
+          asm(
               "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
               "addr; }\n"
               : "=r"(addr)
@@ -588,7 +588,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
                                          (ax1_0 * 16))])) +
                             (((((int)threadIdx.x) & 15) * (N + 8)) +
                              ((((int)threadIdx.x) >> 4) * 8)))));
-          __asm__ __volatile__(
+          asm(
               "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
               "{%0, %1, %2, %3}, [%4];\n"
               : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
@@ -601,7 +601,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
       for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
   #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
         {
-          __asm__ __volatile__(
+          asm(
               "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
@@ -618,7 +618,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
         }
 
         {
-          __asm__ __volatile__(
+          asm(
               "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
@@ -635,7 +635,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
         }
 
         {
-          __asm__ __volatile__(
+          asm(
               "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
@@ -652,7 +652,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
         }
 
         {
-          __asm__ __volatile__(
+          asm(
               "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
@@ -669,7 +669,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
         }
   #else
         {
-          __asm__ __volatile__(
+          asm(
               "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
               "%13};\n"
@@ -690,7 +690,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
         }
 
         {
-          __asm__ __volatile__(
+          asm(
               "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
               "%13};\n"

+ 5 - 2
kernels/quantization/fp8/common.cu

@@ -10,8 +10,11 @@
 
 #ifndef USE_ROCM
 using FP8_TYPE = c10::Float8_e4m3fn;
-C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
-    std::numeric_limits<FP8_TYPE>::max();
+#ifdef _WIN32
+#define FP8_E4M3_MAX (std::numeric_limits<FP8_TYPE>::max())
+#else
+C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
+#endif
 #else
   #include "amd/hip_float8.h"
 using FP8_TYPE = c10::Float8_e4m3fnuz;

+ 1 - 1
kernels/quantization/marlin/sparse/common/base.h

@@ -44,7 +44,7 @@ using I4 = Vec<int, 4>;
 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
 using FragA = Vec<half2, 4>;
 using FragB = Vec<half2, 2>;
-using FragM = Vec<uint, 1>;
+using FragM = Vec<unsigned int, 1>;
 using FragC = Vec<float, 4>;
 using FragS = Vec<half2, 1>;  // quantization scales
 

+ 16 - 0
kernels/reduction.cuh

@@ -43,7 +43,23 @@ using ReduceFnType = T (*)(T, T);
 // Helper function to return the next largest power of 2
 static constexpr int _nextPow2(unsigned int num) {
   if (num <= 1) return num;
+
+#if defined(_MSC_VER) && !defined(__clang__) // MSVC without Clang
+  // Decrement n (to handle cases when n itself is a power of 2)
+  num--;
+  
+  // Set all bits after the first set bit
+  num |= num >> 1;
+  num |= num >> 2;
+  num |= num >> 4;
+  num |= num >> 8;
+  num |= num >> 16;
+  
+  // Add 1 to get the next power of 2
+  return num + 1;
+#else // GCC, Clang, or other compilers with __builtin_clz
   return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
+#endif
 }
 
 template <typename T, int numLanes = WARP_SIZE>

+ 1 - 1
requirements-cuda.txt

@@ -6,5 +6,5 @@ nvidia-ml-py == 12.555.43
 torch == 2.4.0
 torchvision == 0.19  # for phi3v
 xformers == 0.0.27.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0
-triton >= 2.2.1
+triton >= 2.2.1; platform_system == 'Linux'
 aphrodite-flash-attn == 2.6.1.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0

+ 12 - 2
setup.py

@@ -52,7 +52,12 @@ if not sys.platform.startswith("linux"):
         "Aphrodite only supports Linux platform (including WSL). "
         f"Building on {sys.platform}, "
         "so APhrodite may not be able to run correctly")
-    APHRODITE_TARGET_DEVICE = "empty"
+    if sys.platform.startswith("win32"):
+        logger.warning("Only CUDA backend is tested on Windows.")
+        APHRODITE_TARGET_DEVICE = "cuda"
+    else:
+        APHRODITE_TARGET_DEVICE = empty
+       
 
 MAIN_CUDA_VERSION = "12.4"
 
@@ -225,6 +230,9 @@ class cmake_build_ext(build_ext):
 def _no_device() -> bool:
     return APHRODITE_TARGET_DEVICE == "empty"
 
+def _is_windows() -> bool:
+    return APHRODITE_TARGET_DEVICE == "windows"
+
 def _is_cuda() -> bool:
     has_cuda = torch.version.cuda is not None
     return (APHRODITE_TARGET_DEVICE == "cuda" and has_cuda
@@ -401,7 +409,7 @@ def get_requirements() -> List[str]:
                 resolved_requirements.append(line)
         return resolved_requirements
 
-    if _no_device():
+    if _no_device() or _is_windows():
         requirements = _read_requirements("requirements-cuda.txt")
     elif _is_cuda():
         requirements = _read_requirements("requirements-cuda.txt")
@@ -430,6 +438,8 @@ def get_requirements() -> List[str]:
         raise ValueError(
             "Unsupported platform, please use CUDA, ROCm, Neuron, CPU or "
             "OpenVINO.")
+    if _is_windows():
+        requirements.append("winloop")
     return requirements