|
@@ -95,7 +95,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
});
|
|
|
}
|
|
|
|
|
|
-template<typename Kernel_traits>
|
|
|
+template<typename Kernel_traits, bool Is_causal>
|
|
|
void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
|
|
|
static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
|
|
@@ -104,27 +104,25 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
|
|
|
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
|
|
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
|
|
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
- BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
|
|
- EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
|
|
- LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
|
|
- BOOL_SWITCH(params.num_splits > 1, Split, [&] {
|
|
|
- BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
|
|
|
- ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
|
|
- SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
|
|
|
- // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
|
|
- // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
|
|
- // If Is_local, set Is_causal to false
|
|
|
- auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
|
|
|
- // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
|
|
- // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
|
|
- if (smem_size >= 48 * 1024) {
|
|
|
- C10_CUDA_CHECK(cudaFuncSetAttribute(
|
|
|
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
|
|
- }
|
|
|
- kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
|
|
- C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
|
- });
|
|
|
+ BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
|
|
+ EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
|
|
+ LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
|
|
+ BOOL_SWITCH(params.num_splits > 1, Split, [&] {
|
|
|
+ BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
|
|
|
+ ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
|
|
+ SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
|
|
|
+ // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
|
|
+ // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
|
|
+ // If Is_local, set Is_causal to false
|
|
|
+ auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
|
|
|
+ // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
|
|
+ // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
|
|
+ if (smem_size >= 48 * 1024) {
|
|
|
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
|
|
|
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
|
|
+ }
|
|
|
+ kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
|
|
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
|
});
|
|
|
});
|
|
|
});
|
|
@@ -159,161 +157,149 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-template<typename T, int Headdim>
|
|
|
+template<typename T, int Headdim, bool Is_causal>
|
|
|
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int kBlockM = 64; // Fixed for all head dimensions
|
|
|
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
|
|
|
// and for headdim 192 with block size 64 x 128.
|
|
|
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
|
|
|
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
|
|
|
- run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
|
|
|
+ run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
|
|
|
}
|
|
|
|
|
|
-template<typename T>
|
|
|
+template<typename T, bool Is_causal>
|
|
|
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 32;
|
|
|
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- });
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
});
|
|
|
}
|
|
|
|
|
|
-template<typename T>
|
|
|
+template<typename T, bool Is_causal>
|
|
|
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 64;
|
|
|
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
- if constexpr(!Is_dropout) {
|
|
|
- // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
|
|
- // Using block size (64 x 256) is 27% slower for seqlen=2k
|
|
|
- // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- } else {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- }
|
|
|
- });
|
|
|
+ if constexpr(!Is_dropout) {
|
|
|
+ // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
|
|
+ // Using block size (64 x 256) is 27% slower for seqlen=2k
|
|
|
+ // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ } else {
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ }
|
|
|
});
|
|
|
}
|
|
|
|
|
|
-template<typename T>
|
|
|
+template<typename T, bool Is_causal>
|
|
|
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 96;
|
|
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
|
|
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
|
|
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
- // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
|
|
- if (is_sm8x) {
|
|
|
- if constexpr(!Is_causal) {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- } else {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- }
|
|
|
- } else {
|
|
|
+ // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
|
|
+ if (is_sm8x) {
|
|
|
+ if constexpr(!Is_causal) {
|
|
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ } else {
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
}
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // These two are always slower
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
|
|
|
- });
|
|
|
+ } else {
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ }
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // These two are always slower
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
|
|
|
});
|
|
|
}
|
|
|
|
|
|
-template<typename T>
|
|
|
+template<typename T, bool Is_causal>
|
|
|
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 128;
|
|
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
|
|
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
|
|
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
- if constexpr(!Is_dropout) {
|
|
|
- // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
|
|
- // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
|
|
|
- if (is_sm8x) {
|
|
|
- if constexpr(!Is_causal) {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- } else {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- }
|
|
|
+ if constexpr(!Is_dropout) {
|
|
|
+ // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
|
|
+ // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
|
|
|
+ if (is_sm8x) {
|
|
|
+ if constexpr(!Is_causal) {
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
} else {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
}
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // 1st ones are good for H100, A100
|
|
|
- // 2nd one is good for A6000 bc we get slightly better occupancy
|
|
|
} else {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
}
|
|
|
- });
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // 1st ones are good for H100, A100
|
|
|
+ // 2nd one is good for A6000 bc we get slightly better occupancy
|
|
|
+ } else {
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ }
|
|
|
});
|
|
|
}
|
|
|
|
|
|
-template<typename T>
|
|
|
+template<typename T, bool Is_causal>
|
|
|
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 160;
|
|
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
|
|
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
|
|
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
- // For A100, H100, 128 x 32 is the fastest.
|
|
|
- // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
|
|
- // and 128 x 64 with 8 warps is the fastest for non-causal.
|
|
|
- if (is_sm8x) {
|
|
|
- if constexpr(!Is_causal) {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- } else {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- }
|
|
|
+ // For A100, H100, 128 x 32 is the fastest.
|
|
|
+ // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
|
|
+ // and 128 x 64 with 8 warps is the fastest for non-causal.
|
|
|
+ if (is_sm8x) {
|
|
|
+ if constexpr(!Is_causal) {
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
} else {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
}
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
|
|
- });
|
|
|
+ } else {
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ }
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
|
|
});
|
|
|
}
|
|
|
|
|
|
-template<typename T>
|
|
|
+template<typename T, bool Is_causal>
|
|
|
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 192;
|
|
|
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
- if constexpr(!Is_dropout) {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- } else {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- }
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
|
|
- });
|
|
|
+ if constexpr(!Is_dropout) {
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ } else {
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ }
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
|
|
});
|
|
|
}
|
|
|
|
|
|
-template<typename T>
|
|
|
+template<typename T, bool Is_causal>
|
|
|
void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 224;
|
|
|
int device;
|
|
@@ -326,23 +312,21 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
}
|
|
|
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
|
|
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
- if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- } else {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- }
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
|
|
|
- // If we have N = 32, there are only 1024 elements to load at once, where each load
|
|
|
- // is 8 elements. This means we can only use 128 threads and not 256 threads.
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- });
|
|
|
+ if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ } else {
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ }
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
|
|
|
+ // If we have N = 32, there are only 1024 elements to load at once, where each load
|
|
|
+ // is 8 elements. This means we can only use 128 threads and not 256 threads.
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
});
|
|
|
}
|
|
|
|
|
|
-template<typename T>
|
|
|
+template<typename T, bool Is_causal>
|
|
|
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 256;
|
|
|
int device;
|
|
@@ -357,18 +341,16 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
}
|
|
|
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
|
|
|
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
- // For A100, we want to run with 128 x 64 (128KB smem).
|
|
|
- // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
|
|
|
- if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- } else {
|
|
|
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- }
|
|
|
- // 64 KB
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- // 96 KB
|
|
|
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
- });
|
|
|
+ // For A100, we want to run with 128 x 64 (128KB smem).
|
|
|
+ // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
|
|
|
+ if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ } else {
|
|
|
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ }
|
|
|
+ // 64 KB
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
+ // 96 KB
|
|
|
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
});
|
|
|
}
|