|
@@ -10,19 +10,40 @@
|
|
|
#include "flash.h"
|
|
|
#include "flash_fwd_kernel.h"
|
|
|
|
|
|
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
|
|
|
-__global__ void flash_fwd_kernel(__grid_constant__ const Flash_fwd_params params) {
|
|
|
- static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
|
|
- flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
|
|
|
+// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
|
|
|
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
|
+#define ARCH_SUPPORTS_FLASH
|
|
|
+#define KERNEL_PARAM_MODIFIER __grid_constant__
|
|
|
+#else
|
|
|
+#define KERNEL_PARAM_MODIFIER
|
|
|
+#endif
|
|
|
+
|
|
|
+// Define a macro for unsupported architecture handling to centralize the error message
|
|
|
+#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
|
|
|
+
|
|
|
+// Use a macro to clean up kernel definitions
|
|
|
+#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
|
|
|
+template<typename Kernel_traits, __VA_ARGS__> \
|
|
|
+__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
|
|
|
+
|
|
|
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
|
|
|
+ #if defined(ARCH_SUPPORTS_FLASH)
|
|
|
+ static_assert(!(Is_causal && Is_local)); // Enforce constraints
|
|
|
+ flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
|
|
|
+ #else
|
|
|
+ FLASH_UNSUPPORTED_ARCH
|
|
|
+ #endif
|
|
|
}
|
|
|
|
|
|
-template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
|
|
|
-__global__ void flash_fwd_splitkv_kernel(__grid_constant__ const Flash_fwd_params params) {
|
|
|
- flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
|
|
|
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
|
|
|
+ #if defined(ARCH_SUPPORTS_FLASH)
|
|
|
+ flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
|
|
|
+ #else
|
|
|
+ FLASH_UNSUPPORTED_ARCH
|
|
|
+ #endif
|
|
|
}
|
|
|
|
|
|
-template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K>
|
|
|
-__global__ void flash_fwd_splitkv_combine_kernel(__grid_constant__ const Flash_fwd_params params) {
|
|
|
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
|
|
|
static_assert(Log_max_splits >= 1);
|
|
|
flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
|
|
|
}
|