|
@@ -18,18 +18,19 @@
|
|
|
#include "utils.h"
|
|
|
|
|
|
|
|
|
-template<typename Kernel_traits, bool Is_causal, typename Seqlen_traits>
|
|
|
+template<typename Kernel_traits, bool Is_causal, bool Is_local, typename Seqlen_traits>
|
|
|
void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
+ static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
|
|
|
using Element = typename Kernel_traits::Element;
|
|
|
using OutputType = typename Kernel_traits::OutputType;
|
|
|
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
|
|
|
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
|
|
|
|
|
|
// print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
|
|
|
- using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Seqlen_traits>;
|
|
|
+ using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Is_local, Seqlen_traits>;
|
|
|
using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Seqlen_traits>;
|
|
|
using Scheduler = std::conditional_t<
|
|
|
- Seqlen_traits::kUseVarSeqLen,
|
|
|
+ Seqlen_traits::kUseVarSeqLen || Is_local,
|
|
|
flash::SingleTileScheduler,
|
|
|
std::conditional_t<!Is_causal,
|
|
|
flash::StaticPersistentTileScheduler,
|
|
@@ -60,7 +61,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
params.scale_softmax_log2,
|
|
|
params.descale_q_ptr,
|
|
|
params.descale_k_ptr,
|
|
|
- params.descale_v_ptr
|
|
|
+ params.descale_v_ptr,
|
|
|
+ params.window_size_left,
|
|
|
+ params.window_size_right
|
|
|
});
|
|
|
typename CollectiveEpilogue::Params epilogue_params =
|
|
|
CollectiveEpilogue::to_underlying_arguments({
|
|
@@ -85,7 +88,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
if constexpr(cutlass::sizeof_bits_v<Element> == 8)
|
|
|
kernel = (void *)flash::compute_attn_ws_fp8<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
|
|
|
else
|
|
|
- kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
|
|
|
+ kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Is_local, Scheduler, Seqlen_traits>;
|
|
|
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
|
|
|
// int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
|
|
|
// int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
|
|
@@ -115,11 +118,13 @@ template<typename T>
|
|
|
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 64;
|
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
- SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
|
|
|
- run_flash_fwd<
|
|
|
- Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>,
|
|
|
- Is_causal, Seqlen_traits
|
|
|
- >(params, stream);
|
|
|
+ BOOL_SWITCH(params.is_local, Is_local, [&] {
|
|
|
+ SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
|
|
|
+ run_flash_fwd<
|
|
|
+ Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>,
|
|
|
+ Is_causal, Is_local && !Is_causal, Seqlen_traits
|
|
|
+ >(params, stream);
|
|
|
+ });
|
|
|
});
|
|
|
});
|
|
|
}
|
|
@@ -128,13 +133,15 @@ template<typename T>
|
|
|
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 128;
|
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
- SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
|
|
|
- // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
|
|
|
- BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
|
|
|
- run_flash_fwd<
|
|
|
- Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>,
|
|
|
- Is_causal, Seqlen_traits
|
|
|
- >(params, stream);
|
|
|
+ BOOL_SWITCH(params.is_local, Is_local, [&] {
|
|
|
+ SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
|
|
|
+ // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
|
|
|
+ BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
|
|
|
+ run_flash_fwd<
|
|
|
+ Flash_fwd_kernel_traits<Headdim, 128, (Is_causal || Is_local) ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>,
|
|
|
+ Is_causal, Is_local && !Is_causal, Seqlen_traits
|
|
|
+ >(params, stream);
|
|
|
+ });
|
|
|
});
|
|
|
});
|
|
|
});
|
|
@@ -144,13 +151,15 @@ template<typename T>
|
|
|
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 256;
|
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
- SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
|
|
|
- // Only use Cluster if number of tiles along seqlen_q is even
|
|
|
- BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
|
|
|
- run_flash_fwd<
|
|
|
- Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>,
|
|
|
- Is_causal, Seqlen_traits
|
|
|
- >(params, stream);
|
|
|
+ BOOL_SWITCH(params.is_local, Is_local, [&] {
|
|
|
+ SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
|
|
|
+ // Only use Cluster if number of tiles along seqlen_q is even
|
|
|
+ BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
|
|
|
+ run_flash_fwd<
|
|
|
+ Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>,
|
|
|
+ Is_causal, Is_local && !Is_causal, Seqlen_traits
|
|
|
+ >(params, stream);
|
|
|
+ });
|
|
|
});
|
|
|
});
|
|
|
});
|
|
@@ -166,11 +175,11 @@ void run_mha_fwd_hdim64_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
using Seqlen_traits = flash::FixedSeqLenTraits;
|
|
|
if(params.is_causal) {
|
|
|
run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
|
|
|
- false, 1, T>, /*Is_causal=*/true, Seqlen_traits>(params, stream);
|
|
|
+ false, 1, T>, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream);
|
|
|
} else {
|
|
|
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
|
|
|
run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
|
|
|
- false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, Seqlen_traits>(params, stream);
|
|
|
+ false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream);
|
|
|
});
|
|
|
}
|
|
|
// BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
@@ -195,11 +204,11 @@ void run_mha_fwd_hdim128_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
using Seqlen_traits = flash::FixedSeqLenTraits;
|
|
|
if(params.is_causal) {
|
|
|
run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
|
|
|
- false, 1, T>, /*Is_causal=*/true, Seqlen_traits>(params, stream);
|
|
|
+ false, 1, T>, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream);
|
|
|
} else {
|
|
|
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
|
|
|
run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
|
|
|
- false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, Seqlen_traits>(params, stream);
|
|
|
+ false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream);
|
|
|
});
|
|
|
}
|
|
|
// BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
@@ -224,11 +233,11 @@ void run_mha_fwd_hdim256_fp8(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
using Seqlen_traits = flash::FixedSeqLenTraits;
|
|
|
if(params.is_causal) {
|
|
|
run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
|
|
|
- false, 1, T>, /*Is_causal=*/true, Seqlen_traits>(params, stream);
|
|
|
+ false, 1, T>, /*Is_causal=*/true, /*Is_local=*/false, Seqlen_traits>(params, stream);
|
|
|
} else {
|
|
|
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
|
|
|
run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
|
|
|
- false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, Seqlen_traits>(params, stream);
|
|
|
+ false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, /*Is_local=*/false, Seqlen_traits>(params, stream);
|
|
|
});
|
|
|
}
|
|
|
// BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|