浏览代码

formatting

AlpinDale 4 月之前
父节点
当前提交
064c701d18

+ 3 - 3
aphrodite/common/utils.py

@@ -4,6 +4,7 @@ import contextlib
 import datetime
 import enum
 import gc
+import math
 import os
 import socket
 import subprocess
@@ -12,13 +13,12 @@ import tempfile
 import threading
 import uuid
 import warnings
-import math
 from asyncio import FIRST_COMPLETED, ensure_future
 from functools import lru_cache, partial, wraps
 from platform import uname
 from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
-                    Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
-                    Type, TypeVar, Union, overload, Iterable)
+                    Hashable, Iterable, List, Literal, Optional, OrderedDict,
+                    Set, Tuple, Type, TypeVar, Union, overload)
 from uuid import uuid4
 
 import numpy as np

+ 1 - 1
aphrodite/distributed/parallel_state.py

@@ -21,9 +21,9 @@ If you only need to use the distributed environment without model/pipeline
  steps.
 """
 import contextlib
-import sys
 import os
 import pickle
+import sys
 from collections import namedtuple
 from contextlib import contextmanager, nullcontext
 from dataclasses import dataclass

+ 2 - 2
aphrodite/endpoints/openai/api_server.py

@@ -23,7 +23,8 @@ from aphrodite.common.config import ModelConfig
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
 from aphrodite.common.utils import (FlexibleArgumentParser,
-                                    get_open_zmq_ipc_path, random_uuid)
+                                    get_open_zmq_ipc_path, in_windows,
+                                    random_uuid)
 from aphrodite.endpoints.logger import RequestLogger
 from aphrodite.endpoints.openai.args import make_arg_parser
 # yapf: disable
@@ -54,7 +55,6 @@ from aphrodite.engine.protocol import AsyncEngineClient
 from aphrodite.server import serve_http
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
 from aphrodite.version import __version__ as APHRODITE_VERSION
-from aphrodite.common.utils import in_windows
 
 if in_windows():
     import winloop as uvloop

+ 1 - 1
aphrodite/endpoints/openai/rpc/server.py

@@ -9,12 +9,12 @@ from loguru import logger
 from typing_extensions import Never
 
 from aphrodite import AsyncAphrodite, AsyncEngineArgs
+from aphrodite.common.utils import in_windows
 from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_HEALTHY_STR,
                                             APHRODITE_RPC_SUCCESS_STR,
                                             RPCAbortRequest,
                                             RPCGenerateRequest,
                                             RPCUtilityRequest)
-from aphrodite.common.utils import in_windows
 
 if in_windows():
     import winloop as uvloop

+ 1 - 1
aphrodite/server/launch.py

@@ -8,9 +8,9 @@ import uvicorn
 from fastapi import FastAPI, Response
 from loguru import logger
 
+from aphrodite.common.utils import in_windows
 from aphrodite.engine.async_aphrodite import AsyncEngineDeadError
 from aphrodite.engine.protocol import AsyncEngineClient
-from aphrodite.common.utils import in_windows
 
 APHRODITE_KEEP_ALIVE_ON_ENGINE_DEATH = bool(os.getenv(
     "APHRODITE_KEEP_ALIVE_ON_ENGINE_DEATH", 0))

+ 4 - 5
kernels/core/registration.h

@@ -1,14 +1,13 @@
 #pragma once
 
 #ifdef _DEBUG
-#undef _DEBUG
-#include <Python.h>
-#define _DEBUG
+  #undef _DEBUG
+  #include <Python.h>
+  #define _DEBUG
 #else
-#include <Python.h>
+  #include <Python.h>
 #endif
 
-
 #define _CONCAT(A, B) A##B
 #define CONCAT(A, B) _CONCAT(A, B)
 

+ 6 - 7
kernels/mamba/causal_conv1d/causal_conv1d.cu

@@ -589,10 +589,10 @@ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_channellast_fwd_kernel(
                  batch_id * params.out_batch_stride +
                  (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride +
                  chunk_c_id * kChunkSizeC + c_idx * kNElts;
-  [[maybe_unused]] int* seq_idx = !kHasSeqIdx
-                     ? nullptr
-                     : reinterpret_cast<int*>(params.seq_idx_ptr) +
-                           batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
+  [[maybe_unused]] int* seq_idx =
+      !kHasSeqIdx ? nullptr
+                  : reinterpret_cast<int*>(params.seq_idx_ptr) +
+                        batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
   input_t* initial_states =
       params.initial_states_ptr == nullptr || chunk_l_id > 0
           ? nullptr
@@ -702,9 +702,8 @@ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_channellast_fwd_kernel(
 #pragma unroll
   for (int i = 0; i < kLPerThread; ++i) {
     out_vals[i] = bias_val;
-    [[maybe_unused]] const int seq_idx_cur = !kHasSeqIdx
-                                              ? 0
-                                              : seq_idx_thread[i + kWidth - 1];
+    [[maybe_unused]] const int seq_idx_cur =
+        !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
 #pragma unroll
     for (int w = 0; w < kWidth; ++w) {
       if constexpr (!kHasSeqIdx) {

+ 6 - 5
kernels/mamba/mamba_ssm/selective_scan_fwd.cu

@@ -146,21 +146,22 @@ __global__ __launch_bounds__(
   weight_t* A = reinterpret_cast<weight_t*>(params.A_ptr) +
                 dim_id * kNRows * params.A_d_stride;
   [[maybe_unused]] weight_t* B = reinterpret_cast<weight_t*>(params.B_ptr) +
-                dim_id * kNRows * params.B_d_stride;
+                                 dim_id * kNRows * params.B_d_stride;
   input_t* Bvar = reinterpret_cast<input_t*>(params.B_ptr) +
                   batch_id * params.B_batch_stride +
                   group_id * params.B_group_stride;
   [[maybe_unused]] weight_t* C = reinterpret_cast<weight_t*>(params.C_ptr) +
-                dim_id * kNRows * params.C_d_stride;
+                                 dim_id * kNRows * params.C_d_stride;
   input_t* Cvar = reinterpret_cast<input_t*>(params.C_ptr) +
                   batch_id * params.C_batch_stride +
                   group_id * params.C_group_stride;
   scan_t* x = reinterpret_cast<scan_t*>(params.x_ptr) +
               (batch_id * params.dim + dim_id * kNRows) * params.n_chunks *
                   params.dstate;
-  [[maybe_unused]] int* index = !kUseIndex ? nullptr
-                          : reinterpret_cast<int*>(params.index_ptr) +
-                                batch_id * params.seqlen;
+  [[maybe_unused]] int* index =
+      !kUseIndex
+          ? nullptr
+          : reinterpret_cast<int*>(params.index_ptr) + batch_id * params.seqlen;
 
   float D_val[kNRows] = {0};
   if (params.D_ptr != nullptr) {

+ 9 - 9
kernels/mamba/mamba_ssm/static_switch.h

@@ -14,13 +14,13 @@
 ///     some_function<BoolConst>(...);
 /// });
 /// ```
-#define BOOL_SWITCH(COND, CONST_NAME, ...) \
-  [&] {                                    \
-    if (COND) {                            \
-      static constexpr bool CONST_NAME = true;    \
-      return __VA_ARGS__();                \
-    } else {                               \
-      static constexpr bool CONST_NAME = false;   \
-      return __VA_ARGS__();                \
-    }                                      \
+#define BOOL_SWITCH(COND, CONST_NAME, ...)      \
+  [&] {                                         \
+    if (COND) {                                 \
+      static constexpr bool CONST_NAME = true;  \
+      return __VA_ARGS__();                     \
+    } else {                                    \
+      static constexpr bool CONST_NAME = false; \
+      return __VA_ARGS__();                     \
+    }                                           \
   }()

+ 1 - 1
kernels/moe/torch_bindings.cpp

@@ -1,5 +1,5 @@
 #ifdef _WIN32
-#include <crtdefs.h>
+  #include <crtdefs.h>
 #endif
 
 #include "../core/registration.h"

+ 20 - 40
kernels/quantization/awq/gemm_kernels.cu

@@ -176,16 +176,14 @@ __global__ void __launch_bounds__(64)
     for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
       {
         unsigned int addr;
-        asm(
-            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
+        asm("{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
             "addr; }\n"
             : "=r"(addr)
             : "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
                           (((((int)threadIdx.x) & 15) * 40) +
                            ((((int)threadIdx.x) >> 4) * 8)))));
 
-        asm(
-            "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
+        asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
             "{%0, %1, %2, %3}, [%4];\n"
             : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
               "=r"(((unsigned*)(A_shared_warp + 0))[1]),
@@ -197,8 +195,7 @@ __global__ void __launch_bounds__(64)
       for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
         {
           unsigned int addr;
-          asm(
-              "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
+          asm("{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
               "addr; }\n"
               : "=r"(addr)
               : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
@@ -206,8 +203,7 @@ __global__ void __launch_bounds__(64)
                                          (ax1_0 * 16))])) +
                             (((((int)threadIdx.x) & 15) * (N + 8)) +
                              ((((int)threadIdx.x) >> 4) * 8)))));
-          asm(
-              "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
+          asm("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
               "{%0, %1, %2, %3}, [%4];\n"
               : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
                 "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
@@ -219,8 +215,7 @@ __global__ void __launch_bounds__(64)
       for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
   #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
         {
-          asm(
-              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+          asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
@@ -236,8 +231,7 @@ __global__ void __launch_bounds__(64)
         }
 
         {
-          asm(
-              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+          asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
@@ -253,8 +247,7 @@ __global__ void __launch_bounds__(64)
         }
 
         {
-          asm(
-              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+          asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
@@ -270,8 +263,7 @@ __global__ void __launch_bounds__(64)
         }
 
         {
-          asm(
-              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+          asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
@@ -287,8 +279,7 @@ __global__ void __launch_bounds__(64)
         }
   #else
         {
-          asm(
-              "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+          asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
               "%13};\n"
               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
@@ -308,8 +299,7 @@ __global__ void __launch_bounds__(64)
         }
 
         {
-          asm(
-              "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+          asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
               "%13};\n"
               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
@@ -558,16 +548,14 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
     for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
       {
         unsigned int addr;
-        asm(
-            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
+        asm("{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
             "addr; }\n"
             : "=r"(addr)
             : "l"((void*)((&(A_shared[(k_0_1 * 16)])) +
                           (((((int)threadIdx.x) & 15) * 40) +
                            ((((int)threadIdx.x) >> 4) * 8)))));
 
-        asm(
-            "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
+        asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
             "{%0, %1, %2, %3}, [%4];\n"
             : "=r"(((unsigned*)(A_shared_warp + 0))[0]),
               "=r"(((unsigned*)(A_shared_warp + 0))[1]),
@@ -579,8 +567,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
       for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
         {
           unsigned int addr;
-          asm(
-              "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
+          asm("{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
               "addr; }\n"
               : "=r"(addr)
               : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) +
@@ -588,8 +575,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
                                          (ax1_0 * 16))])) +
                             (((((int)threadIdx.x) & 15) * (N + 8)) +
                              ((((int)threadIdx.x) >> 4) * 8)))));
-          asm(
-              "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
+          asm("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
               "{%0, %1, %2, %3}, [%4];\n"
               : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]),
                 "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]),
@@ -601,8 +587,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
       for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
   #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
         {
-          asm(
-              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+          asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
@@ -618,8 +603,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
         }
 
         {
-          asm(
-              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+          asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
@@ -635,8 +619,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
         }
 
         {
-          asm(
-              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+          asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
                 "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]),
@@ -652,8 +635,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
         }
 
         {
-          asm(
-              "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+          asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),
                 "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]),
@@ -669,8 +651,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
         }
   #else
         {
-          asm(
-              "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+          asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
               "%13};\n"
               : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]),
@@ -690,8 +671,7 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
         }
 
         {
-          asm(
-              "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+          asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
               "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
               "%13};\n"
               : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]),

+ 6 - 5
kernels/quantization/fp8/common.cu

@@ -10,11 +10,12 @@
 
 #ifndef USE_ROCM
 using FP8_TYPE = c10::Float8_e4m3fn;
-#ifdef _WIN32
-#define FP8_E4M3_MAX (std::numeric_limits<FP8_TYPE>::max())
-#else
-C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
-#endif
+  #ifdef _WIN32
+    #define FP8_E4M3_MAX (std::numeric_limits<FP8_TYPE>::max())
+  #else
+C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
+    std::numeric_limits<FP8_TYPE>::max();
+  #endif
 #else
   #include "amd/hip_float8.h"
 using FP8_TYPE = c10::Float8_e4m3fnuz;

+ 4 - 4
kernels/quantization/gptq_marlin/gptq_marlin.cu

@@ -1697,10 +1697,10 @@ __global__ void Marlin(
   #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
                     HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS)          \
     if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS &&              \
-             thread_n_blocks == THREAD_N_BLOCKS &&                             \
-             thread_k_blocks == THREAD_K_BLOCKS &&                             \
-             has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP &&             \
-             group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) {     \
+        thread_n_blocks == THREAD_N_BLOCKS &&                                  \
+        thread_k_blocks == THREAD_K_BLOCKS &&                                  \
+        has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP &&                  \
+        group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) {          \
       cudaFuncSetAttribute(                                                    \
           Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS,          \
                  THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \

+ 2 - 2
kernels/quantization/quant_ops.h

@@ -109,7 +109,7 @@ at::Tensor e8p_mm_origorder(const at::Tensor& A, const at::Tensor& B,
 void decompress_e8p_origorder(torch::Tensor YIs, torch::Tensor CB,
                               torch::Tensor& Y);
 
-#ifndef _WIN32
+  #ifndef _WIN32
 bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
 
 void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
@@ -132,7 +132,7 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
                               torch::Tensor const& s_group,
                               torch::Tensor& workspace, int64_t size_m,
                               int64_t size_n, int64_t size_k);
-#endif
+  #endif
 
 torch::Tensor fp_eXmY_linear_forward_cuda(int64_t EXPONENT, int64_t MANTISSA,
                                           torch::Tensor _in_feats,

+ 4 - 4
kernels/reduction.cuh

@@ -44,20 +44,20 @@ using ReduceFnType = T (*)(T, T);
 static constexpr int _nextPow2(unsigned int num) {
   if (num <= 1) return num;
 
-#if defined(_MSC_VER) && !defined(__clang__) // MSVC without Clang
+#if defined(_MSC_VER) && !defined(__clang__)  // MSVC without Clang
   // Decrement n (to handle cases when n itself is a power of 2)
   num--;
-  
+
   // Set all bits after the first set bit
   num |= num >> 1;
   num |= num >> 2;
   num |= num >> 4;
   num |= num >> 8;
   num |= num >> 16;
-  
+
   // Add 1 to get the next power of 2
   return num + 1;
-#else // GCC, Clang, or other compilers with __builtin_clz
+#else  // GCC, Clang, or other compilers with __builtin_clz
   return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
 #endif
 }

+ 2 - 2
kernels/torch_bindings.cpp

@@ -162,7 +162,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
   ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
 
-#ifndef _WIN32
+  #ifndef _WIN32
   // marlin_qqq_gemm for QQQ.
   ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
   ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
@@ -189,7 +189,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "                  Tensor b_scales, Tensor azp_adj,"
       "                  Tensor? azp, Tensor? bias) -> ()");
   ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
-#endif
+  #endif
 
   // QuIP# GEMV
   ops.def("quip_gemv", &e8p_mm_origorder);

+ 1 - 1
setup.py

@@ -56,7 +56,7 @@ if not sys.platform.startswith("linux"):
         logger.warning("Only CUDA backend is tested on Windows.")
         APHRODITE_TARGET_DEVICE = "cuda"
     else:
-        APHRODITE_TARGET_DEVICE = empty
+        APHRODITE_TARGET_DEVICE = "empty"
        
 
 MAIN_CUDA_VERSION = "12.4"