Răsfoiți Sursa

fix: ROCm build (#817)

* Some fixed (ig)

* Oopsie :3

* Remove a comment

* indent the block

* another indentation

* Revert stuff (hopefully)

---------

Co-authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>
Co-authored-by: AlpinDale <alpindale@gmail.com>
Naomiusearch 3 luni în urmă
părinte
comite
4f9fea4c4d
2 a modificat fișierele cu 15 adăugiri și 17 ștergeri
  1. 14 16
      CMakeLists.txt
  2. 1 1
      amdpatch.sh

+ 14 - 16
CMakeLists.txt

@@ -20,7 +20,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
 set(CUDA_SUPPORTED_ARCHS "6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0")
 
 # Supported AMD GPU architectures.
-set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
+set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")
 
 #
 # Supported/expected torch versions for CUDA/ROCm.
@@ -65,20 +65,19 @@ endif()
 # etc.
 #
 find_package(Torch REQUIRED)
-find_package(CUDA REQUIRED)
-find_package(CUDAToolkit REQUIRED)
-
-# Add cuBLAS to the list of libraries to link against
-list(APPEND LIBS CUDA::cublas)
-
-set(CMAKE_CXX_STANDARD 17)
-set(CMAKE_CXX_STANDARD_REQUIRED ON)
-set(CMAKE_CUDA_STANDARD 17)
-set(CMAKE_CUDA_STANDARD_REQUIRED ON)
-
-# Replace -std=c++20 with -std=c++17 in APHRODITE_GPU_FLAGS
-if(APHRODITE_GPU_LANG STREQUAL "CUDA")
-  list(APPEND APHRODITE_GPU_FLAGS "--std=c++17" "-Xcompiler -Wno-return-type")
+if(MSVC)
+  find_package(CUDA REQUIRED)
+  find_package(CUDAToolkit REQUIRED)
+  # Add cuBLAS to the list of libraries to link against
+  list(APPEND LIBS CUDA::cublas)
+  set(CMAKE_CXX_STANDARD 17)
+  set(CMAKE_CXX_STANDARD_REQUIRED ON)
+  set(CMAKE_CUDA_STANDARD 17)
+  set(CMAKE_CUDA_STANDARD_REQUIRED ON)
+  # Replace -std=c++20 with -std=c++17 in APHRODITE_GPU_FLAGS
+  if(APHRODITE_GPU_LANG STREQUAL "CUDA")
+    list(APPEND APHRODITE_GPU_FLAGS "--std=c++17" "-Xcompiler -Wno-return-type")
+  endif()
 endif()
 
 #
@@ -222,7 +221,6 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
     "kernels/permute_cols.cu"
     "kernels/sampling/sampling.cu")
 
-  # Add CUTLASS and GPTQ Marlin kernels if not MSVC
   if(NOT MSVC)
     # Include CUTLASS only when needed
     include(FetchContent)

+ 1 - 1
amdpatch.sh

@@ -2,4 +2,4 @@
 
 ROCM_PATH=$(hipconfig --rocmpath)
 
-sudo patch $ROCM_PATH/lib/llvm/lib/clang/18/include/__clang_hip_cmath.h ./patches/amd.patch
+sudo patch $ROCM_PATH/lib/llvm/lib/clang/*/include/__clang_hip_cmath.h ./patches/amd.patch