|
@@ -25,7 +25,7 @@ template <int kHeadDim, int kBlockM, int kBlockN, typename Element, bool Is_caus
|
|
|
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
|
|
using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
|
|
|
using ElementAccum = float;
|
|
|
- using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, /*Clear_dQaccum=*/true, Varlen>;
|
|
|
+ using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90, /*Clear_dQaccum=*/true, Varlen>;
|
|
|
int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * 128, 128);
|
|
|
typename PreprocessKernel::Arguments preprocess_args {
|
|
|
static_cast<Element const*>(params.o_ptr),
|
|
@@ -130,7 +130,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
|
|
cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
|
|
|
CHECK_CUDA_KERNEL_LAUNCH();
|
|
|
|
|
|
- using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum,
|
|
|
+ using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, cutlass::arch::Sm90,
|
|
|
AttnKernel::CollectiveMainloop::kNThreadsdQ,
|
|
|
typename AttnKernel::CollectiveMainloop::SmemLayoutdQaccumTMA,
|
|
|
typename AttnKernel::CollectiveMainloop::TiledMmadQ,
|