|
@@ -154,7 +154,10 @@ struct CollectiveMainloopFwd {
|
|
|
typename Seqlen_traits::LayoutT layout_K;
|
|
|
Element const* ptr_V;
|
|
|
typename Seqlen_traits::LayoutT layout_V;
|
|
|
- float const softmax_scale_log2;
|
|
|
+ float const softmax_scale_log2;
|
|
|
+ float const* descale_q_ptr;
|
|
|
+ float const* descale_k_ptr;
|
|
|
+ float const* descale_v_ptr;
|
|
|
};
|
|
|
|
|
|
// Device side kernel params
|
|
@@ -166,7 +169,10 @@ struct CollectiveMainloopFwd {
|
|
|
TMA_Q tma_load_Q;
|
|
|
TMA_K tma_load_K;
|
|
|
TMA_V tma_load_V;
|
|
|
- float const softmax_scale_log2;
|
|
|
+ float const softmax_scale_log2;
|
|
|
+ float const* descale_q_ptr;
|
|
|
+ float const* descale_k_ptr;
|
|
|
+ float const* descale_v_ptr;
|
|
|
};
|
|
|
|
|
|
|
|
@@ -196,7 +202,8 @@ struct CollectiveMainloopFwd {
|
|
|
return {args.layout_Q, args.layout_K, args.layout_V,
|
|
|
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
|
|
|
tma_load_Q, tma_load_K, tma_load_V,
|
|
|
- args.softmax_scale_log2};
|
|
|
+ args.softmax_scale_log2,
|
|
|
+ args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr};
|
|
|
}
|
|
|
|
|
|
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
|
@@ -367,6 +374,7 @@ struct CollectiveMainloopFwd {
|
|
|
flatten(sVt_divide(_, i, j, stage)));
|
|
|
}
|
|
|
}
|
|
|
+ cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
|
|
|
};
|
|
|
|
|
|
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
|
|
@@ -720,7 +728,7 @@ struct CollectiveMainloopFwd {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
|
|
|
+ softmax.template online_softmax</*Is_first=*/true>(tSrS);
|
|
|
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
|
|
|
Tensor scores_scale = make_fragment_like(softmax.row_max);
|
|
|
clear(scores_scale);
|
|
@@ -747,8 +755,8 @@ struct CollectiveMainloopFwd {
|
|
|
tSrS(i) = -INFINITY;
|
|
|
}
|
|
|
}
|
|
|
- cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
|
|
|
- softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2);
|
|
|
+ cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS), scores_scale);
|
|
|
+ softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS);
|
|
|
warpgroup_wait<0>();
|
|
|
pipeline_v.consumer_release(smem_pipe_read_v); // release V
|
|
|
++smem_pipe_read_k;
|
|
@@ -769,8 +777,8 @@ struct CollectiveMainloopFwd {
|
|
|
warpgroup_wait<1>();
|
|
|
pipeline_k.consumer_release(smem_pipe_read_k); // release K
|
|
|
// auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
|
|
|
- cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
|
|
|
- softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
|
|
|
+ cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
|
|
|
+ softmax.template online_softmax</*Is_first=*/false>(tSrS);
|
|
|
warpgroup_wait<0>();
|
|
|
pipeline_v.consumer_release(smem_pipe_read_v); // release V
|
|
|
++smem_pipe_read_k;
|
|
@@ -783,7 +791,7 @@ struct CollectiveMainloopFwd {
|
|
|
softmax.rescale_o(tOrO, scores_scale);
|
|
|
consumer_wait(pipeline_v, smem_pipe_read_v);
|
|
|
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
|
|
- cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
|
|
|
+ cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS), scores_scale);
|
|
|
warpgroup_wait<0>();
|
|
|
pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang
|
|
|
++smem_pipe_read_v;
|
|
@@ -881,7 +889,7 @@ struct CollectiveMainloopFwd {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
|
|
|
+ softmax.template online_softmax</*Is_first=*/true>(tSrS);
|
|
|
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
|
|
|
permute_regs_A_to_C(tOrP);
|
|
|
|
|
@@ -915,18 +923,18 @@ struct CollectiveMainloopFwd {
|
|
|
|
|
|
warp_scheduler_barrier_arrive();
|
|
|
pipeline_k.consumer_release(smem_pipe_read);
|
|
|
+ if constexpr(Delay_V_release) {
|
|
|
+ pipeline_vt.consumer_release(smem_pipe_release);
|
|
|
+ ++smem_pipe_release;
|
|
|
+ }
|
|
|
consumer_wait(pipeline_vt, smem_pipe_read);
|
|
|
|
|
|
- cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
|
|
|
+ cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS), scores_scale);
|
|
|
softmax.rescale_o(tOrO, scores_scale);
|
|
|
- softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2);
|
|
|
+ softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS);
|
|
|
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
|
|
|
permute_regs_A_to_C(tOrP);
|
|
|
-
|
|
|
- if constexpr(Delay_V_release) {
|
|
|
- pipeline_vt.consumer_release(smem_pipe_release);
|
|
|
- ++smem_pipe_release;
|
|
|
- }
|
|
|
+
|
|
|
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
|
|
|
if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
|
|
|
++smem_pipe_read;
|
|
@@ -946,9 +954,9 @@ struct CollectiveMainloopFwd {
|
|
|
if constexpr(!Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
|
|
|
else { consumer_wait(pipeline_vt, smem_pipe_read); }
|
|
|
|
|
|
- cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
|
|
|
+ cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
|
|
|
softmax.rescale_o(tOrO, scores_scale);
|
|
|
- softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
|
|
|
+ softmax.template online_softmax</*Is_first=*/false>(tSrS);
|
|
|
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
|
|
|
permute_regs_A_to_C(tOrP);
|
|
|
|
|
@@ -965,23 +973,21 @@ struct CollectiveMainloopFwd {
|
|
|
CUTLASS_PRAGMA_NO_UNROLL
|
|
|
for (; n_block >= 0; --n_block) {
|
|
|
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
|
|
- consumer_wait(pipeline_k, smem_pipe_read);
|
|
|
- pipeline_vt.consumer_release(smem_pipe_release);
|
|
|
- flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
|
|
|
- warp_scheduler_barrier_arrive();
|
|
|
- warpgroup_wait<0>();
|
|
|
- consumer_wait(pipeline_vt, smem_pipe_read);
|
|
|
+ consumer_wait(pipeline_k, smem_pipe_read);
|
|
|
+ flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
|
|
|
+ warp_scheduler_barrier_arrive();
|
|
|
+ pipeline_k.consumer_release(smem_pipe_read);
|
|
|
+ pipeline_vt.consumer_release(smem_pipe_release);
|
|
|
|
|
|
- cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
|
|
|
+ cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
|
|
|
softmax.rescale_o(tOrO, scores_scale);
|
|
|
- softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
|
|
|
+ softmax.template online_softmax</*Is_first=*/false>(tSrS);
|
|
|
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
|
|
|
permute_regs_A_to_C(tOrP);
|
|
|
-
|
|
|
- pipeline_k.consumer_release(smem_pipe_read);
|
|
|
- flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
|
|
|
+
|
|
|
+ consumer_wait(pipeline_vt, smem_pipe_read);
|
|
|
+ flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
|
|
|
warp_scheduler_barrier_sync();
|
|
|
- warpgroup_wait<0>();
|
|
|
++smem_pipe_read;
|
|
|
++smem_pipe_release;
|
|
|
}
|
|
@@ -999,9 +1005,9 @@ struct CollectiveMainloopFwd {
|
|
|
warp_scheduler_barrier_arrive();
|
|
|
pipeline_k.consumer_release(smem_pipe_read);
|
|
|
|
|
|
- cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
|
|
|
+ cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
|
|
|
softmax.rescale_o(tOrO, scores_scale);
|
|
|
- softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
|
|
|
+ softmax.template online_softmax</*Is_first=*/false>(tSrS);
|
|
|
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
|
|
|
permute_regs_A_to_C(tOrP);
|
|
|
|
|
@@ -1014,8 +1020,7 @@ struct CollectiveMainloopFwd {
|
|
|
if constexpr (kHeadDim == 128) { warp_scheduler_barrier_arrive(); }
|
|
|
}
|
|
|
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
|
|
|
-
|
|
|
- cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
|
|
|
+ cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, shared_storage.descale_v), scores_scale);
|
|
|
softmax.rescale_o(tOrO, scores_scale);
|
|
|
return;
|
|
|
}
|