Bladeren bron

Simplify BOOL_SWITCH macro to fix compiling error on gcc 7

Tri Dao 2 jaren geleden
bovenliggende
commit
8a2ece89f7

+ 2 - 1
csrc/flash_attn/src/fmha.h

@@ -36,7 +36,8 @@
 #include <ATen/cuda/CUDAGeneratorImpl.h>
 #endif
 
-#include <ATen/cuda/CUDAGraphsUtils.cuh>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/detail/UnpackRaw.cuh>
 
 #include <fmha_utils.h>
 

+ 2 - 3
csrc/flash_attn/src/fmha_bwd_hdim128.cu

@@ -5,9 +5,8 @@
 #include "fmha_bwd_launch_template.h"
 
 void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
-    // work around for MSVC issue
-    FP16_SWITCH(params.is_bf16, [&] {
+    FP16_SWITCH(params.is_bf16, ({
         using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
         run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
-    });
+    }));
 }

+ 2 - 3
csrc/flash_attn/src/fmha_bwd_hdim32.cu

@@ -5,8 +5,7 @@
 #include "fmha_bwd_launch_template.h"
 
 void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
-    // work around for MSVC issue
-    FP16_SWITCH(params.is_bf16, [&] {
+    FP16_SWITCH(params.is_bf16, ({
         if (params.seqlen_k == 128) {
             using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
             run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
@@ -14,5 +13,5 @@ void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const b
             using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
             run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
         }
-    });
+    }));
 }

+ 2 - 3
csrc/flash_attn/src/fmha_bwd_hdim64.cu

@@ -5,8 +5,7 @@
 #include "fmha_bwd_launch_template.h"
 
 void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
-    // work around for MSVC issue
-    FP16_SWITCH(params.is_bf16, [&] {
+    FP16_SWITCH(params.is_bf16, ({
         auto dprops = at::cuda::getCurrentDeviceProperties();
         if (params.seqlen_k == 128) {
             using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
@@ -27,5 +26,5 @@ void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const b
                 run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
             }
         }
-    });
+    }));
 }

+ 2 - 3
csrc/flash_attn/src/fmha_bwd_launch_template.h

@@ -3,7 +3,6 @@
 #pragma once
 
 #include "static_switch.h"
-#include "fp16_switch.h"
 #include "fmha.h"
 #include "fmha_dgrad_kernel_1xN_loop.h"
 
@@ -62,7 +61,7 @@ void run_fmha_bwd_loop(FMHA_dgrad_params &params, cudaStream_t stream, const boo
 
     bool is_dropout = params.p_dropout < 1.f;  // params.p_dropout is the probability of "keeping"
     // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
-    BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
+    BOOL_SWITCH(is_dropout, IsDropoutConst, ({
         auto kernel = params.is_causal
             ? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
             : &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
@@ -111,5 +110,5 @@ void run_fmha_bwd_loop(FMHA_dgrad_params &params, cudaStream_t stream, const boo
             kernel_seqparallel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
         }
         FMHA_CHECK_CUDA(cudaPeekAtLastError());
-    });
+    }));
 }

+ 2 - 2
csrc/flash_attn/src/fmha_fwd_hdim128.cu

@@ -5,8 +5,8 @@
 #include "fmha_fwd_launch_template.h"
 
 void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params) {
-    FP16_SWITCH(launch_params.params.is_bf16, [&] {
+    FP16_SWITCH(launch_params.params.is_bf16, ({
         using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
         run_fmha_fwd_loop<Kernel_traits>(launch_params);
-    });
+    }));
 }

+ 2 - 2
csrc/flash_attn/src/fmha_fwd_hdim32.cu

@@ -5,7 +5,7 @@
 #include "fmha_fwd_launch_template.h"
 
 void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params) {
-    FP16_SWITCH(launch_params.params.is_bf16, [&] {
+    FP16_SWITCH(launch_params.params.is_bf16, ({
         if (launch_params.params.seqlen_k == 128) {
             using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
             run_fmha_fwd_loop<Kernel_traits>(launch_params);
@@ -13,5 +13,5 @@ void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params) {
             using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
             run_fmha_fwd_loop<Kernel_traits>(launch_params);
         }
-    });
+    }));
 }

+ 2 - 2
csrc/flash_attn/src/fmha_fwd_hdim64.cu

@@ -5,7 +5,7 @@
 #include "fmha_fwd_launch_template.h"
 
 void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params) {
-    FP16_SWITCH(launch_params.params.is_bf16, [&] {
+    FP16_SWITCH(launch_params.params.is_bf16, ({
         if (launch_params.params.seqlen_k == 128) {
             using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
             run_fmha_fwd_loop<Kernel_traits>(launch_params);
@@ -13,5 +13,5 @@ void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params) {
             using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
             run_fmha_fwd_loop<Kernel_traits>(launch_params);
         }
-    });
+    }));
 }

+ 2 - 3
csrc/flash_attn/src/fmha_fwd_launch_template.h

@@ -8,7 +8,6 @@
 #include <cuda_bf16.h>
 
 #include "static_switch.h"
-#include "fp16_switch.h"
 #include "fmha.h"
 #include "fmha_fprop_kernel_1xN.h"
 
@@ -57,7 +56,7 @@ void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) {
     // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
     // https://github.com/kokkos/kokkos-kernels/issues/349
     // https://github.com/HazyResearch/flash-attention/issues/21
-    BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
+    BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, ({
         auto kernel = launch_params.params.is_causal
             ? (launch_params.return_softmax
                ? &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
@@ -88,5 +87,5 @@ void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) {
         kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
             launch_params.params);
         FMHA_CHECK_CUDA(cudaPeekAtLastError());
-    });
+    }));
 }

+ 0 - 27
csrc/flash_attn/src/fp16_switch.h

@@ -1,27 +0,0 @@
-// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
-// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
-
-// modified from static_switch.h 
-// because MSVC cannot handle std::conditional with constexpr variable
-
-#pragma once
-
-/// @param COND       - a boolean expression to switch by
-/// @param ...       - code to execute for true and false
-///
-/// Usage:
-/// ```
-/// FP16_SWITCH(flag, [&] {
-///     some_function(...);
-/// });
-/// ```
-#define FP16_SWITCH(COND, ...)                                           \
-    [&] {                                                                            \
-        if (COND) {                                                                  \
-            using elem_type = __nv_bfloat16;   \
-            return __VA_ARGS__();                                                    \
-        } else {                                                                     \
-            using elem_type = __half;   \
-            return __VA_ARGS__();                                                    \
-        }                                                                            \
-    }()

+ 22 - 12
csrc/flash_attn/src/static_switch.h

@@ -9,17 +9,27 @@
 ///
 /// Usage:
 /// ```
-/// BOOL_SWITCH(flag, BoolConst, [&] {
+/// BOOL_SWITCH(flag, BoolConst, ({
 ///     some_function<BoolConst>(...);
-/// });
+/// }));
 /// ```
-#define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \
-    [&] {                                                                            \
-        if (COND) {                                                                  \
-            constexpr bool CONST_NAME = true;                                        \
-            return __VA_ARGS__();                                                    \
-        } else {                                                                     \
-            constexpr bool CONST_NAME = false;                                       \
-            return __VA_ARGS__();                                                    \
-        }                                                                            \
-    }()
+/// We need "({" and "})" to make sure that the code is a single argument being passed to the macro.
+#define BOOL_SWITCH(COND, CONST_NAME, CODE) \
+    if (COND) {                             \
+        constexpr bool CONST_NAME = true;   \
+        CODE;                               \
+    } else {                                \
+        constexpr bool CONST_NAME = false;  \
+        CODE;                               \
+    }
+
+// modified from BOOL_SWITCH
+// because MSVC cannot handle std::conditional with constexpr variable
+#define FP16_SWITCH(COND, CODE)        \
+    if (COND) {                        \
+      using elem_type = __nv_bfloat16; \
+      CODE;                            \
+    } else {                           \
+      using elem_type = __half;        \
+      CODE;                            \
+    }                                  \