Parcourir la source

fix: make AMD usable (#775)

* Patch for __test ambiguity on rocm

* Make amd work

* Update docs

* Update docs again

* Don't default to marlin for amd
Naomiusearch il y a 4 mois
Parent
commit
eee3cf5dab

+ 1 - 1
CMakeLists.txt

@@ -268,4 +268,4 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA" OR APHRODITE_GPU_LANG STREQUAL "HIP")
 
   message(STATUS "Enabling moe extension.")
   add_dependencies(default _moe_C)
-endif()
+endif()

+ 5 - 0
amdpatch.sh

@@ -0,0 +1,5 @@
+#!/bin/sh
+
+ROCM_PATH=$(hipconfig --rocmpath)
+
+sudo patch $ROCM_PATH/lib/llvm/lib/clang/18/include/__clang_hip_cmath.h ./patches/amd.patch

+ 5 - 0
aphrodite/quantization/gptq_marlin.py

@@ -19,6 +19,7 @@ from aphrodite.quantization.utils.marlin_utils import (
     marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
     verify_marlin_supported, verify_marlin_supports_shape)
 from aphrodite.scalar_type import scalar_types
+from aphrodite.common.utils import is_hip
 
 
 class GPTQMarlinConfig(QuantizationConfig):
@@ -93,6 +94,9 @@ class GPTQMarlinConfig(QuantizationConfig):
         is_valid_user_quant = (user_quant is None or user_quant == "marlin"
                                or user_quant == "gptq_marlin")
 
+        if is_hip():
+            return None
+
         if can_convert and is_valid_user_quant:
             msg = ("The model is convertible to {} during runtime."
                    " Using {} kernel.".format(cls.get_name(), cls.get_name()))
@@ -105,6 +109,7 @@ class GPTQMarlinConfig(QuantizationConfig):
                         " so forcing gptq. Use quantization=gptq_marlin for"
                         " faster inference")
         return None
+            
 
     def get_quant_method(self, layer: torch.nn.Module,
                          prefix: str) -> Optional["GPTQMarlinLinearMethod"]:

+ 2 - 0
docs/pages/installation/installation-rocm.md

@@ -72,6 +72,8 @@ Finally, build Aphrodite:
 git clone https://github.com/PygmalionAI/aphrodite-engine.git
 cd aphrodite-engine
 
+chmod +x ./amdpatch.sh
+./amdpatch.sh 
 pip install -U -r requirements-rocm.txt
 python setup.py develop  #  pip install -e . won't work for now
 ```

+ 20 - 23
kernels/ops.h

@@ -62,29 +62,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                           torch::Tensor expert_ids,
                           torch::Tensor num_tokens_post_pad);
 
-std::vector<torch::Tensor> selective_scan_fwd(
-    const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
-    const torch::Tensor& B, const torch::Tensor& C,
-    const c10::optional<torch::Tensor>& D_,
-    const c10::optional<torch::Tensor>& z_,
-    const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
-    const c10::optional<torch::Tensor>& index_,
-    const c10::optional<torch::Tensor>& x);
-
-at::Tensor causal_conv1d_update(const at::Tensor& x,
-                                const at::Tensor& conv_state,
-                                const at::Tensor& weight,
-                                const c10::optional<at::Tensor>& bias_,
-                                bool silu_activation);
-
-at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
-                             const c10::optional<at::Tensor>& bias_,
-                             const c10::optional<at::Tensor>& seq_idx_,
-                             const c10::optional<at::Tensor>& seq_pos_idx_,
-                             const c10::optional<at::Tensor>& initial_states_,
-                             const c10::optional<at::Tensor>& final_states_out_,
-                             bool silu_activation);
-
 #ifndef USE_ROCM
 using fptr_t = int64_t;
 fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
@@ -105,4 +82,24 @@ std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
     fptr_t _fa);
 void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
                             const std::vector<std::vector<int64_t>>& offsets);
+std::vector<torch::Tensor> selective_scan_fwd(
+    const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
+    const torch::Tensor& B, const torch::Tensor& C,
+    const c10::optional<torch::Tensor>& D_,
+    const c10::optional<torch::Tensor>& z_,
+    const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
+    const c10::optional<torch::Tensor>& index_,
+    const c10::optional<torch::Tensor>& x);
+at::Tensor causal_conv1d_update(const at::Tensor& x,
+                                const at::Tensor& conv_state,
+                                const at::Tensor& weight,
+                                const c10::optional<at::Tensor>& bias_,
+                                bool silu_activation);
+at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
+                             const c10::optional<at::Tensor>& bias_,
+                             const c10::optional<at::Tensor>& seq_idx_,
+                             const c10::optional<at::Tensor>& seq_pos_idx_,
+                             const c10::optional<at::Tensor>& initial_states_,
+                             const c10::optional<at::Tensor>& final_states_out_,
+                             bool silu_activation);
 #endif

+ 2 - 1
kernels/torch_bindings.cpp

@@ -271,7 +271,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "()");
   ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
            &dynamic_scaled_int8_quant);
-
+#ifndef USE_ROCM
   // Mamba kernels
   ops.def(
       "selective_scan_fwd(Tensor! u, Tensor! delta,"
@@ -298,6 +298,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "Tensor? final_states_out_,"
       "bool silu_activation) -> Tensor");
   ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
+#endif
 }
 
 TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {

+ 17 - 0
patches/amd.patch

@@ -0,0 +1,17 @@
+diff --git a/clang/lib/Headers/__clang_hip_cmath.h b/clang/lib/Headers/__clang_hip_cmath.h
+index 071c64c7af8d5b..e04fc7824b1771 100644
+--- a/clang/lib/Headers/__clang_hip_cmath.h
++++ b/clang/lib/Headers/__clang_hip_cmath.h
+@@ -397,7 +397,12 @@ template <class _Tp> struct __numeric_type {
+   // No support for long double, use double instead.
+   static double __test(long double);
+ 
+-  typedef decltype(__test(declval<_Tp>())) type;
++  template <typename _U>
++  static auto __test_impl(int) -> decltype(__test(declval<_U>()));
++
++  template <typename _U> static void __test_impl(...);
++
++  typedef decltype(__test_impl<_Tp>(0)) type;
+   static const bool value = !is_same<type, void>::value;
+ };

+ 1 - 1
requirements-rocm.txt

@@ -5,5 +5,5 @@
 awscli
 boto3
 botocore
-ray == 2.10.0
+ray >= 2.10.0
 peft