|
@@ -3,7 +3,6 @@
|
|
|
#pragma once
|
|
|
|
|
|
#include "static_switch.h"
|
|
|
-#include "fp16_switch.h"
|
|
|
#include "fmha.h"
|
|
|
#include "fmha_dgrad_kernel_1xN_loop.h"
|
|
|
|
|
@@ -62,7 +61,7 @@ void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const boo
|
|
|
|
|
|
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
|
|
|
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
|
|
- BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
|
|
|
+ BOOL_SWITCH(is_dropout, IsDropoutConst, ({
|
|
|
auto kernel = params.is_causal
|
|
|
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
|
|
|
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
|
|
@@ -111,5 +110,5 @@ void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const boo
|
|
|
kernel_seqparallel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
|
|
|
}
|
|
|
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
|
|
- });
|
|
|
+ }));
|
|
|
}
|