|
@@ -123,7 +123,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
|
|
|
}
|
|
|
int n_block_max = collective_mainloop.get_n_block_max(
|
|
|
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
|
|
|
- if (Is_causal && n_block_max <= 0) {
|
|
|
+ if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) {
|
|
|
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
|
|
|
scheduler.broadcast_next_work(work_tile_info);
|
|
|
continue;
|
|
@@ -169,7 +169,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
|
|
|
}
|
|
|
int n_block_max = collective_mainloop.get_n_block_max(
|
|
|
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
|
|
|
- if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
|
|
|
+ if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
|
|
|
collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
|
|
|
continue;
|
|
|
}
|