Pārlūkot izejas kodu

Add in, macrosf for defining __grid_constant__ (#852)

Driss Guessous 1 gadu atpakaļ
vecāks
revīzija
4a73e903da

+ 1 - 1
csrc/flash_attn/flash_api.cpp

@@ -46,7 +46,7 @@ void set_params_fprop(Flash_fwd_params &params,
                       bool seqlenq_ngroups_swapped=false) {
 
     // Reset the parameters
-    memset(&params, 0, sizeof(params));
+    params = {};
 
     params.is_bf16 = q.dtype() == torch::kBFloat16;
 

+ 34 - 11
csrc/flash_attn/src/flash_bwd_launch_template.h

@@ -11,6 +11,40 @@
 #include "flash_bwd_preprocess_kernel.h"
 #include "flash_bwd_kernel.h"
 
+// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+#define ARCH_SUPPORTS_FLASH
+#define KERNEL_PARAM_MODIFIER __grid_constant__
+#else
+#define KERNEL_PARAM_MODIFIER
+#endif
+
+// Define a macro for unsupported architecture handling to centralize the error message
+#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
+
+// Use a macro to clean up kernel definitions
+#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
+template<typename Kernel_traits, __VA_ARGS__> \
+__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
+
+DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
+    #if defined(ARCH_SUPPORTS_FLASH)
+       flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
+    #else
+        FLASH_UNSUPPORTED_ARCH
+    #endif
+}
+
+DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
+    #if defined(ARCH_SUPPORTS_FLASH)
+        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
+        flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
+    #else
+        FLASH_UNSUPPORTED_ARCH
+    #endif
+}
+
+
 template<bool Clear_dQaccum=true, typename Kernel_traits>
 __global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
     flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
@@ -21,17 +55,6 @@ __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
     flash::clear_dKVaccum<Kernel_traits>(params);
 }
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K>
-__global__ void flash_bwd_dq_dk_dv_loop_kernel(__grid_constant__ const Flash_bwd_params params) {
-    flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
-}
-
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K>
-__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(__grid_constant__ const Flash_bwd_params params) {
-    static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
-    flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
-}
-
 template<typename Kernel_traits>
 __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
     flash::convert_dQ<Kernel_traits>(params, nsplits);

+ 30 - 9
csrc/flash_attn/src/flash_fwd_launch_template.h

@@ -10,19 +10,40 @@
 #include "flash.h"
 #include "flash_fwd_kernel.h"
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
-__global__ void flash_fwd_kernel(__grid_constant__ const Flash_fwd_params params) {
-    static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
-    flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
+// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+#define ARCH_SUPPORTS_FLASH
+#define KERNEL_PARAM_MODIFIER __grid_constant__
+#else
+#define KERNEL_PARAM_MODIFIER
+#endif
+
+// Define a macro for unsupported architecture handling to centralize the error message
+#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
+
+// Use a macro to clean up kernel definitions
+#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
+template<typename Kernel_traits, __VA_ARGS__> \
+__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
+
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
+    #if defined(ARCH_SUPPORTS_FLASH)
+        static_assert(!(Is_causal && Is_local)); // Enforce constraints
+        flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
+    #else
+        FLASH_UNSUPPORTED_ARCH
+    #endif
 }
 
-template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
-__global__ void flash_fwd_splitkv_kernel(__grid_constant__ const Flash_fwd_params params) {
-    flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
+    #if defined(ARCH_SUPPORTS_FLASH)
+        flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
+    #else
+        FLASH_UNSUPPORTED_ARCH
+    #endif
 }
 
-template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K>
-__global__ void flash_fwd_splitkv_combine_kernel(__grid_constant__ const Flash_fwd_params params) {
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
     static_assert(Log_max_splits >= 1);
     flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
 }