|
@@ -42,10 +42,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
|
|
const bool return_softmax = params.p_ptr != nullptr;
|
|
|
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
|
|
- BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
|
|
- BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
|
|
+ EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
|
|
+ LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
|
|
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
|
|
- BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
|
|
+ ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
|
|
// Will only return softmax if dropout, to reduce compilation time.
|
|
|
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
|
|
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
|
@@ -83,11 +83,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
|
|
- BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
|
|
- BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
|
|
+ EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
|
|
+ LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
|
|
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
|
|
|
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
|
|
|
- BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
|
|
+ ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
|
|
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
|
|
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
|
|
// If Is_local, set Is_causal to false
|
|
@@ -113,7 +113,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
|
|
|
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
|
|
|
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
|
|
|
- BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
|
|
+ EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
|
|
if (params.num_splits <= 2) {
|
|
|
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
|
|
} else if (params.num_splits <= 4) {
|
|
@@ -147,7 +147,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream)
|
|
|
template<typename T>
|
|
|
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 32;
|
|
|
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
|
});
|
|
@@ -157,7 +157,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
template<typename T>
|
|
|
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 64;
|
|
|
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
if constexpr(!Is_dropout) {
|
|
|
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
|
@@ -181,7 +181,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 96;
|
|
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
|
|
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
|
|
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
|
|
if (is_sm8x) {
|
|
@@ -207,7 +207,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 128;
|
|
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
|
|
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
|
|
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
if constexpr(!Is_dropout) {
|
|
|
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
|
@@ -244,7 +244,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 160;
|
|
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
|
|
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
|
|
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
// For A100, H100, 128 x 32 is the fastest.
|
|
|
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
|
@@ -272,7 +272,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
template<typename T>
|
|
|
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
constexpr static int Headdim = 192;
|
|
|
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
if constexpr(!Is_dropout) {
|
|
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
@@ -300,7 +300,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
C10_CUDA_CHECK(status_);
|
|
|
}
|
|
|
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
|
|
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
|
|
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
|
@@ -331,7 +331,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
|
C10_CUDA_CHECK(status_);
|
|
|
}
|
|
|
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
|
|
|
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
|
// For A100, we want to run with 128 x 64 (128KB smem).
|
|
|
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
|