123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248 |
- #pragma once
- #include <cute/tensor.hpp>
- #include <cutlass/cutlass.h>
- #include "cutlass/layout/layout.h"
- #include <cutlass/array.h>
- #include <cutlass/numeric_types.h>
- #include "kernel_traits.h"
- #include "utils.h"
- namespace flash {
- using namespace cute;
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template <class Element, class SmemShape, class SmemShapeMaxSplits>
- struct SharedStorageLSE {
- cute::array_aligned<Element, cute::size_v<SmemShape>> smem_lse;
- cute::array_aligned<bool, cute::size_v<SmemShapeMaxSplits>> smem_valid_splits;
- };
- // DONT use Kernel_traits here to avoid redundant compilation.
- // template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
- template<typename Element, typename ElementAccum, int kHeadDim, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
- __global__ void combine_attn_seqk_parallel(Params const params) {
- // using Element = typename Kernel_traits::OutputType;
- // using ElementAccum = typename Kernel_traits::ElementAccum;
- using index_t = int64_t; // Kernel_traits::index_t
- constexpr int kMaxSplits = 1 << Log_max_splits;
- // constexpr int kHeadDim = Kernel_traits::kHeadDim;
- constexpr int kNThreads = 128; //Kernel_traits::kNThreads;
- static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
- static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32");
- static_assert(kNThreads == 128, "We assume that each block has 128 threads");
- // Shared memory.
- // kBlockM + 1 instead of kBlockM to reduce bank conflicts.
- //__shared__ __align__(16) ElementAccum sLSE[kMaxSplits][kBlockM+1];
- extern __shared__ char smem_[];
- using SharedStorage = SharedStorageLSE<ElementAccum, Shape<Int<kMaxSplits>, Int<kBlockM+1>>, Shape<Int<kMaxSplits>>>;
- SharedStorage &shared_storage =
- *reinterpret_cast<SharedStorage *>(smem_);
- Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape<Int<kMaxSplits>, Int<kBlockM+1>>{});
- Tensor sValidSplits = make_tensor(make_smem_ptr(shared_storage.smem_valid_splits.data()), Shape<Int<kMaxSplits>>{});
- // The thread and block index.
- const int tidx = threadIdx.x;
- const int bidx = blockIdx.x;
- const index_t lse_size = params.b * params.h * params.seqlen_q;
- //if (cute::thread0()) print ("final %d %d %d %d\n", params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q);
- const index_t row_offset_lse = bidx * kBlockM;
- Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
- Shape<Int<kMaxSplits>, Int<kBlockM>>{},
- make_stride(lse_size, _1{}));
- // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
- // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
- Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
- Shape<Int<kBlockM>>{}, Stride<_1>{});
- // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
- Layout flat_layout = make_layout(lse_size);
- Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b));
- auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
- Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
- Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));
- Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout);
- constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
- // Read the LSE values from gmem and store them in shared memory, then transpose them.
- constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
- #pragma unroll
- for (int l = 0; l < kNLsePerThread; ++l) {
- const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
- const int col = tidx % kBlockM;
- ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
- if (row < kMaxSplits) { sLSE(row,col) = lse; }
- // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
- }
- __syncthreads();
- // Reduce along the kBlockM dimension to determine valid splits (store in SMEM)
- // One thread per split. Know NumThreads = 128 >= NumMaxSplits
- if (tidx < kMaxSplits) {
- bool is_valid_split = false;
- #pragma unroll
- for (int col = 0; col < kBlockM; ++col) {
- if(sLSE(tidx,col) != -INFINITY) {
- is_valid_split = true;
- }
- }
- sValidSplits(tidx) = is_valid_split;
- }
- __syncthreads();
- // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
-
- Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
- constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
- // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
- // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
- // kBlockM rows, so each time we load we can load 128 / kBlockM rows).
- // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
- // static_assert(kThreadsPerSplit <= 32);
- static_assert(kRowsPerLoadTranspose <= 32);
- static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
- #pragma unroll
- for (int l = 0; l < kNLsePerThread; ++l) {
- const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
- const int col = tidx / kRowsPerLoadTranspose;
- //if (bidx == 0 && tidx < 128) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
- lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row,col) : -INFINITY;
- }
- //return;
- // Compute the logsumexp of the LSE along the split dimension.
- ElementAccum lse_max = lse_accum(0);
- #pragma unroll
- for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
- MaxOp<float> max_op;
- lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
- lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
- float lse_sum = expf(lse_accum(0) - lse_max);
- #pragma unroll
- for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
- SumOp<float> sum_op;
- lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
- // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
- // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
- ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
- // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
- if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
- if (params.unpadded_lse) {
- const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
- if (lse_offset < lse_size) {
- gLSE_unpadded(lse_offset) = lse_logsum;
- }
- } else {
- gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
- }
- }
- //if (cute::thread0()) printf ("lse_logsum = %f\n", lse_logsum);
- // Store the scales exp(lse - lse_logsum) in shared memory.
- #pragma unroll
- for (int l = 0; l < kNLsePerThread; ++l) {
- const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
- const int col = tidx / kRowsPerLoadTranspose;
- if (row < params.num_splits && col < kBlockM) { sLSE(row,col) = expf(lse_accum(l) - lse_logsum); }
- }
- __syncthreads();
- const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
- Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
- Shape<Int<kBlockM>, Int<kHeadDim>>{},
- Stride<Int<kHeadDim>, _1>{});
- constexpr int kBlockN = kNThreads / kBlockM;
- using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
- using GmemTiledCopyOaccum = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
- GmemLayoutAtomOaccum{},
- Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
- GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
- auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
- Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
- Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
- Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
- clear(tOrO);
- // Predicates
- Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
- //if (cute::thread0()) print_tensor (cOaccum);
- // Repeat the partitioning with identity layouts
- Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
- Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
- if (!Is_even_K) {
- #pragma unroll
- for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; }
- }
- // Load Oaccum in then scale and accumulate to O
- for (int split = 0; split < params.num_splits; ++split) {
- // DONT copy in Oaccum if lse(split) = -inf for all kBlockM.
- if(sValidSplits(split)) {
- flash::copy</*Is_even_MN=*/false, Is_even_K>(
- gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM
- );
- #pragma unroll
- for (int m = 0; m < size<1>(tOrOaccum); ++m) {
- int row = get<0>(tOcOaccum(0, m, 0));
- ElementAccum lse_scale = sLSE(split,row);
- if (lse_scale != 0.f) {
- #pragma unroll
- for (int k = 0; k < size<2>(tOrOaccum); ++k) {
- #pragma unroll
- for (int i = 0; i < size<0>(tOrOaccum); ++i) {
- tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
- //tOrO(i, m, k) += tOrOaccum(i, m, k);
- }
- }
- }
- //if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE(split, 0), sLSE(split, 1)); print_tensor(tOrOaccum); }
- }
- }
- tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;
- }
- //if (cute::thread0()) { print_tensor(tOrO); }
- Tensor rO = flash::convert_type<Element>(tOrO);
- // Write to gO
- #pragma unroll
- for (int m = 0; m < size<1>(rO); ++m) {
- const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
- //if (cute::thread0()) print ("final %d %d %d %d %d\n", idx, params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q);
- if (idx < params.b * params.h * params.seqlen_q) {
- //print ("final2\n");
- const int batch_idx = idx / (params.h * params.seqlen_q);
- const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
- // The index to the rows of Q
- const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q;
- auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride
- + head_idx * params.o_head_stride + row * params.o_row_stride;
- #pragma unroll
- for (int k = 0; k < size<2>(rO); ++k) {
- if (Is_even_K || tOpOaccum(k)) {
- const int col = get<1>(tOcOaccum(0, m, k));
- Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
- Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
- // TODO: Should check if this is using vectorized store, but it seems pretty fast
- copy(rO(_, m, k), gO);
- //if (cute::thread0()) { print ("final\n"); print_tensor(gO); }
- // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
- // reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k);
- }
- }
- }
- }
- }
- } // namespace flash
|