#pragma once #include #include #include "cutlass/layout/layout.h" #include #include #include "kernel_traits.h" #include "utils.h" namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SharedStorageLSE { cute::array_aligned> smem_lse; cute::array_aligned> smem_valid_splits; }; // DONT use Kernel_traits here to avoid redundant compilation. // template template __global__ void combine_attn_seqk_parallel(Params const params) { // using Element = typename Kernel_traits::OutputType; // using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = int64_t; // Kernel_traits::index_t constexpr int kMaxSplits = 1 << Log_max_splits; // constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNThreads = 128; //Kernel_traits::kNThreads; static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); static_assert(kNThreads == 128, "We assume that each block has 128 threads"); // Shared memory. // kBlockM + 1 instead of kBlockM to reduce bank conflicts. //__shared__ __align__(16) ElementAccum sLSE[kMaxSplits][kBlockM+1]; extern __shared__ char smem_[]; using SharedStorage = SharedStorageLSE, Int>, Shape>>; SharedStorage &shared_storage = *reinterpret_cast(smem_); Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape, Int>{}); Tensor sValidSplits = make_tensor(make_smem_ptr(shared_storage.smem_valid_splits.data()), Shape>{}); // The thread and block index. const int tidx = threadIdx.x; const int bidx = blockIdx.x; const index_t lse_size = params.b * params.h * params.seqlen_q; //if (cute::thread0()) print ("final %d %d %d %d\n", params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); const index_t row_offset_lse = bidx * kBlockM; Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), Shape, Int>{}, make_stride(lse_size, _1{})); // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. Layout flat_layout = make_layout(lse_size); Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // Read the LSE values from gmem and store them in shared memory, then transpose them. constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; const int col = tidx % kBlockM; ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; if (row < kMaxSplits) { sLSE(row,col) = lse; } // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } } __syncthreads(); // Reduce along the kBlockM dimension to determine valid splits (store in SMEM) // One thread per split. Know NumThreads = 128 >= NumMaxSplits if (tidx < kMaxSplits) { bool is_valid_split = false; #pragma unroll for (int col = 0; col < kBlockM; ++col) { if(sLSE(tidx,col) != -INFINITY) { is_valid_split = true; } } sValidSplits(tidx) = is_valid_split; } __syncthreads(); // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } Tensor lse_accum = make_tensor(Shape>{}); constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, // kBlockM rows, so each time we load we can load 128 / kBlockM rows). // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; // static_assert(kThreadsPerSplit <= 32); static_assert(kRowsPerLoadTranspose <= 32); static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; const int col = tidx / kRowsPerLoadTranspose; //if (bidx == 0 && tidx < 128) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row,col) : -INFINITY; } //return; // Compute the logsumexp of the LSE along the split dimension. ElementAccum lse_max = lse_accum(0); #pragma unroll for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } MaxOp max_op; lse_max = Allreduce::run(lse_max, max_op); lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf float lse_sum = expf(lse_accum(0) - lse_max); #pragma unroll for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } SumOp sum_op; lse_sum = Allreduce::run(lse_sum, sum_op); // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { if (params.unpadded_lse) { const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; if (lse_offset < lse_size) { gLSE_unpadded(lse_offset) = lse_logsum; } } else { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; } } //if (cute::thread0()) printf ("lse_logsum = %f\n", lse_logsum); // Store the scales exp(lse - lse_logsum) in shared memory. #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; const int col = tidx / kRowsPerLoadTranspose; if (row < params.num_splits && col < kBlockM) { sLSE(row,col) = expf(lse_accum(l) - lse_logsum); } } __syncthreads(); const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), Shape, Int>{}, Stride, _1>{}); constexpr int kBlockN = kNThreads / kBlockM; using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); Tensor tOrO = make_tensor(shape(tOgOaccum)); Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); clear(tOrO); // Predicates Tensor cOaccum = make_identity_tensor(Shape, Int>{}); //if (cute::thread0()) print_tensor (cOaccum); // Repeat the partitioning with identity layouts Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } } // Load Oaccum in then scale and accumulate to O for (int split = 0; split < params.num_splits; ++split) { // DONT copy in Oaccum if lse(split) = -inf for all kBlockM. if(sValidSplits(split)) { flash::copy( gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM ); #pragma unroll for (int m = 0; m < size<1>(tOrOaccum); ++m) { int row = get<0>(tOcOaccum(0, m, 0)); ElementAccum lse_scale = sLSE(split,row); if (lse_scale != 0.f) { #pragma unroll for (int k = 0; k < size<2>(tOrOaccum); ++k) { #pragma unroll for (int i = 0; i < size<0>(tOrOaccum); ++i) { tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); //tOrO(i, m, k) += tOrOaccum(i, m, k); } } } //if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE(split, 0), sLSE(split, 1)); print_tensor(tOrOaccum); } } } tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; } //if (cute::thread0()) { print_tensor(tOrO); } Tensor rO = flash::convert_type(tOrO); // Write to gO #pragma unroll for (int m = 0; m < size<1>(rO); ++m) { const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); //if (cute::thread0()) print ("final %d %d %d %d %d\n", idx, params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); if (idx < params.b * params.h * params.seqlen_q) { //print ("final2\n"); const int batch_idx = idx / (params.h * params.seqlen_q); const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; // The index to the rows of Q const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; #pragma unroll for (int k = 0; k < size<2>(rO); ++k) { if (Is_even_K || tOpOaccum(k)) { const int col = get<1>(tOcOaccum(0, m, k)); Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), Shape(rO))::value>>{}, Stride<_1>{}); // TODO: Should check if this is using vectorized store, but it seems pretty fast copy(rO(_, m, k), gO); //if (cute::thread0()) { print ("final\n"); print_tensor(gO); } // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); } } } } } } // namespace flash